我是Keras初学者的一员,所以我提前为任何普遍不好的理解道歉.
我想根据存储在另一个张量中的索引手动设置我的Keras张量的一些值.我相信我理解如何使用tf.gather_nd(我在下面未经测试的尝试)访问张量的条目,我想我明白我只能设置变量的值而不是张量.
为清楚起见,这发生在GAN的生成和识别阶段之间.
gen_out = generator(inputs)
indices_to_reset = Input(shape=(1,),dtype='int32')
new_values = Input(shape=(1,), dtype='int32')
batch_size = K.shape(x)[0]
idx_0 = K.reshape(K.arange(batch_size),(1,))
indices_to_reset = K.reshape(indices_to_reset, (1,))
idx = K.stack((idx_0, indices_to_reset), axis=0)
grabbed_entries = Lambda(lambda x: tf.gather_nd(gen_out,x))(idx)
# Doesn't work
# gen_out[:,indices_to_reset] = new_values
updated_gen_out = ???
Run Code Online (Sandbox Code Playgroud)
我现在没有机会尝试,但你不能使用tf.where:
updated_gen_out = tf.where(idx_mask, gen_out, new_values)
Run Code Online (Sandbox Code Playgroud)
不过,您需要idx_mask首先为索引创建一个布尔掩码,并可能重复您的 new_values 以具有与 相同的形状gen_out。
| 归档时间: |
|
| 查看次数: |
156 次 |
| 最近记录: |