Fel*_*ser 1 memory-leaks memory-management python-3.x tensorflow
我有一个非常简单的基于 tensorflow 的函数,它采用形状为 (1, 6, 64, 64, 64, 1) 的张量并返回形状为 (1, 6, 3) 的张量,其中包含每个 (64 , 64, 64) 原始张量中的体积。我的工作没有任何问题,但每次我的循环(见下文)进入下一次迭代时,我电脑中使用的 RAM 都会增加。在我完全用完之前,这将我限制在大约 500 个样本。我想我在某个地方遗漏了一些东西,但我没有足够的经验知道在哪里。
代码:
import tensorflow as tf
import pickle
import scipy.io
import scipy.ndimage
import sys
from os import listdir
from os.path import isfile, join
import numpy as np
def get_raw_centroids(lm_vol):
# Find centres of mass for each landmark
lm_vol *= tf.cast(tf.greater(lm_vol, 0.75), tf.float64)
batch_size, lm_size, vol_size = lm_vol.shape[:3]
xx, yy, zz = tf.meshgrid(tf.range(vol_size), tf.range(
vol_size), tf.range(vol_size), indexing='ij')
coords = tf.stack([tf.reshape(xx, (-1,)), tf.reshape(yy, (-1,)),
tf.reshape(zz, (-1,))], axis=-1)
coords = tf.cast(coords, tf.float64)
volumes_flat = tf.reshape(lm_vol, [-1, int(lm_size), int(vol_size * vol_size * vol_size), 1])
total_mass = tf.reduce_sum(volumes_flat, axis=2)
raw_centroids = tf.reduce_sum(volumes_flat * coords, axis=2) / total_mass
return raw_centroids
path = '/home/mosahle/Avg_vol_tf/'
lm_data_path = path + 'MAT_data_volumes/'
files = [f for f in listdir(lm_data_path) if isfile(join(lm_data_path, f))]
files.sort()
for i in range(10):
sess = tf.Session()
print("File {} of {}".format(i, len(files)))
"""
Load file
"""
dir = lm_data_path + files[i]
lm_vol = scipy.io.loadmat(dir)['datavol']
lm_vol = tf.convert_to_tensor(lm_vol, dtype=tf.float64)
Run Code Online (Sandbox Code Playgroud)
lm_vol 是 (1, 6, 64, 64, 64, 1) 数组。它们只是 numpy 数组并转换为张量。
"""
Get similarity matrix
"""
pts_raw = get_raw_centroids(lm_vol)
print(sess.run(pts_raw))
sess.close()
Run Code Online (Sandbox Code Playgroud)
我也尝试将 tf.Session() 放在循环之外,但没有区别。
上面代码中的问题是,当您调用函数时,您正在循环内创建多个图形get_raw_centroids。
让我们考虑一个更简单的例子:
def get_raw_centroids(lm_vol):
raw_centroids = lm_vol * 2
return raw_centroids
for i in range(10):
sess = tf.Session()
lm_vol = tf.constant(3)
pts_raw = get_raw_centroids(lm_vol)
print(sess.run(pts_raw))
print('****Graph: ***\n')
print([x for x in tf.get_default_graph().get_operations()])
sess.close()
Run Code Online (Sandbox Code Playgroud)
上面代码的输出是:
#6
#****Graph: ***
#[<tf.Operation 'Const' type=Const>,
#<tf.Operation 'mul/y' type=Const>,
#<tf.Operation 'mul' type=Mul>]
#6
#****Graph: ***
#[<tf.Operation 'Const' type=Const>,
# <tf.Operation 'mul/y' type=Const>,
# <tf.Operation 'mul' type=Mul>,
# <tf.Operation 'Const_1' type=Const>,
# <tf.Operation 'mul_1/y' type=Const>,
# <tf.Operation 'mul_1' type=Mul>]
#6
#****Graph: ***
#[<tf.Operation 'Const' type=Const>,
#<tf.Operation 'mul/y' type=Const>,
#<tf.Operation 'mul' type=Mul>,
#<tf.Operation 'Const_1' type=Const>,
#<tf.Operation 'mul_1/y' type=Const>,
#<tf.Operation 'mul_1' type=Mul>,
#<tf.Operation 'Const_2' type=Const>,
#<tf.Operation 'mul_2/y' type=Const>,
#<tf.Operation 'mul_2' type=Mul>]
...
Run Code Online (Sandbox Code Playgroud)
因此,每个循环都会添加一个带有新变量的新图以及旧图。
处理上述代码的正确方法如下:
# Create a placeholder for the input
lm_vol = tf.placeholder(dtype=tf.float32)
pts_raw = get_raw_centroids(lm_vol)
# Session
for i in range(10):
# numpy input
lm_vol_np = 3
# pass the input to the placeholder and get the output of the graph
print(sess.run(pts_raw, {lm_vol: lm_vol_np}))
print('****Graph: ***\n')
print([x for x in tf.get_default_graph().get_operations()])
sess.close()
Run Code Online (Sandbox Code Playgroud)
代码的输出将是:
#6.0
#****Graph: ***
#[<tf.Operation 'Placeholder' type=Placeholder>,
#<tf.Operation 'mul/y' type=Const>,
#<tf.Operation 'mul' type=Mul>]
#6.0
#****Graph: ***
#[<tf.Operation 'Placeholder' type=Placeholder>,
#<tf.Operation 'mul/y' type=Const>,
#<tf.Operation 'mul' type=Mul>]
#6.0
#****Graph: ***
#[<tf.Operation 'Placeholder' type=Placeholder>,
#<tf.Operation 'mul/y' type=Const>,
#<tf.Operation 'mul' type=Mul>]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1001 次 |
| 最近记录: |