我正在尝试使用我自己的数据集和类进行在imagenet上预训练的Inception-resnet v2模型的传输学习.我的原始代码库是对tf.slim样本的修改,我找不到了,现在我正在尝试使用tf.estimator.*框架重写相同的代码.
但是,我正在运行从预训练检查点加载一些权重的问题,用其默认初始值设定项初始化剩余的层.
研究这个问题,我发现了这个GitHub问题和这个问题,都提到了需要tf.train.init_from_checkpoint在我的问题中使用model_fn.我试过了,但由于两者都没有例子,我想我错了.
这是我的最小例子:
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
import numpy as np
import inception_resnet_v2
NUM_CLASSES = 900
IMAGE_SIZE = 299
def input_fn(mode, num_classes, batch_size=1):
# some code that loads images, reshapes them to 299x299x3 and batches them
return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES)
def model_fn(images, labels, num_classes, mode):
with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
logits, end_points = …Run Code Online (Sandbox Code Playgroud)