在训练用于图像分类的 keras 模型(来自 DOG BREED IDENTIFICATION 数据集的 120 个类,KAGGLE)时,我需要使用我在某处读到的类权重来平衡类,在示例中我看到人们使用 fit_generator 的参数 class_weight。但我在 model.compile 中发现了另一个参数,weighted_metrics,其在文档中的描述是:“在训练和测试期间由sample_weight或class_weight评估和加权的指标列表”。我要用这个吗?请用任何示例解释此参数的用途。
#Calculating Class weights
counter = Counter(train_generator.classes)
max_value = float(max(counter.values()))
CLASS_WEIGHTS = {classid: max_value / num_occurences
for classid, num_occurences in counter.items()}
# Model Compile
model.compile(optimizer=Adam(lr=LR),
loss=categorical_crossentropy,
metrics=[categorical_accuracy],
weighted_metrics=None) # <--------------- This parameter
STEPS_PER_EPOCH = train_generator.n//train_generator.batch_size
VAL_STEPS = val_generator.n//val_generator.batch_size
model.fit_generator(train_generator,
steps_per_epoch=STEPS_PER_EPOCH,
epochs=EPOCHS,
callbacks=callback_list,
verbose=1,
class_weight=CLASS_WEIGHTS,
validation_data=val_generator,
validation_steps=VAL_STEPS) # USED CLASS_WEIGHTS HERE
Run Code Online (Sandbox Code Playgroud)