Perte KL avec une unité gaussienne

10

J'ai implémenté une VAE et j'ai remarqué deux implémentations différentes en ligne de la divergence gaussienne KL univariée simplifiée. La divergence d' origine que par ici est Si nous supposons que notre a priori est une unité gaussienne, c'est-à-dire et , cela se simplifie jusqu'à Et voici où repose ma confusion. Bien que j'ai trouvé quelques dépôts github obscurs avec l'implémentation ci-dessus, ce que je trouve le plus couramment utilisé est: μ2=0σ2=1KLloss=-log(σ1)+σ 2 1 +μ 2 1

KLloss=log(σ2σ1)+σ12+(μ1μ2)22σ2212
μ2=0σ2=1 KLloss=-1
KLloss=log(σ1)+σ12+μ12212
KLloss=12(2log(σ1)σ12μ12+1)

=12(log(σ1)σ1μ12+1)
Par exemple dans le tutoriel officiel de l' auto-encodeur Keras . Ma question est alors, qu'est-ce que je manque entre ces deux? La principale différence est de laisser tomber le facteur 2 sur le terme logarithmique et de ne pas mettre la variance au carré. Analytiquement, j'ai utilisé cette dernière avec succès, pour ce qu'elle vaut. Merci d'avance pour votre aide!
groovyDragon
la source

Réponses:

7

Notez qu'en remplaçant par dans la dernière équation, vous récupérez la précédente (c.-à-d. ). Cela m'amène à penser que dans le premier cas, l'encodeur est utilisé pour prédire la variance, tandis que dans le second, il est utilisé pour prédire l'écart type.σ 2 1 log ( σ 1 ) - σ 12 log ( σ 1 ) - σ 2 1σ1σ12log(σ1)σ12log(σ1)σ12

Les deux formulations sont équivalentes et l'objectif est inchangé.

F. Evlangeli
la source
Je ne pense pas qu'il puisse être le cas que ceux-ci sont équivalents. Oui, ils sont tous deux minimisés quand pour zéro et unit . Cependant, dans l'équation originale (présentant la variance), la pénalité pour éloigner de l'unité est beaucoup plus grande que dans la deuxième équation (basée sur l'écart type). La pénalité pour les variations de est la même pour les deux, et l'erreur de reconstruction serait la même, donc l'utilisation de la deuxième version change considérablement l'importance relative des écarts de rapport à l'unité. Qu'est-ce que je rate? σ σ μ σμσσμσ
TheBamf
0

Je pense que la réponse est plus simple. Dans la VAE, les gens utilisent généralement une distribution normale multivariée, qui a une matrice de covariance au lieu de variance . Cela semble déroutant dans un morceau de code mais a la forme souhaitée.Σσ2

Ici vous pouvez trouver la dérivation d'une divergence KL pour les distributions normales multivariées: Dériver la perte de divergence KL pour les VAE

Dmitry Grebenyuk
la source