小编MR_*_*BGC的帖子

在 Flux (Julia) 的损失函数中使用分位数

我正在尝试在损失函数中使用分位数进行训练!(对于某些稳健性,例如最少修剪的平方),但它会改变数组并且 Zygote 会抛出一个错误Mutating arrays is not supported,来自sort!. 下面是一个简单的例子(内容当然没有意义):

using Flux, StatsBase
xdata = randn(2, 100)   
ydata = randn(100)

model = Chain(Dense(2,10), Dense(10, 1))


function trimmedLoss(x,y; trimFrac=0.f05)
        yhat = model(x)
        absRes = abs.(yhat .- y) |> vec
        trimVal = quantile(absRes, 1.f0-trimFrac) 
        s = sum(ifelse.(absRes .> trimVal,  0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
        #s = sum(absRes)/length(absRes)   # using this and commenting out the two above works (no surprise)    
end

println(trimmedLoss(xdata, ydata)) #works ok

Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())

println(trimmedLoss(xdata, ydata)) …
Run Code Online (Sandbox Code Playgroud)

machine-learning julia flux.jl

7
推荐指数
1
解决办法
412
查看次数

标签 统计

flux.jl ×1

julia ×1

machine-learning ×1