我的代码中有什么错误,随着梯度下降的每次迭代,误差不断增加?

Kri*_*mar 2 python machine-learning linear-regression python-3.x gradient-descent

下面的代码读取 csv(Andrew NG ML 课程 ex1 多元线性回归练习数据文件),然后尝试使用学习率 alpha = 0.01 将线性模型拟合到数据集。梯度下降是将参数(theta 向量)递减 400 次(alpha 和 num_of_iterations 值在问题陈述中给出)。我尝试了矢量化实现来获得参数的最佳值,但下降并未收敛 - 误差不断增加。

# Imports


```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
```

# Model Preparation

## Gradient descent


```python
def gradient_descent(m, theta, alpha, num_of_iterations, X, Y):
#     print(m, theta, alpha, num_of_iterations)
    for i in range(num_of_iterations):
        htheta_vector = np.dot(X,theta)
#         print(X.shape, theta.shape, htheta_vector.shape)
        error_vector = htheta_vector - Y
        gradient_vector = (1/m) * (np.dot(X.T, error_vector)) # each element in gradient_vector corresponds to each theta
        theta = theta - alpha * gradient_vector

    return theta
```

# Main


```python
def main():
    df = pd.read_csv('data2.csv', header = None) #loading data
    data = df.values # converting dataframe to numpy array

    X = data[:, 0:2]
#     print(X.shape)
    Y = data[:, -1]

    m = (X.shape)[0] # number of training examples

    Y = Y.reshape(m, 1)

    ones = np.ones(shape = (m,1))
    X_with_bias = np.concatenate([ones, X], axis = 1)

    theta = np.zeros(shape = (3,1)) # two features, so three parameters

    alpha = 0.001
    num_of_iterations = 400

    theta = gradient_descent(m, theta, alpha, num_of_iterations, X_with_bias, Y) # calling gradient descent
#     print('Parameters learned: ' + str(theta))

if __name__ == '__main__':
    main()
```
Run Code Online (Sandbox Code Playgroud)

错误:

    /home/krish-thorcode/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:8: RuntimeWarning: invalid value encountered in subtract
Run Code Online (Sandbox Code Playgroud)

不同迭代的误差值:

迭代 1 [[-399900.] [-329900.] [-369000.] [-232000.] [-539900.] [-299900.] [-314900.] [-198999.] [-212000.] [- 242500.] [-239999.] [-347000.] [-329999.] [-699900.] [-259900.] [-449900.] [-299900.] [-199900.] [-499998.] [- 599000.] [-252900.] [-255000.] [-242900.] [-259900.] [-573900.] [-249900.] [-464500.] [-469000.] [-475000.] [- 299900.] [-349900.] [-169900.] [-314900.] [-579900.] [-285900.] [-249900.] [-229900.] [-345000.] [-549000.] [- 287000.] [-368500.] [-329900.] [-314000.] [-299000.] [-179900.] [-299900.] [-239500.]]

迭代2 [[1.60749981e+09] [1.22240841e+09] [1.83373661e+09] [1.08189071e+09] [2.29209231e+09] [1.51666004e+09] [1.17198560e+09] [1.0903 3113e+09 ] [1.05440030e+09] [1.14148964e+09] [1.48233053e+09] [1.52807496e+09] [1.44402895e+09] [3.42143452e+09] [9.68760976e+08] [1.7572359 2e+09] [ 1.00845873e+09] [9.44366284e+08] [1.99332644e+09] [2.31572369e+09] [1.35010833e+09] [1.44257442e+09] [1.22555224e+09] [1.49912323e +09] [2.97220331e +09] [8.40383843e+08] [1.11375611e+09] [1.92992696e+09] [1.68078878e+09] [2.01492327e+09] [1.40503327e+09] [7.64040689e+08] [1.55867 654e+09 ] [2.39674784e+09] [1.38370165e+09] [1.09792232e+09] [9.46628911e+08] [1.62895368e+09] [3.22059730e+09] [1.65193796e+09] [1.2712780 7e+09] [ 1.70997383e+09] [1.96141565e+09] [9.16755655e+08] [6.50928858e+08] [1.41502023e+09] [9.19107783e+08]]

迭代3 [[-7.42664624e+12] [-5.64764378e+12] [-8.47145714e+12] [-4.99816153e+12] [-1.05893224e+13] [-7.00660901e+12] [-5.41467917e+ 12] [-5.03699402e+12] [-4.87109500e+12] [-5.27348843e+12] [-6.84776945e+12] [-7.05955046e+12] [-6.67127611e+12] [-1.58063228e+13 ] [-4.47576119e+12] [-8.11848565e+12] [-4.65930400e+12] [-4.36280860e+12] [-9.20918360e+12] [-1.06987452e+13] [-6.23711474e+12] [-6.66421140e+12] [-5.66176276e+12] [-6.92542434e+12] [-1.37308096e+13] [-3.88276038e+12] [-5.14641706e+12] [-8.91620784e+12] [ -7.76550392e+12] [-9.30801176e+12] [-6.49125293e+12] [-3.52977344e+12] [-7.20074619e+12] [-1.10728954e+13] [-6.39242960e+12] [- 5.07229174e+12] [-4.37339793e+12] [-7.52548475e+12] [-1.48779889e+13] [-7.63137769e+12] [-5.87354379e+12] [-7.89963490e+12] [-9.06 093321 e+12] [-4.23573710e+12] [-3.00737309e+12] [-6.53715005e+12] [-4.24632634e+12]]

迭代4 [[3.43099835e+16] [2.60912608e+16] [3.91368523e+16] [2.30907512e+16] [4.89210695e+16] [3.23694753e+16] [2.50149995e+16] [2.3270 1516e+16 ] [2.25037231e+16] [2.43627199e+16] [3.16356608e+16] [3.26140566e+16] [3.08202877e+16] [7.30228235e+16] [2.06773403e+16] [3.7506177 0e+16] [ 2.15252802e+16] [2.01555166e+16] [4.25450367e+16] [4.94265862e+1​​6] [2.88145280e+16] [3.07876502e+16] [2.61564888e+16] [3.19944145e +16] [6.34342666e +16] [1.79377661e+16] [2.37756683e+16] [4.11915330e+16] [3.58754545e+16] [4.30016088e+16] [2.99886077e+16] [1.63070200e+16] [3.32663 597e+16 ] [5.11551035e+16] [2.95320591e+16] [2.34332215e+16] [2.02044376e+16] [3.47666027e+16] [6.87340617e+16] [3.52558124e+16] [2.7134884 6e+16] [ 3.64951201e+16] [4.18601431e+16] [1.95684650e+16] [1.38936092e+16] [3.02006457e+16] [1.96173860e+16]]

迭代5 [[-1.58506940e+20] [-1.20537683e+20] [-1.80806345e+20] [-1.06675782e+20] [-2.26007951e+20] [-1.49542086e+20] [-1.15565519e+ 20] [-1.07504585e+20] [-1.03963801e+20] [-1.12552086e+20] [-1.46151974e+20] [-1.50672014e+20] [-1.42385073e+20] [-3.37354413e+20 ] [-9.55261885e+19] [-1.73272871e+20] [-9.94435428e+19] [-9.31154420e+19] [-1.96551642e+20] [-2.28343362e+20] [-1.33118767e+20] [-1.42234293e+20] [-1.20839027e+20] [-1.47809362e+20] [-2.93056729e+20] [-8.28697695e+19] [-1.09839996e+20] [-1.90298660e+20] [ -1.65739180e+20] [-1.98660937e+20] [-1.38542837e+20] [-7.53359691e+19] [-1.53​​685556e+20] [-2.36328850e+20] [-1.36433652e+20] [- 1.08257943e+20] [-9.33414495e+19] [-1.60616452e+20] [-3.17540981e+20] [-1.62876527e+20] [-1.25359067e+20] [-1.68601941e+20] [-1.93 387537 e+20] [-9.04033523e+19] [-6.41863754e+19] [-1.39522421e+20] [-9.06293597e+19]]

迭代83 [[-1.09904300e+306] [-8.35774743e+305] [-1.25366087e+306] [-7.39660179e+305] [-1.56707622e+306] [-1.03688320e+306] [-8.012991 37e+ 305] [-7.45406868e+305] [-7.20856058e+305] [-7.80404831e+305] [-1.01337710e+306] [-1.04471781e+306] [-9.87258464e+305] [-2.3391215 9e+306 ] [-6.62352000e+305] [-1.20142586e+306] [-6.89513844e+305] [-6.45636555e+305] [-1.36283437e+306] [-1.58326931e+306] [-9.23008472e+第305章[-9.86212994e+305] [-8.37864174e+305] [-1.02486897e+306] [-2.03197378e+306] [-5.74595914e+305] [-7.61599955e+305] [-1.31947793e+30 6] [ -1.14918934e+306] [-1.37745963e+306] [-9.60617469e+305] [-5.22358639e+305] [-1.06561287e+306] [-1.63863846e+306] [-9.45992963e+305 ] [- 7.50630445e+305] [-6.47203628e+305] [-1.11366977e+306] [-2.20174077e+306] [-1.12934050e+306] [-8.69204879e+305] [-1.16903893e+306 ] [-1.34089535 e+306] [-6.26831680e+305] [-4.45050460e+305] [-9.67409627e+305] [-6.28398753e+305]]

迭代 84 [[inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf]
[inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] ] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf] [inf]]

小智 5

请尝试特征标准化来解决这个问题。只是特征值很大,当值很大时,成本函数(平方误差)会快速增加。作为一般规则,当您尝试最小化非线性成本函数时,请执行均值归一化和特征缩放。