我正在训练使用CNN进行图像分类.由于我的数据集的大小有限,我正在使用转移学习.基本上,我正在使用谷歌在其重新培训示例(https://www.tensorflow.org/tutorials/image_retraining)中证明的预训练网络.
该模型运行良好,并提供非常好的准确性.但我的数据集是高度不平衡的,这意味着准确性不是判断模型性能的最佳指标.
通过研究不同的解决方案,一些人建议改变采样方法或使用的性能指标.我选择和以后一起去.
Tensorflow提供了很好的指标,包括AUC,精度,召回等.
现在,这里是回溯模型的代码:https: //github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py
我正在添加以下add_evaluation_step(result_tensor, ground_truth_tensor)功能:
with tf.name_scope('AUC'):
with tf.name_scope('prediction'):
prediction = tf.argmax(result_tensor, 1)
with tf.name_scope('AUC'):
auc_value = tf.metrics.auc(tf.argmax(ground_truth_tensor, 1), prediction, curve='ROC')
tf.summary.scalar('accuracy', evaluation_step)
tf.summary.scalar('AUC', auc_value)
Run Code Online (Sandbox Code Playgroud)
但是我收到了这个错误:
回溯(最近一次调用最后一次):文件"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py",第1135行,tf. app.run(main = main,argv = [sys.argv [0]] + unparsed)文件"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ platform/app.py",第44行,在运行_sys.exit(main(_sys.argv [:1] + flags_passthrough))文件"/ home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain. runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py",第911行,主要是ground_truth_input:train_ground_truth})文件"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/ tensorflow/python/client/session.py",第767行,运行run_metadata_ptr)文件"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ client/session.py",第965行,在_run feed_dict_string,options,run_metadata中)文件"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session .py",第1015行,在_do_run target_list,options,run_metadata中)文件"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py" ,行1035,在_do_call中提升类型(e)(node_def,op,message)tensorflow.python.framework.errors_impl.FailedPreconditionError:尝试使用未初始化的值AUC/AUC/auc/false_positives
[[Node:AUC/AUC/auc/false_positives/read = IdentityT = DT_FLOAT,_ class = ["loc:@ AUC/AUC/auc/false_positives"],_ device ="/ job:localhost/replica:0/task:0/cpu:0"]]由op u'AUC/AUC/auc/false_positives/read'引起,定义于:File"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/ retrain.py",第1135行,在tf.app.run中(main = main,argv = [sys.argv [0]] + unparsed)文件"/ home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining /retrain.runfiles/org_tensorflow/tensorflow/python/platform/app.py",第44行,在运行_sys.exit(main(_sys.argv [:1] + flags_passthrough))文件"/ home/user_2/tensorflow/bazel -bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py",第874行,主final_tensor,ground_truth_input)文件"/ home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py",第806行,add_evaluation_step auc_value,update_op = tf.metrics.auc(tf.argmax(ground_truth_tensor,1),prediction,curve ='鹏')文件"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py",第555行,在auc标签,预测,阈值,权重)文件"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py",第473行,在_confusion_matrix_at_thresholds中false_p = _create_local('false_positives',shape = [num_thresholds])文件"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py",第177行,在_create_local中验证validate_shape = validate_shape)文件"/home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/variables.py",第226行,在init中expect_shape = expected_shape)文件"/ home/user_2 /tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/o ps/variables.py",第344行,_init_from_args self._snapshot = array_ops.identity(self._variable,name ="read")文件"/ home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain .runfiles/org_tensorflow/tensorflow/python/ops/gen_array_ops.py",第1490行,在identity result = _op_def_lib.apply_op("Identity",input = input,name = name)文件"/ home/user_2/tensorflow/bazel- bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/framework/op_def_library.py",第768行,在apply_op中op_def = op_def)文件"/ home/user_2/tensorflow/bazel-bin/tensorflow/examples /image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/framework/ops.py",第2402行,在create_op original_op = self._default_original_op,op_def = op_def)文件"/ home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/framework/ops.py",第1264行,在init self._traceback = _extract_stack()
FailedPreconditionError(参见上面的回溯):尝试使用未初始化的值AUC/AUC/auc/false_positives [[Node:AUC/AUC/auc/false_positives/read = IdentityT = DT_FLOAT,_class = ["loc:@ AUC/AUC/auc/false_positives"],_ device ="/ job:localhost/replica:0/task:0/cpu:0"]]
但我不明白为什么这是因为在主要方面我有这个:
init = tf.global_variables_initializer()
sess.run(init)
Run Code Online (Sandbox Code Playgroud)
Jie*_*hou 23
试试这个:
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
5197 次 |
| 最近记录: |