我正在使用自定义层对 keras 模型进行子类化。每个层都包装了一个生成它们层时使用的参数字典。似乎这些参数字典没有在 Tensorflow 中进行训练检查点之前设置,而是在之后设置,这会导致错误。我不知道如何解决这个问题,因为ValueError被提出也提供了过时的信息(tf.contrib不再存在)。
ValueError: 无法保存对象 {'units': 32, 'activation': 'tanh', 'recurrent_initializer': 'glorot_uniform', 'dropout': 0, 'return_sequences': True}(在属性上自动构造的字典包装器任务)。包装的字典在包装器外被修改(它的最终值是 {'units': 32, 'activation': 'tanh', 'recurrent_initializer': 'glorot_uniform', 'dropout': 0, 'return_sequences': True},它的添加检查点依赖项时的值为 None),这会破坏对象创建时的恢复。
如果您不需要此字典检查点,请将其包装在 tf.contrib.checkpoint.NoDependency 对象中;它将被自动解包并随后被忽略。
这是引发此问题的图层示例:
class RecurrentConfig(BaseLayer):
'''Basic configurable recurrent layer'''
def __init__(self, params: Dict[Any, Any], mode: ModeKeys, layer_name: str = '', **kwargs):
self.layer_name = layer_name
self.cell_name = params.pop('cell', 'GRU')
self.num_layers = params.pop('num_layers', 1)
kwargs['name'] = layer_name
super().__init__(params, mode, **kwargs)
if layer_name == '':
self.layer_name = self.cell_name
self.layers: List[layers.Layer] = stack_layers(self.params,
self.num_layers,
self.cell_name)
def call(self, inputs: np.ndarray) -> layers.Layer:
'''This function is a sequential/functional call to this layers logic
Args:
inputs: Array to be processed within this layer
Returns:
inputs processed through this layer'''
processed = inputs
for layer in self.layers:
processed = layer(processed)
return processed
@staticmethod
def default_params() -> Dict[Any, Any]:
return{
'units': 32,
'recurrent_initializer': 'glorot_uniform',
'dropout': 0,
'recurrent_dropout': 0,
'activation': 'tanh',
'return_sequences': True
}
Run Code Online (Sandbox Code Playgroud)
基础层.py
'''Basic ABC for a keras style layer'''
from typing import Dict, Any
from tensorflow.keras import layers
from mosaix_py.mosaix_learn.configurable import Configurable
class BaseLayer(Configurable, layers.Layer):
'''Base configurable Keras layer'''
def get_config(self) -> Dict[str, Any]:
'''Return configuration dictionary as part of keras serialization'''
config = super().get_config()
config.update(self.params)
return config
@staticmethod
def default_params() -> Dict[Any, Any]:
raise NotImplementedError('Layer does not implement default params')
Run Code Online (Sandbox Code Playgroud)
我面临的问题是我正在将字典中的项目传递到图层中。图层
self.cell_name = params.pop('cell', 'GRU')
self.num_layers = params.pop('num_layers', 1)
Run Code Online (Sandbox Code Playgroud)
将字典传递到层时,它必须在跟踪时保持不变。
我的解决方案是进一步抽象参数解析并传入最终的字典。
| 归档时间: |
|
| 查看次数: |
614 次 |
| 最近记录: |