小编nim*_*ics的帖子

AttributeError:模块“tensorflow.keras.mixed_ precision”没有属性“set_global_policy”

我需要在代码中添加混合精度以节省一些内存。具体来说,我尝试在https://github.com/nimRobotics/google-research/blob/master/ravens/train.py的第 27 行附近添加混合精度策略,下面是代码摘录

import argparse
import datetime
import os

import numpy as np
from ravens import agents
from ravens import Dataset
import tensorflow as tf

# tf.keras.mixed_precision.set_global_policy('mixed_float16')

# OR

# policy = tf.keras.mixed_precision.Policy('mixed_float16')
# mixed_precision.set_global_policy(policy)
Run Code Online (Sandbox Code Playgroud)

这两种方法都会导致属性错误,如下所示,我使用 Google Colab 和 TF 2.3.0

使用tf.keras.mixed_precision.set_global_policy('mixed_float16')结果于

Traceback (most recent call last):
  File "train.py", line 28, in <module>
    tf.keras.mixed_precision.set_global_policy('mixed_float16')
AttributeError: module 'tensorflow.keras.mixed_precision' has no attribute 'set_global_policy'

Run Code Online (Sandbox Code Playgroud)

使用

policy = tf.keras.mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
Run Code Online (Sandbox Code Playgroud)

结果是

Traceback (most recent call last):
  File "train.py", …
Run Code Online (Sandbox Code Playgroud)

python python-3.x keras tensorflow google-colaboratory

5
推荐指数
1
解决办法
7012
查看次数