小编hei*_*hei的帖子

如何在pytorch中处理多重损失?

在此输入图像描述

比如这个,我想用一些辅助损失来提升我的模特表现.
哪个类型代码可以在pytorch中实现它?

#one
loss1.backward()
loss2.backward()
loss3.backward()
optimizer.step()
#two
loss1.backward()
optimizer.step() 
loss2.backward()
optimizer.step() 
loss3.backward()
optimizer.step()   
#three
loss = loss1+loss2+loss3
loss.backward()
optimizer.step()
Run Code Online (Sandbox Code Playgroud)

感谢您的回答!

python pytorch

9
推荐指数
3
解决办法
3261
查看次数

图中的重复节点名称:'conv2d_0/kernel/Adam'

我刚刚通过该代码保存了一个模型:

def train():    
with tf.Session() as sess:
    saver = tf.train.Saver(max_to_keep = 2)
    Loss = myYoloLoss([Scale1,Scale2,Scale3],[Y1, Y2 ,Y3])
    opt = tf.train.AdamOptimizer(2e-4).minimize(Loss)
    init = tf.global_variables_initializer()
    sess.run(init)
    imageNum = 0
    Num = 0
    while(1):
        #get batchInput
        batchImg,batchScale1,batchScale2,batchScale3 = getBatchImage(batchSize = BATCHSIZE)
        for epoch in range(75):
            _ , epochloss = sess.run([opt,Loss],feed_dict={X:batchImg,Y1:batchScale1,Y2:batchScale2,Y3:batchScale3})
            if(epoch%15 == 0):
                print(epochloss)
        imageNum = imageNum + BATCHSIZE
        Num = Num + 1
        if(Num%4 == 0):
            saver.save(sess,MODELPATH + 'MyModle__' + str(imageNum))            
        if(os.path.exists(STOPFLAGPATH)):
            saver.save(sess,MODELPATH + 'MyModle__Stop_' + str(imageNum))   
            print('checked stopfile,stop')
            break
return 0 …
Run Code Online (Sandbox Code Playgroud)

python tensorflow pre-trained-model

6
推荐指数
1
解决办法
3603
查看次数

标签 统计

python ×2

pre-trained-model ×1

pytorch ×1

tensorflow ×1