为什么我无法用Keras预测y = x**4?(y = x**3有效)

Chr*_*ris 7 python deep-learning keras

我设法预测y=x**2y=x**3,但方程式喜欢y=x**4y=x**5y=x**7仅收敛到不准确的线?

我做错了什么?我能改进什么?

import numpy as np
from keras.layers import Dense, Activation
from keras.models import Sequential
import matplotlib.pyplot as plt
import math
import time

x = np.arange(-100, 100, 0.5)
y = x**4

model = Sequential()
model.add(Dense(50, input_shape=(1,)))
model.add(Activation('sigmoid'))
model.add(Dense(50) )
model.add(Activation('elu'))
model.add(Dense(1))
model.compile(loss='mse', optimizer='adam')

t1 = time.clock()
for i in range(100):
    model.fit(x, y, epochs=1000, batch_size=len(x), verbose=0)
    predictions = model.predict(x)
    print (i," ", np.mean(np.square(predictions - y))," t: ", time.clock()-t1)

    plt.hold(False)
    plt.plot(x, y, 'b', x, predictions, 'r--')
    plt.hold(True)
    plt.ylabel('Y / Predicted Value')
    plt.xlabel('X Value')
    plt.title([str(i)," Loss: ",np.mean(np.square(predictions - y))," t: ", str(time.clock()-t1)])
    plt.pause(0.001)

#plt.savefig("fig2.png")
plt.show()
Run Code Online (Sandbox Code Playgroud)

小智 0

我认为这是因为输入数据的范围太大了。添加batchnorm层可以提高性能。这是带有batchnorm层的模型的结果。

这是代码:

import numpy as np
import keras
from keras.layers import Dense, Activation
from keras.models import Sequential
import matplotlib.pyplot as plt
import math
import time


x = np.arange(-100, 100, 0.5)
y = x**4


model = Sequential()
model.add(keras.layers.normalization.BatchNormalization(input_shape=(1,)))
model.add(Dense(200))
model.add(Activation('relu'))
model.add(Dense(50))
model.add(Activation('elu'))
model.add(Dense(1))
model.compile(loss='mse', optimizer='adam')


t1 = time.clock()
for i in range(100):
    model.fit(x, y, epochs=1000, batch_size=len(x), verbose=0)
    predictions = model.predict(x)
    print (i," ", np.mean(np.square(predictions - y))," t: ", time.clock()-t1)

    plt.hold(False)
    plt.plot(x, y, 'b', x, predictions, 'r--')
    plt.hold(True)
    plt.ylabel('Y / Predicted Value')
    plt.xlabel('X Value')
    plt.title([str(i)," Loss: ",np.mean(np.square(predictions - y))," t: ", str(time.clock()-t1)])
    plt.pause(0.001)
plt.show()
Run Code Online (Sandbox Code Playgroud)