Tensorflow - 如何为tf.Estimator()CNN使用GPU而不是CPU

Gen*_*337 8 python tensorflow tensorflow-estimator

我认为应该和它一起使用with tf.device("/gpu:0"),但我应该把它放在哪里?我认为不是:

with tf.device("/gpu:0"):
    tf.app.run()
Run Code Online (Sandbox Code Playgroud)

所以,我应该把它在main()功能tf.app,或我使用的估计模型的功能?

编辑:如果这有帮助,这是我的main()功能:

def main(unused_argv):
  """Code to load training folds data pickle or generate one if not present"""

  # Create the Estimator
  mnist_classifier = tf.estimator.Estimator(
      model_fn=cnn_model_fn2, model_dir="F:/python_machine_learning_codes/tmp/custom_age_adience_1")

  # Set up logging for predictions
  # Log the values in the "Softmax" tensor with label "probabilities"
  tensors_to_log = {"probabilities": "softmax_tensor"}
  logging_hook = tf.train.LoggingTensorHook(
      tensors=tensors_to_log, every_n_iter=100)

  # Train the model
  train_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": train_data},
      y=train_labels,
      batch_size=64,
      num_epochs=None,
      shuffle=True)
  mnist_classifier.train(
      input_fn=train_input_fn,
      steps=500,
      hooks=[logging_hook])

  # Evaluate the model and print results
  """Code to load eval fold data pickle or generate one if not present"""

  eval_logs = {"probabilities": "softmax_tensor"}
  eval_hook = tf.train.LoggingTensorHook(
      tensors=eval_logs, every_n_iter=100)
  eval_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": eval_data},
      y=eval_labels,
      num_epochs=1,
      shuffle=False)
  eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn, hooks=[eval_hook])
Run Code Online (Sandbox Code Playgroud)

正如你所看到的,我在这里的任何地方都没有明确的会话声明,所以我究竟在哪里放with tf.device("/gpu:0")

Win*_*erZ 0

对于估计器,没有任何类似的声明

sess = tf.Session(config = xxxxxxxxxxxxx)
Run Code Online (Sandbox Code Playgroud)

既没有一个声明

sess.run()
Run Code Online (Sandbox Code Playgroud)

所以......不幸的是张量流网络中没有完整的文档。我正在尝试使用 RunConfig 的不同选项

# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
# trains faster than GPU for this model.
run_config = tf.estimator.RunConfig().replace(
        session_config=tf.ConfigProto(log_device_placement=True,
                                      device_count={'GPU': 0}))
Run Code Online (Sandbox Code Playgroud)

尝试处理这个...实际上我正在处理类似你的任务的事情,所以如果我得到一些进展,我会将其发布在这里。

请看这里: https: //github.com/tensorflow/models/blob/master/official/wide_deep/wide_deep.py 在此示例中,他们使用上面显示的代码和 .replace 语句来确保模型正在运行在CPU上。