Quelle fonction de perte utiliser pour les classes déséquilibrées (en utilisant PyTorch)?

17

J'ai un ensemble de données avec 3 classes avec les éléments suivants:

  • Classe 1: 900 éléments
  • Classe 2: 15 000 éléments
  • Classe 3: 800 éléments

Je dois prédire les classes 1 et 3, qui signalent des écarts importants par rapport à la norme. La classe 2 est le cas «normal» par défaut dont je me fiche.

Quel type de fonction de perte devrais-je utiliser ici? Je pensais utiliser CrossEntropyLoss, mais comme il y a un déséquilibre de classe, cela devrait être pondéré, je suppose? Comment cela fonctionne-t-il dans la pratique? Comme ça (en utilisant PyTorch)?

summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)

Ou faut-il inverser le poids? soit 1 / poids?

Est-ce la bonne approche pour commencer ou y a-t-il d'autres / meilleures méthodes que je pourrais utiliser?

Merci

Muppet
la source

Réponses:

13

Quel type de fonction de perte devrais-je utiliser ici?

L'entropie croisée est la fonction de perte de référence pour les tâches de classification, qu'elles soient équilibrées ou déséquilibrées. C'est le premier choix lorsqu'aucune préférence n'est encore établie à partir de la connaissance du domaine.

Cela devrait être pondéré, je suppose? Comment cela fonctionne-t-il dans la pratique?

Oui. Poids de la classec est la taille de la plus grande classe divisée par la taille de la classe c.

Par exemple, si la classe 1 a 900, la classe 2 a 15000 et la classe 3 a 800 échantillons, alors leurs poids seraient respectivement 16,67, 1,0 et 18,75.

Vous pouvez également utiliser la plus petite classe comme nominateur, ce qui donne 0,889, 0,053 et 1,0 respectivement. Il ne s'agit que d'une nouvelle mise à l'échelle, les poids relatifs sont les mêmes.

Est-ce la bonne approche pour commencer ou y a-t-il d'autres / meilleures méthodes que je pourrais utiliser?

Oui, c'est la bonne approche.

MODIFIER :

Grâce à @Muppet, nous pouvons également utiliser le suréchantillonnage de classe, ce qui équivaut à utiliser des poids de classe . Ceci est accompli par WeightedRandomSamplerdans PyTorch, en utilisant les mêmes poids mentionnés ci-dessus.

Esmailian
la source
2
Je voulais juste ajouter que l'utilisation de WeightedRandomSampler de PyTorch a également aidé, au cas où quelqu'un d'autre regarderait cela.
Muppet
0

Lorsque vous dites: vous pouvez également utiliser la plus petite classe comme nominateur, ce qui donne respectivement 0,889, 0,053 et 1,0. Il ne s'agit que d'une nouvelle mise à l'échelle, les poids relatifs sont les mêmes.

Mais cette solution est en contradiction avec la première que vous avez donnée, comment ça marche?

Georges Matar
la source