Quels paramètres doivent être utilisés pour un arrêt précoce?

97

J'entraîne un réseau de neurones pour mon projet en utilisant Keras. Keras a fourni une fonction d'arrêt anticipé. Puis-je savoir quels paramètres doivent être observés pour éviter que mon réseau de neurones ne se suralimente en utilisant l'arrêt précoce?

AizuddinAzman
la source

Réponses:

157

arrêt précoce

L'arrêt précoce consiste essentiellement à arrêter l'entraînement une fois que votre perte commence à augmenter (ou en d'autres termes, la précision de la validation commence à diminuer). Selon les documents, il est utilisé comme suit;

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=0,
                              verbose=0, mode='auto')

Les valeurs dépendent de votre implémentation (problème, taille du lot etc ...) mais généralement pour éviter le surajustement que j'utiliserais;

  1. Surveillez la perte de validation (besoin d'utiliser la validation croisée ou au moins former / tester des ensembles) en définissant l' monitor argument sur 'val_loss'.
  2. min_deltaest un seuil pour déterminer si une perte à une certaine époque est une amélioration ou non. Si la différence de perte est inférieure min_delta, elle est quantifiée comme aucune amélioration. Mieux vaut laisser 0, car nous sommes intéressés par le moment où la perte s'aggrave.
  3. patienceL'argument représente le nombre d'époques avant de s'arrêter une fois que votre perte commence à augmenter (cesse de s'améliorer). Cela dépend de votre implémentation, si vous utilisez de très petits lots ou un taux d'apprentissage élevé votre perte zig-zag (la précision sera plus bruyante) donc mieux vaut définir un gros patienceargument. Si vous utilisez de gros lots et un faible taux d'apprentissage, votre perte sera plus douce et vous pourrez donc utiliser un patienceargument plus petit . Dans tous les cas, je laisserai la valeur 2 pour donner plus de chance au modèle.
  4. verbose décide quoi imprimer, laissez-le par défaut (0).
  5. modeL'argument dépend de la direction de votre quantité surveillée (est-elle censée diminuer ou augmenter), puisque nous surveillons la perte, nous pouvons utiliser min. Mais laissons keras gérer cela pour nous et définissons cela surauto

Je voudrais donc utiliser quelque chose comme ça et expérimenter en traçant la perte d'erreur avec et sans arrêt précoce.

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=2,
                              verbose=0, mode='auto')

Pour une éventuelle ambiguïté sur le fonctionnement des rappels, je vais essayer d'expliquer plus. Une fois que vous appelez fit(... callbacks=[es])votre modèle, Keras appelle des fonctions prédéterminées d'objets de rappel donnés. Ces fonctions peuvent être appelées on_train_begin, on_train_end, on_epoch_begin, on_epoch_endet on_batch_begin, on_batch_end. Le rappel d'arrêt précoce est appelé à chaque fin d'époque, compare la meilleure valeur surveillée à la valeur actuelle et s'arrête si les conditions sont remplies (combien d'époques se sont écoulées depuis l'observation de la meilleure valeur surveillée et est-ce plus qu'un argument de patience, la différence entre la dernière valeur est supérieure à min_delta etc.).

Comme indiqué par @BrentFaust dans les commentaires, l'entraînement du modèle se poursuivra jusqu'à ce que les conditions d'arrêt anticipé soient remplies ou que le epochsparamètre (par défaut = 10) fit()soit satisfait. La définition d'un rappel d'arrêt anticipé ne fera pas entraîner le modèle au-delà de son epochsparamètre. Ainsi, une fit()fonction d' appel avec une epochsvaleur plus élevée bénéficierait davantage du rappel de l'arrêt anticipé.

umutto
la source
3
@AizuddinAzman close, min_deltaest un seuil permettant de quantifier ou non le changement de valeur surveillée comme une amélioration. Alors oui, si nous donnons, monitor = 'val_loss'cela ferait référence à la différence entre la perte de validation actuelle et la perte de validation précédente. En pratique, si vous donnez min_delta=0.1une diminution de la perte de validation (actuelle - précédente) inférieure à 0,1 ne serait pas quantifiée, donc arrêterait la formation (si vous avez patience = 0).
umutto
3
Notez que cela callbacks=[EarlyStopping(patience=2)]n'a aucun effet, à moins que les époques ne soient données model.fit(..., epochs=max_epochs).
Brent Faust
1
@BrentFaust C'est aussi ce que je comprends, j'ai écrit la réponse en supposant que le modèle est entraîné avec au moins 10 époques (par défaut). Après votre commentaire, j'ai réalisé qu'il peut y avoir un cas où le programmeur appelle fit avec epoch=1dans une boucle for (pour divers cas d'utilisation) dans lequel ce rappel échouerait. S'il y a une ambiguïté dans ma réponse, j'essaierai de la formuler d'une meilleure façon.
umutto
4
@AdmiralWen Depuis que j'ai écrit la réponse, le code a un peu changé. Si vous utilisez la dernière version de Keras, vous pouvez utiliser l' restore_best_weightsargument (pas encore dans la documentation), qui charge le modèle avec les meilleurs poids après l'entraînement. Mais, pour vos besoins, j'utiliserais le ModelCheckpointrappel avec save_best_onlyargument. Vous pouvez consulter la documentation, elle est simple à utiliser mais vous devez charger manuellement les meilleurs poids après l'entraînement.
umutto
1
@umutto Bonjour merci pour la suggestion de restore_best_weights, mais je ne peux pas l'utiliser, `es = EarlyStopping (monitor = 'val_acc', min_delta = 1e-4, patience = patience_, verbose = 1, restore_best_weights = True) TypeError: __init __ () a obtenu un argument de mot-clé inattendu 'restore_best_weights'`. Des idées? keras 2.2.2, tf, 1.10 quelle est votre version?
Haramoz