mat*_*ick 6 tensorflow tensorflow2.0
下面的示例显示了手动清除缓存的简单方法。是否有更标准/稳定的方式来管理未来的缓存?或者也许是一种从一开始就避免这种情况的模式?
在某些情况下,批处理大小变化很大,并且遇到内存问题,因为 def_fun 没有超出范围,并且缓存可能没有清除。
In [164]: @tf.function
...: def f(x):
...: return dict(something=x ** 2)
...:
...:
...:
In [165]: f._list_all_concrete_functions_for_serialization()
Out[165]: []
In [166]: _ = f(tf.convert_to_tensor(np.random.randn(109, 3).astype(np.float32)))
In [167]: _ = f(tf.convert_to_tensor(np.random.randn(111, 3).astype(np.float32)))
In [168]: f._list_all_concrete_functions_for_serialization()
Out[168]:
[<tensorflow.python.eager.function.ConcreteFunction at 0x7fac73e0d358>,
<tensorflow.python.eager.function.ConcreteFunction at 0x7fac71d41a58>]
In [169]: f._stateful_fn._function_cache._garbage_collectors
Out[169]:
[<tensorflow.python.eager.function._FunctionGarbageCollector at 0x7fac94252390>,
<tensorflow.python.eager.function._FunctionGarbageCollector at 0x7fac7b0c6048>,
<tensorflow.python.eager.function._FunctionGarbageCollector at 0x7fac7b0c6d68>]
In [170]: f._stateful_fn._function_cache._garbage_collectors[0]
Out[170]: <tensorflow.python.eager.function._FunctionGarbageCollector at 0x7fac94252390>
In [171]: f._stateful_fn._function_cache._garbage_collectors[0]._cache
Out[171]:
OrderedDict([(CacheKey(input_signature=('UTd1s109-3-u', None), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
<tensorflow.python.eager.function.ConcreteFunction at 0x7fac7371def0>),
(CacheKey(input_signature=('UTd1s111-3-u', None), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
<tensorflow.python.eager.function.ConcreteFunction at 0x7fac77a514a8>),
(CacheKey(input_signature=('URu', (TensorSpec(shape=(111, 3), dtype=tf.float32, name='x'),)), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
<tensorflow.python.eager.function.ConcreteFunction at 0x7fac73e0d358>),
(CacheKey(input_signature=('URu', (TensorSpec(shape=(109, 3), dtype=tf.float32, name='x'),)), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
<tensorflow.python.eager.function.ConcreteFunction at 0x7fac71d41a58>)])
In [172]: f._stateful_fn._function_cache._garbage_collectors[0]._cache.popitem()
Out[172]:
(CacheKey(input_signature=('URu', (TensorSpec(shape=(109, 3), dtype=tf.float32, name='x'),)), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
<tensorflow.python.eager.function.ConcreteFunction at 0x7fac71d41a58>)
In [173]: f._stateful_fn._function_cache._garbage_collectors[0]._cache.popitem()
Out[173]:
(CacheKey(input_signature=('URu', (TensorSpec(shape=(111, 3), dtype=tf.float32, name='x'),)), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
<tensorflow.python.eager.function.ConcreteFunction at 0x7fac73e0d358>)
In [174]: f._stateful_fn._function_cache._garbage_collectors[0]._cache.popitem()
Out[174]:
(CacheKey(input_signature=('UTd1s111-3-u', None), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
<tensorflow.python.eager.function.ConcreteFunction at 0x7fac77a514a8>)
In [175]: f._stateful_fn._function_cache._garbage_collectors[0]._cache.popitem()
Out[175]:
(CacheKey(input_signature=('UTd1s109-3-u', None), parent_graph=None, device_functions=(), colocation_stack=(), in_cross_replica_context=False),
<tensorflow.python.eager.function.ConcreteFunction at 0x7fac7371def0>)
In [176]: f._stateful_fn._function_cache._garbage_collectors[0]._cache.popitem()
Run Code Online (Sandbox Code Playgroud)
小智 -1
在调用 tf.function 之前,复制该对象。
import tensorflow as tf
import copy
@tf.function
def test1(a):
print('trace')
return a * a
test2 = copy.copy(test1)
print(test1(1))
print(test1.experimental_get_tracing_count())
print(test1(1))
print(test1.experimental_get_tracing_count())
print(test2(1))
print(test2.experimental_get_tracing_count())
Run Code Online (Sandbox Code Playgroud)
结果:
trace
tf.Tensor(1, shape=(), dtype=int32)
1
tf.Tensor(1, shape=(), dtype=int32)
1
trace
tf.Tensor(1, shape=(), dtype=int32)
1
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1581 次 |
| 最近记录: |