相关疑难解决方法(0)

使用tf.estimator.Estimator框架转移学习

我正在尝试使用我自己的数据集和类进行在imagenet上预训练的Inception-resnet v2模型的传输学习.我的原始代码库是对tf.slim样本的修改,我找不到了,现在我正在尝试使用tf.estimator.*框架重写相同的代码.

但是,我正在运行从预训练检查点加载一些权重的问题,用其默认初始值设定项初始化剩余的层.

研究这个问题,我发现了这个GitHub问题这个问题,都提到了需要tf.train.init_from_checkpoint在我的问题中使用model_fn.我试过了,但由于两者都没有例子,我想我错了.

这是我的最小例子:

import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
import numpy as np

import inception_resnet_v2

NUM_CLASSES = 900
IMAGE_SIZE = 299

def input_fn(mode, num_classes, batch_size=1):
  # some code that loads images, reshapes them to 299x299x3 and batches them
  return tf.constant(np.zeros([batch_size, 299, 299, 3], np.float32)), tf.one_hot(tf.constant(np.zeros([batch_size], np.int32)), NUM_CLASSES)


def model_fn(images, labels, num_classes, mode):
  with tf.contrib.slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):
    logits, end_points = …
Run Code Online (Sandbox Code Playgroud)

python tensorflow tensorflow-estimator

17
推荐指数
1
解决办法
4694
查看次数

标签 统计

python ×1

tensorflow ×1

tensorflow-estimator ×1