小编Val*_*ité的帖子

将 tf.data.Dataset 转换为 jax.numpy 迭代器

我对使用 JAX 训练神经网络感兴趣。我看过tf.data.Dataset,但它专门提供 tf 张量。我寻找一种将数据集更改为 JAX numpy 数组的方法,并且发现了很多用于Dataset.as_numpy_generator()将 tf 张量转换为 numpy 数组的实现。但是我想知道这是否是一个好的做法,因为 numpy 数组存储在 CPU 内存中,这不是我想要的训练(我使用 GPU)。所以我发现的最后一个想法是通过调用手动重新转换数组jnp.array,但这并不是很优雅(我担心 GPU 内存中的副本)。有人对此有更好的主意吗?

快速说明代码:

import os
import jax.numpy as jnp
import tensorflow as tf

def generator():
    for _ in range(2):
        yield tf.random.uniform((1, ))

ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
                                    output_shapes=tf.TensorShape([1]))

ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)

for i, batch in enumerate(ds1):
    print(type(batch))

for i, batch in enumerate(ds2):
    print(type(jnp.array(batch)))

# returns:

<class 'numpy.ndarray'> # not good
<class 'jaxlib.xla_extension.DeviceArray'> # good but …
Run Code Online (Sandbox Code Playgroud)

python tensorflow numpy-ndarray jax

4
推荐指数
1
解决办法
1494
查看次数

标签 统计

jax ×1

numpy-ndarray ×1

python ×1

tensorflow ×1