comment pondérer la perte KLD par rapport à la perte de reconstruction dans l'auto-encodeur variationnel

26

dans presque tous les exemples de code que j'ai vus sur un VAE, les fonctions de perte sont définies comme suit (c'est du code tensorflow, mais j'ai vu des choses similaires pour theo, torch etc. C'est aussi pour un convnet, mais ce n'est pas trop pertinent non plus , affecte uniquement les axes sur lesquels les sommes sont reprises):

# latent space loss. KL divergence between latent space distribution and unit gaussian, for each batch.
# first half of eq 10. in https://arxiv.org/abs/1312.6114
kl_loss = -0.5 * tf.reduce_sum(1 + log_sigma_sq - tf.square(mu) - tf.exp(log_sigma_sq), axis=1)

# reconstruction error, using pixel-wise L2 loss, for each batch
rec_loss = tf.reduce_sum(tf.squared_difference(y, x), axis=[1,2,3])

# or binary cross entropy (assuming 0...1 values)
y = tf.clip_by_value(y, 1e-8, 1-1e-8) # prevent nan on log(0)
rec_loss = -tf.reduce_sum(x * tf.log(y) + (1-x) * tf.log(1-y), axis=[1,2,3])

# sum the two and average over batches
loss = tf.reduce_mean(kl_loss + rec_loss)

Cependant, la plage numérique de kl_loss et rec_loss est très dépendante des gradations latentes de l'espace et de la taille de la caractéristique d'entrée (par exemple la résolution en pixels) respectivement. Serait-il judicieux de remplacer la réduction de la somme par la réduction de la moyenne pour obtenir par z-dim KLD et par pixel (ou fonctionnalité) LSE ou BCE? Plus important encore, comment pondérer la perte latente avec la perte de reconstruction lors de la sommation de la perte finale? S'agit-il seulement d'essais et d'erreurs? ou existe-t-il une théorie (ou du moins une règle empirique) pour cela? Je n'ai trouvé aucune information à ce sujet (y compris le document original).


Le problème que je rencontre est que si l'équilibre entre mes dimensions d'entité d'entrée (x) et les dimensions d'espace latent (z) n'est pas «optimal», soit mes reconstructions sont très bonnes mais l'espace latent appris n'est pas structuré (si x dimensions est très élevé et l'erreur de reconstruction domine sur KLD), ou vice versa (les reconstructions ne sont pas bonnes mais l'espace latent appris est bien structuré si KLD domine).

Je me retrouve à devoir normaliser la perte de reconstruction (division par la taille de l'entité en entrée) et KLD (division par les dimensions z), puis pondérer manuellement le terme KLD avec un facteur de pondération arbitraire (la normalisation est pour que je puisse utiliser la même chose ou poids similaire indépendant des dimensions de x ou z ). Empiriquement, j'ai trouvé environ 0,1 pour fournir un bon équilibre entre la reconstruction et l'espace latent structuré qui me semble être un «point idéal». Je recherche un travail préalable dans ce domaine.


Sur demande, notation mathématique ci-dessus (en se concentrant sur la perte de L2 pour l'erreur de reconstruction)

Llatent(i)=12j=1J(1+log(σj(i))2(μj(i))2(σj(i))2)

Lrecon(i)=k=1K(yk(i)xk(i))2

L(m)=1Mi=1M(Llatent(i)+Lrecon(i))

Jzμσ2KM(i)iL(m)m

note
la source

Réponses:

17

Pour toute personne tombant sur ce post à la recherche d'une réponse, ce fil Twitter a ajouté de nombreuses informations très utiles.

À savoir:

beta-VAE: Apprentissage des concepts visuels de base avec un cadre variationnel contraint

βnorm

et lecture connexe (où des questions similaires sont discutées)

note
la source
7

Je voudrais ajouter un autre article sur cette question (je ne peux pas faire de commentaire en raison de ma mauvaise réputation pour le moment).

Dans la sous-section 3.1 du document, les auteurs ont précisé qu'ils n'avaient pas réussi à former une mise en œuvre directe de la VAE qui pondérait également la probabilité et la divergence KL. Dans leur cas, la perte de KL a été ramenée à zéro de manière indésirable, même si elle devait avoir une petite valeur. Pour surmonter cela, ils ont proposé d'utiliser le "recuit de coût KL", qui a lentement augmenté le facteur de pondération du terme de divergence KL (courbe bleue) de 0 à 1.

Figure 2. Le poids du terme de divergence KL de la borne inférieure variationnelle selon un calendrier de recuit sigmoïde typique tracé à côté de la valeur (non pondérée) du terme de divergence KL pour notre VAE sur la Penn TreeBank.

Cette solution de contournement est également appliquée dans Ladder VAE.

Papier:

Bowman, SR, Vilnis, L., Vinyals, O., Dai, AM, Jozefowicz, R. et Bengio, S., 2015. Génération de phrases à partir d'un espace continu . arXiv preprint arXiv: 1511.06349.

Cuong
la source