Comment fonctionne le paramètre class_weight dans scikit-learn?

116

J'ai beaucoup de mal à comprendre comment fonctionne le class_weightparamètre de la régression logistique de scikit-learn.

La situation

Je souhaite utiliser la régression logistique pour effectuer une classification binaire sur un ensemble de données très déséquilibré. Les classes sont étiquetées 0 (négatif) et 1 (positif) et les données observées sont dans un rapport d'environ 19: 1 avec la majorité des échantillons ayant un résultat négatif.

Première tentative: préparation manuelle des données d'entraînement

J'ai divisé les données que j'avais en ensembles disjoints pour la formation et les tests (environ 80/20). Ensuite, j'ai échantillonné au hasard les données d'entraînement à la main pour obtenir des données d'entraînement dans des proportions différentes de 19: 1; de 2: 1 -> 16: 1.

J'ai ensuite formé la régression logistique sur ces différents sous-ensembles de données d'entraînement et tracé le rappel (= TP / (TP + FN)) en fonction des différentes proportions d'entraînement. Bien sûr, le rappel a été calculé sur les échantillons TEST disjoints qui avaient les proportions observées de 19: 1. Notez que bien que j'aie entraîné les différents modèles sur différentes données d'entraînement, j'ai calculé le rappel pour tous sur les mêmes données de test (disjointes).

Les résultats étaient comme prévu: le rappel était d'environ 60% à des proportions d'entraînement de 2: 1 et a chuté assez rapidement au moment où il est arrivé à 16: 1. Il y avait plusieurs proportions 2: 1 -> 6: 1 où le rappel était décemment supérieur à 5%.

Deuxième tentative: recherche de grille

Ensuite, je voulais tester différents paramètres de régularisation et j'ai donc utilisé GridSearchCV et fait une grille de plusieurs valeurs du Cparamètre ainsi que du class_weightparamètre. Pour traduire mes proportions n: m de négatifs: échantillons d'apprentissage positifs dans la langue du dictionnaire, class_weightj'ai pensé que je viens de spécifier plusieurs dictionnaires comme suit:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

et j'ai également inclus Noneet auto.

Cette fois, les résultats ont été totalement décevants. Tous mes rappels sont sortis minuscules (<0,05) pour chaque valeur de class_weightsauf auto. Je ne peux donc que supposer que ma compréhension de la configuration du class_weightdictionnaire est erronée. Il est intéressant de noter que la class_weightvaleur de «auto» dans la recherche de grille était d'environ 59% pour toutes les valeurs de C, et j'ai deviné que son équilibre était de 1: 1?

Mes questions

  1. Comment utilisez-vous correctement class_weightpour obtenir des équilibres différents dans les données d'entraînement par rapport à ce que vous leur donnez réellement? Plus précisément, à quel dictionnaire dois-je passer pour class_weightutiliser n: m proportions de négatifs: échantillons d'apprentissage positifs?

  2. Si vous transmettez divers class_weightdictionnaires à GridSearchCV, lors de la validation croisée, rééquilibrera-t-il les données de pli d'entraînement en fonction du dictionnaire, mais utilisera-t-il les vraies proportions d'échantillons données pour calculer ma fonction de notation sur le pli de test? Ceci est essentiel car toute métrique ne m'est utile que si elle provient de données dans les proportions observées.

  3. Que fait la autovaleur de en class_weightce qui concerne les proportions? J'ai lu la documentation et je suppose que "équilibre les données inversement proportionnelles à leur fréquence" signifie simplement qu'il est 1: 1. Est-ce correct? Sinon, quelqu'un peut-il clarifier?

kilgoretrout
la source
Quand on utilise class_weight, la fonction de perte est modifiée. Par exemple, au lieu de l'entropie croisée, elle devient une entropie croisée pondérée. versdatascience.com
prashanth

Réponses:

123

Tout d'abord, il n'est peut-être pas bon de se contenter de se souvenir uniquement. Vous pouvez simplement obtenir un rappel de 100% en classant tout comme la classe positive. Je suggère généralement d'utiliser AUC pour sélectionner les paramètres, puis de trouver un seuil pour le point de fonctionnement (par exemple un niveau de précision donné) qui vous intéresse.

Pour la façon dont cela class_weightfonctionne: Cela pénalise les erreurs dans les échantillons de class[i]avec class_weight[i]au lieu de 1. Donc, un poids de classe plus élevé signifie que vous voulez mettre davantage l'accent sur une classe. D'après ce que vous dites, il semble que la classe 0 est 19 fois plus fréquente que la classe 1. Vous devriez donc augmenter la class_weightclasse 1 par rapport à la classe 0, disons {0: .1, 1: .9}. Si la class_weightsomme ne vaut pas 1, cela changera fondamentalement le paramètre de régularisation.

Pour savoir comment class_weight="auto"fonctionne, vous pouvez jeter un œil à cette discussion . Dans la version de développement, vous pouvez utiliser class_weight="balanced", ce qui est plus facile à comprendre: cela signifie essentiellement répliquer la classe la plus petite jusqu'à ce que vous ayez autant d'échantillons que dans la plus grande, mais de manière implicite.

Andreas Mueller
la source
1
Merci! Question rapide: j'ai mentionné le rappel pour plus de clarté et en fait, j'essaie de décider quelle AUC utiliser comme mesure. Ma compréhension est que je devrais maximiser la zone sous la courbe ROC ou la zone sous la courbe de rappel par rapport à la courbe de précision pour trouver des paramètres. Après avoir choisi les paramètres de cette façon, je crois que je choisis le seuil de classification en glissant le long de la courbe. Est-ce que c'est ce que vous vouliez dire? Si oui, laquelle des deux courbes est la plus logique à regarder si mon objectif est de capturer autant de TP que possible? Merci également pour votre travail et vos contributions à scikit-learn !!!
kilgoretrout
1
Je pense que l'utilisation de ROC serait la méthode la plus standard, mais je ne pense pas qu'il y aura une énorme différence. Cependant, vous avez besoin d'un critère pour choisir le point sur la courbe.
Andreas Mueller
3
@MiNdFrEaK Je pense que ce qu'Andrew veut dire, c'est que l'estimateur réplique des échantillons dans la classe minoritaire, de sorte que l'échantillon de différentes classes soit équilibré. Il s'agit simplement de suréchantillonner de manière implicite.
Shawn TIAN
8
@MiNdFrEaK et Shawn Tian: les classificateurs basés sur SV ne produisent pas plus d'échantillons des classes plus petites lorsque vous utilisez «équilibré». Cela pénalise littéralement les erreurs commises sur les petites classes. Dire le contraire est une erreur et est trompeur, en particulier dans les grands ensembles de données lorsque vous ne pouvez pas vous permettre de créer plus d'échantillons. Cette réponse doit être modifiée.
Pablo Rivas
4
scikit-learn.org/dev/glossary.html#term-class-weight Les poids de classe seront utilisés différemment selon l'algorithme: pour les modèles linéaires (tels que SVM linéaire ou régression logistique), les poids de classe modifieront la fonction de perte en pondérer la perte de chaque échantillon par son poids de classe. Pour les algorithmes basés sur des arbres, les poids de classe seront utilisés pour repondérer le critère de division. A noter cependant que ce rééquilibrage ne prend pas en compte le poids des échantillons dans chaque classe.
prashanth
2

La première réponse est bonne pour comprendre comment cela fonctionne. Mais je voulais comprendre comment je devrais l'utiliser dans la pratique.

RÉSUMÉ

  • pour les données modérément déséquilibrées SANS bruit, il n'y a pas beaucoup de différence dans l'application des poids de classe
  • pour des données modérément déséquilibrées AVEC du bruit et fortement déséquilibrées, il est préférable d'appliquer des pondérations de classe
  • param class_weight="balanced"fonctionne correctement en l'absence de vous voulant optimiser manuellement
  • avec class_weight="balanced"vous capturez plus d'événements vrais (rappel TRUE plus élevé) mais vous êtes également plus susceptible de recevoir de fausses alertes (précision TRUE inférieure)
    • par conséquent, le total% TRUE peut être plus élevé que réel en raison de tous les faux positifs
    • L'AUC pourrait vous tromper ici si les fausses alarmes sont un problème
  • pas besoin de changer le seuil de décision au pourcentage de déséquilibre, même pour un déséquilibre fort, ok pour garder 0,5 (ou quelque part autour de cela en fonction de ce dont vous avez besoin)

NB

Le résultat peut différer lors de l'utilisation de RF ou GBM. sklearn n'a pas class_weight="balanced" pour GBM mais lightgbm aLGBMClassifier(is_unbalance=False)

CODE

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
citynorman
la source