小编Alm*_*hty的帖子

使用 tensorflowjs 进行联合学习

我正在使用 tensorflowjs 实现联邦学习。但我有点卡在联邦平均过程中。这个想法很简单:从多个客户端获取更新的权重并在服务器中对其进行平均。

我已经在浏览器上训练了一个模型,通过 model.getWeights() 方法获得了更新的权重,并将权重发送到服务器进行平均。


//get weights from multiple clients(happens i client-side)
w1 = model.getWeights(); //weights from client 1
w2 = model.getWeights(); //weights from client 2


//calculate average of the weights(server-side)
var mean_weights= [];
let length = w1.length; // length of all weights_array is same
for(var i=0; i<length; i++){
    let sum = w1[i].add(w2[i]);
    let mean = sum.divide(2); //got confused here, how to calculate mean of tensors ??
    mean_weights.push(mean);
}

//apply updates to the model(both client-side and server-side)
model.setWeights(mean_weights);
Run Code Online (Sandbox Code Playgroud)

所以我的问题是:如何计算张量数组的平均值?另外,这是通过 …

tensorflow tensorflow.js

6
推荐指数
1
解决办法
276
查看次数

标签 统计

tensorflow ×1

tensorflow.js ×1