如何检查两个张量形状是否相同,包括无?

Sun*_*Kim 3 numpy tensorflow

我想检查两个张量是否具有相同的形状。

假设我有一些这样的张量:

a = tf.placeholder(tf.float32, shape=[None, 3])
b = tf.placeholder(tf.float32, shape=[None, 3])
Run Code Online (Sandbox Code Playgroud)

我补充说assert a.shape == b.shape。但是,这失败了,可能是由于 None。事实上a.shape= (?, 1),也b.shape就是(?, 1)。他们在我看来都一样。

如果没有 None,它工作正常。

a = tf.placeholder(tf.float32, shape=[1, 3])
b = tf.placeholder(tf.float32, shape=[1, 3])
assert a.shape == b.shape  # True
Run Code Online (Sandbox Code Playgroud)

如何在形状检查中忽略 None ?

总之:

1: a = tf.placeholder(tf.float32, shape=[1, 3])
2: b = tf.placeholder(tf.float32, shape=[1, 3])
3: assert a.shape == b.shape  # True
4: 
5: a = tf.placeholder(tf.float32, shape=[None, 3])
6: b = tf.placeholder(tf.float32, shape=[None, 3])
7: assert a.shape == b.shape  # False
Run Code Online (Sandbox Code Playgroud)

我想让第 7 行中的断言为真。

mrr*_*rry 8

您可以使用a.shape.as_list() == b.shape.as_list()来比较两个tf.TensorShape对象的“相等性”。但是,这样做时应该小心,因为如果两个形状包含None在同一位置,则不能保证具有这些形状的张量在该维度上具有相同的大小。

(能够batch_size在 a 中表示“符号”维度tf.TensorShape会很有用,这将使相等性测试更有用。我们正在研究 API 的扩展,以允许在未来版本的 TensorFlow 中实现这一点。)

  • 谢谢!如果我们有像 tf.TensorShape.assert_same() 这样的东西,那就太好了。 (2认同)
  • 另一种选择是使用 `tf.assert_equal(a.shape, b.shape)`,尽管它不适用于部分已知的形状。 (2认同)