Chr*_*ris 7 python deep-learning keras
我设法预测y=x**2和y=x**3,但方程式喜欢y=x**4或y=x**5或y=x**7仅收敛到不准确的线?
我做错了什么?我能改进什么?
import numpy as np
from keras.layers import Dense, Activation
from keras.models import Sequential
import matplotlib.pyplot as plt
import math
import time
x = np.arange(-100, 100, 0.5)
y = x**4
model = Sequential()
model.add(Dense(50, input_shape=(1,)))
model.add(Activation('sigmoid'))
model.add(Dense(50) )
model.add(Activation('elu'))
model.add(Dense(1))
model.compile(loss='mse', optimizer='adam')
t1 = time.clock()
for i in range(100):
model.fit(x, y, epochs=1000, batch_size=len(x), verbose=0)
predictions = model.predict(x)
print (i," ", np.mean(np.square(predictions - y))," t: ", time.clock()-t1)
plt.hold(False)
plt.plot(x, y, 'b', x, predictions, 'r--')
plt.hold(True)
plt.ylabel('Y / Predicted Value')
plt.xlabel('X Value')
plt.title([str(i)," Loss: ",np.mean(np.square(predictions - y))," t: ", str(time.clock()-t1)])
plt.pause(0.001)
#plt.savefig("fig2.png")
plt.show()
Run Code Online (Sandbox Code Playgroud)
小智 0
我认为这是因为输入数据的范围太大了。添加batchnorm层可以提高性能。这是带有batchnorm层的模型的结果。
这是代码:
import numpy as np
import keras
from keras.layers import Dense, Activation
from keras.models import Sequential
import matplotlib.pyplot as plt
import math
import time
x = np.arange(-100, 100, 0.5)
y = x**4
model = Sequential()
model.add(keras.layers.normalization.BatchNormalization(input_shape=(1,)))
model.add(Dense(200))
model.add(Activation('relu'))
model.add(Dense(50))
model.add(Activation('elu'))
model.add(Dense(1))
model.compile(loss='mse', optimizer='adam')
t1 = time.clock()
for i in range(100):
model.fit(x, y, epochs=1000, batch_size=len(x), verbose=0)
predictions = model.predict(x)
print (i," ", np.mean(np.square(predictions - y))," t: ", time.clock()-t1)
plt.hold(False)
plt.plot(x, y, 'b', x, predictions, 'r--')
plt.hold(True)
plt.ylabel('Y / Predicted Value')
plt.xlabel('X Value')
plt.title([str(i)," Loss: ",np.mean(np.square(predictions - y))," t: ", str(time.clock()-t1)])
plt.pause(0.001)
plt.show()
Run Code Online (Sandbox Code Playgroud)