小编Geo*_*ost的帖子

Tensorflow:logits 和标签必须具有相同的第一维

我是 tensoflow 的新手,我想用我自己的数据(40x40 的图像)调整 MNIST 教程https://www.tensorflow.org/tutorials/layers。这是我的模型函数:

def cnn_model_fn(features, labels, mode):
        # Input Layer
        input_layer = tf.reshape(features, [-1, 40, 40, 1])

        # Convolutional Layer #1
        conv1 = tf.layers.conv2d(
                inputs=input_layer,
                filters=32,
                kernel_size=[5, 5],
                #  To specify that the output tensor should have the same width and height values as the input tensor
                # value can be "same" ou "valid"
                padding="same",
                activation=tf.nn.relu)

        # Pooling Layer #1
        pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

        # Convolutional Layer #2 and Pooling Layer #2
        conv2 …
Run Code Online (Sandbox Code Playgroud)

python keras tensorflow tensorflow-datasets tensorflow-estimator

18
推荐指数
4
解决办法
5万
查看次数

Tensorflow错误:不支持可调用

我按照教程https://www.tensorflow.org/tutorials/layers进行操作,我想用它来使用自己的数据集。

def train_input_fn_custom(filenames_array, labels_array, batch_size):
    # Reads an image from a file, decodes it into a dense tensor, and resizes it to a fixed shape.
    def _parse_function(filename, label):
        image_string = tf.read_file(filename)
        image_decoded = tf.image.decode_png(image_string, channels=1)
        image_resized = tf.image.resize_images(image_decoded, [40, 40])
        return image_resized, label

    filenames = tf.constant(filenames_array)
    labels = tf.constant(labels_array)

    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    dataset = dataset.map(_parse_function)
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    return dataset.make_one_shot_iterator().get_next()


def main(self):
    tf.logging.set_verbosity(tf.logging.INFO)

    # Get data
    filenames_train = ['blackcorner-data/1.png', 'blackcorner-data/2.png']
    labels_train = [0, 1]

    # Create the …
Run Code Online (Sandbox Code Playgroud)

python python-3.x tensorflow tensorflow-datasets tensorflow-estimator

2
推荐指数
1
解决办法
3289
查看次数