Aks*_*rma 3 machine-learning julia logistic-regression flux-machine-learning
我有一个数据集,其中包含 2 个科目的学生成绩以及学生是否被大学录取的结果。我需要对数据执行逻辑回归并找到最佳参数 \xce\xb8 以最小化损失并预测测试数据的结果。我不想在这里构建任何复杂的非线性网络。
\n\n\n\n我为逻辑回归定义了损失函数,这样效果很好
\n\npredict(X) = sigmoid(X*\xce\xb8)\nloss(X,y) = (1 / length(y)) * sum(-y .* log.(predict(X)) .- (1 - y) .* log.(1 - predict(X)))\nRun Code Online (Sandbox Code Playgroud)\n\n我需要最小化这个损失函数并找到最佳的\xce\xb8。我想使用 Flux.jl 或任何其他使它更容易的库来完成此操作。 \n在阅读示例后我尝试使用 Flux.jl 但无法最大程度地降低成本。
\n\n我的代码片段:
\n\nfunction update!(ps, \xce\xb7 = .1)\n for w in ps\n w.data .-= w.grad .* \xce\xb7\n print(w.data)\n w.grad .= 0\n end\nend\n\nfor i = 1:400\n back!(L)\n update!((\xce\xb8, b))\n @show L\nend\nRun Code Online (Sandbox Code Playgroud)\n
您可以使用 GLM.jl(更简单)或 Flux.jl(更复杂,但总体上更强大)。\n在代码中,我生成数据,以便您可以检查结果是否正确。另外,我有一个二进制响应变量 - 如果您有目标变量的其他编码,您可能需要稍微更改代码。
\n\n这是要运行的代码(您可以调整参数以提高收敛速度 - 我选择了安全的参数):
\n\nusing GLM, DataFrames, Flux.Tracker\n\nsrand(1)\nn = 10000\ndf = DataFrame(s1=rand(n), s2=rand(n))\ndf[:y] = rand(n) .< 1 ./ (1 .+ exp.(-(1 .+ 2 .* df[1] .+ 0.5 .* df[2])))\nmodel = glm(@formula(y~s1+s2), df, Binomial(), LogitLink())\n\nx = Matrix(df[1:2])\ny = df[3]\nW = param(rand(2,1))\nb = param(rand(1))\npredict(x) = 1.0 ./ (1.0+exp.(-x*W .- b))\nloss(x,y) = -sum(log.(predict(x[y,:]))) - sum(log.(1 - predict(x[.!y,:])))\n\nfunction update!(ps, \xce\xb7 = .0001)\n for w in ps\n w.data .-= w.grad .* \xce\xb7\n w.grad .= 0\n end\nend\n\ni = 1\nwhile true\n back!(loss(x,y))\n max(maximum(abs.(W.grad)), abs(b.grad[1])) > 0.001 || break\n update!((W, b))\n i += 1\nend\nRun Code Online (Sandbox Code Playgroud)\n\n结果如下:
\n\njulia> model # GLM result\nStatsModels.DataFrameRegressionModel{GLM.GeneralizedLinearModel{GLM.GlmResp{Array{Float64,1},Distributions.Binomial{Float64},GLM.LogitLink},GLM.DensePredChol{Float64,Base.LinAlg.Cholesky{Float64,Array{Float64,2}}}},Array{Float64,2}}\n\nFormula: y ~ 1 + s1 + s2\n\nCoefficients:\n Estimate Std.Error z value Pr(>|z|)\n(Intercept) 0.910347 0.0789283 11.5338 <1e-30\ns1 2.18707 0.123487 17.7109 <1e-69\ns2 0.556293 0.115052 4.83513 <1e-5\n\n\njulia> (b, W, i) # Flux result with number of iterations needed to converge\n(param([0.910362]), param([2.18705; 0.556278]), 1946)\nRun Code Online (Sandbox Code Playgroud)\n
| 归档时间: |
|
| 查看次数: |
1051 次 |
| 最近记录: |