创建一个 3D 零张量,在 numpy/jax 中的每个切片上随机放置一个“1”

Atu*_*yak 5 python numpy tensor jax

例如,我需要创建一个像这样的 3D 张量 (5,3,2)

array([[[0, 0],
        [0, 1],
        [0, 0]],

       [[1, 0],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [1, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [1, 0]],

       [[0, 0],
        [0, 1],
        [0, 0]]])
Run Code Online (Sandbox Code Playgroud)

每个切片中都应该随机放置一个“1”(如果您将张量视为一条面包)。这可以使用循环来完成,但我想向量化这部分。

Qua*_*ang 4

尝试生成一个随机数组,然后找到max

a = np.random.rand(5,3,2)
out = (a == a.max(axis=(1,2))[:,None,None]).astype(int)
Run Code Online (Sandbox Code Playgroud)