Comment le générateur d'un GAN est-il formé?

9

Le document sur les GAN indique que le discriminateur utilise le gradient suivant pour s'entraîner:

θd1mi=1m[logD(x(i))+log(1D(G(z(i))))]

Les valeurs sont échantillonnées, passées à travers le générateur pour générer des échantillons de données, puis le discriminateur est rétropropagé en utilisant les échantillons de données générés. Une fois que le générateur a généré les données, il ne joue plus aucun rôle dans la formation du discriminateur. En d'autres termes, le générateur peut être complètement supprimé de la métrique en le faisant générer des échantillons de données et en ne travaillant ensuite qu'avec les échantillons.z

Je suis un peu plus confus quant à la façon dont le générateur est entraîné. Il utilise le dégradé suivant:

θg1mi=1m[log(1D(G(z(i))))]

Dans ce cas, le discriminateur fait partie de la métrique. Il ne peut pas être supprimé comme le cas précédent. Des choses comme les moindres carrés ou la vraisemblance logarithmique dans les modèles discriminants réguliers peuvent facilement être différenciées car elles ont une belle définition proche. Cependant, je suis un peu confus quant à la façon dont vous rétropropagiez lorsque la métrique dépend d'un autre réseau de neurones. Fixez-vous essentiellement les sorties du générateur aux entrées du discriminateur, puis traitez-vous le tout comme un réseau géant où les poids dans la partie discriminante sont constants?

Phidias
la source

Réponses:

10

Cela aide à penser à ce processus en pseudocode. Soit generator(z)une fonction qui prend un vecteur de bruit échantillonné uniformément zet renvoie un vecteur de même taille que le vecteur d'entrée X; appelons cette longueur d. Soit discriminator(x)une fonction qui prend un dvecteur dimensionnel et renvoie une probabilité scalaire qui xappartient à la vraie distribution de données. Pour s'entraîner:

G_sample = generator(Z)
D_real = discriminator(X)
D_fake = discriminator(G_sample)

D_loss = maximize mean of (log(D_real) + log(1 - D_fake))
G_loss = maximize mean of log(D_fake)

# Only update D(X)'s parameters
D_solver = Optimizer().minimize(D_loss, theta_D)
# Only update G(X)'s parameters
G_solver = Optimizer().minimize(G_loss, theta_G)

# theta_D and theta_G are the weights and biases of D and G respectively
Repeat the above for a number of epochs

Donc, oui, vous avez raison de penser que nous considérons essentiellement le générateur et le discriminateur comme un réseau géant pour alterner des minibatches lorsque nous utilisons de fausses données. La fonction de perte du générateur prend en charge les gradients pour cette moitié. Si vous pensez à cette formation réseau de manière isolée, alors elle est formée comme vous le feriez habituellement pour un MLP, son entrée étant la sortie de la dernière couche du réseau générateur.

Vous pouvez suivre une explication détaillée avec du code dans Tensorflow ici (parmi de nombreux endroits): http://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/

Il devrait être facile à suivre une fois que vous aurez consulté le code.

tejaskhot
la source
1
Pourriez-vous élaborer sur D_losset G_loss? Maximiser sur quel espace? IIUC, D_realet D_fakesont chacun un lot, donc nous maximisons le lot ??
P i
@Pi Oui, nous maximisons sur un lot.
tejaskhot
1

Attachez-vous essentiellement les sorties du générateur aux entrées du discriminateur?> Et traitez-vous ensuite le tout comme un réseau géant où les poids dans la partie discriminante sont constants?

En bref: Oui. (J'ai creusé certaines des sources du GAN pour revérifier cela)

Il y a aussi beaucoup plus dans la formation GAN comme: devrions-nous mettre à jour D et G à chaque fois ou D sur les itérations impaires et G sur pair, et bien plus encore. Il y a aussi un très bon article sur ce sujet:

"Techniques améliorées pour la formation des GAN"

Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen

https://arxiv.org/abs/1606.03498

Liberus
la source
Pourriez-vous fournir des liens vers les sources que vous avez consultées? Il serait utile pour moi de les lire.
Vivek Subramanian
0

Récemment, j'ai téléchargé une collection de divers modèles GAN sur github repo. Il est basé sur torch7 et très facile à utiliser. Le code est assez simple à comprendre avec des résultats expérimentaux. J'espère que cela vous aidera

https://github.com/nashory/gans-collection.torch

nashory
la source