PyTorch 中的早期停止

Tot*_*oro 29 python neural-network deep-learning pytorch early-stopping

我尝试实现提前停止功能以避免我的神经网络模型过度拟合。我很确定逻辑是正确的,但由于某种原因,它不起作用。我希望当验证损失大于某些时期的训练损失时,早期停止函数返回 True。但它始终返回 False,即使验证损失变得比训练损失大得多。请问您能看出问题出在哪里吗?

早停功能

def early_stopping(train_loss, validation_loss, min_delta, tolerance):

    counter = 0
    if (validation_loss - train_loss) > min_delta:
        counter +=1
        if counter >= tolerance:
          return True
Run Code Online (Sandbox Code Playgroud)

在训练期间调用该函数

for i in range(epochs):
    
    print(f"Epoch {i+1}")
    epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
    train_loss.append(epoch_train_loss)

    # validation 

    with torch.no_grad(): 
       epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
       validation_loss.append(epoch_validate_loss)
    
    # early stopping
    if early_stopping(epoch_train_loss, epoch_validate_loss, min_delta=10, tolerance = 20):
      print("We are at epoch:", i)
      break
Run Code Online (Sandbox Code Playgroud)

编辑:训练和验证损失: 在此输入图像描述 在此输入图像描述

编辑2:

def train_validate (model, train_dataloader, validate_dataloader, loss_func, optimiser, device, epochs):
    preds = []
    train_loss =  []
    validation_loss = []
    min_delta = 5
    

    for e in range(epochs):
        
        print(f"Epoch {e+1}")
        epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
        train_loss.append(epoch_train_loss)

        # validation 
        with torch.no_grad(): 
           epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
           validation_loss.append(epoch_validate_loss)
        
        # early stopping
        early_stopping = EarlyStopping(tolerance=2, min_delta=5)
        early_stopping(epoch_train_loss, epoch_validate_loss)
        if early_stopping.early_stop:
            print("We are at epoch:", e)
            break

    return train_loss, validation_loss
Run Code Online (Sandbox Code Playgroud)

isl*_*ods 62

尽管@KarelZe 的回应充分且优雅地解决了您的问题,但我想提供一种可以说更好的替代早期停止标准。

您的早期停止标准基于验证损失与训练损失的偏离程度(以及持续时间)。当验证损失确实减少但通常不够接近训练损失时,这种情况就会被打破。训练模型的目标是鼓励减少验证损失,而不是减少训练损失和验证损失之间的差距。

因此,我认为更好的早期停止标准是单独观察验证损失的趋势,即,如果训练没有导致验证损失的降低,则终止它。这是一个示例实现:

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
Run Code Online (Sandbox Code Playgroud)

使用方法如下:

early_stopper = EarlyStopper(patience=3, min_delta=10)
for epoch in np.arange(n_epochs):
    train_loss = train_one_epoch(model, train_loader)
    validation_loss = validate_one_epoch(model, validation_loader)
    if early_stopper.early_stop(validation_loss):             
        break
Run Code Online (Sandbox Code Playgroud)

  • 非常感谢您的回答。这是一个新想法,非常令人惊奇。你很热心! (2认同)

Kar*_*lZe 18

您的实现的问题是,每当您调用early_stopping()计数器时,都会使用 重新初始化0

这是使用面向 oo 的方法与 代替的可行解决__call__()方案__init__()

class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0):

        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, train_loss, validation_loss):
        if (validation_loss - train_loss) > self.min_delta:
            self.counter +=1
            if self.counter >= self.tolerance:  
                self.early_stop = True
Run Code Online (Sandbox Code Playgroud)

这样称呼它:

early_stopping = EarlyStopping(tolerance=5, min_delta=10)

for i in range(epochs):
    
    print(f"Epoch {i+1}")
    epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
    train_loss.append(epoch_train_loss)

    # validation 
    with torch.no_grad(): 
       epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
       validation_loss.append(epoch_validate_loss)
    
    # early stopping
    early_stopping(epoch_train_loss, epoch_validate_loss)
    if early_stopping.early_stop:
      print("We are at epoch:", i)
      break
Run Code Online (Sandbox Code Playgroud)

例子:

early_stopping = EarlyStopping(tolerance=2, min_delta=5)

train_loss = [
    642.14990234,
    601.29278564,
    561.98400879,
    530.01501465,
    497.1098938,
    466.92709351,
    438.2364502,
    413.76028442,
    391.5090332,
    370.79074097,
]
validate_loss = [
    509.13619995,
    497.3125,
    506.17315674,
    497.68960571,
    505.69918823,
    459.78610229,
    480.25592041,
    418.08630371,
    446.42675781,
    372.09902954,
]

for i in range(len(train_loss)):

    early_stopping(train_loss[i], validate_loss[i])
    print(f"loss: {train_loss[i]} : {validate_loss[i]}")
    if early_stopping.early_stop:
        print("We are at epoch:", i)
        break

Run Code Online (Sandbox Code Playgroud)

输出:

loss: 642.14990234 : 509.13619995
loss: 601.29278564 : 497.3125
loss: 561.98400879 : 506.17315674
loss: 530.01501465 : 497.68960571
loss: 497.1098938 : 505.69918823
loss: 466.92709351 : 459.78610229
loss: 438.2364502 : 480.25592041
We are at epoch: 6
Run Code Online (Sandbox Code Playgroud)