estimator的简单使用方式
estimator的官方使用方式介绍了使用自定义的estimator的model,没有涉及到从keras的model来使用estimator。
主要的使用方式来自这篇notebook在使用的时候没有遇上太多障碍。
但有一些细节花了一点时间去调试。
比如estimator能按照dataset重复次数dataset.repeat(n)
作为epoch,因此如果直接使用dataset.repeat()
会在训练时陷入死循环。
model_fn的处理
1 | def model_fn(features, labels, mode): |
通过传递参数是无法打印更多的训练结果,但是可以通过创建一个logging hook来让estimator运行。
In the body of model_fn function for your estimator:
1 | logging_hook = tf.train.LoggingTensorHook({"loss" : loss, |
除了self.estimator.train()
以外,可以使用tf.estimator.train_and_evaluate()
对train
和evaluate
进行更精细地操作。
此外add_metrics(estimator,my_auc)
只是把metrics加入到最终结果的输出里,而不是每一次step,对于每一次step需要在EstimatorSpec(training_hook=[logging_hook])
里添加logging_hook
多gpu出现的
All hooks must be SessionRunHook instances问题在#issues21444 里解决,等待tf-1.11版本。