Implémentation de t-SNE Python: divergence Kullback-Leibler

11

t-SNE, comme dans [1], fonctionne en réduisant progressivement la divergence de Kullback-Leibler (KL), jusqu'à ce qu'une certaine condition soit remplie. Les créateurs de t-SNE suggèrent d'utiliser la divergence KL comme critère de performance pour les visualisations:

vous pouvez comparer les divergences Kullback-Leibler rapportées par t-SNE. Il est tout à fait correct d'exécuter t-SNE dix fois et de sélectionner la solution avec la divergence KL la plus faible [2]

J'ai essayé deux implémentations de t-SNE:

  • python : sklearn.manifold.TSNE ().
  • R : tsne, de la bibliothèque (tsne).

Ces deux implémentations, lorsque la verbosité est définie, affichent l'erreur (divergence Kullback-Leibler) pour chaque itération. Cependant, ils ne permettent pas à l'utilisateur d'obtenir ces informations, ce qui me semble un peu étrange.

Par exemple, le code:

import numpy as np
from sklearn.manifold import TSNE
X = np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
model = TSNE(n_components=2, verbose=2, n_iter=200)
t = model.fit_transform(X)

produit:

[t-SNE] Computing pairwise distances...
[t-SNE] Computed conditional probabilities for sample 4 / 4
[t-SNE] Mean sigma: 1125899906842624.000000
[t-SNE] Iteration 10: error = 6.7213750, gradient norm = 0.0012028
[t-SNE] Iteration 20: error = 6.7192064, gradient norm = 0.0012062
[t-SNE] Iteration 30: error = 6.7178683, gradient norm = 0.0012114
...
[t-SNE] Error after 200 iterations: 0.270186

Maintenant, si je comprends bien, 0,270186 devrait être la divergence KL. Cependant, je ne peux pas obtenir ces informations, ni à partir du modèle ni à partir de t (qui est un simple numpy.ndarray).

Pour résoudre ce problème, je pourrais: i) Calculer la divergence KL par moi-même, ii) Faire quelque chose de méchant en python pour capturer et analyser la sortie de la fonction TSNE () [3]. Cependant: i) serait assez stupide pour recalculer la divergence KL, alors que TSNE () l'a déjà calculé, ii) serait un peu inhabituel en termes de code.

Avez-vous une autre suggestion? Existe-t-il un moyen standard d'obtenir ces informations à l'aide de cette bibliothèque?

J'ai mentionné que j'avais essayé la bibliothèque tsne de R , mais je préfèrerais que les réponses se concentrent sur l' implémentation de python sklearn.


Références

[1] http://nbviewer.ipython.org/urls/gist.githubusercontent.com/AlexanderFabisch/1a0c648de22eff4a2a3e/raw/59d5bc5ed8f8bfd9ff1f7faa749d1b095aa97d5a/t-SNE.ipynb

[2] http://homepage.tudelft.nl/19j49/t-SNE.html

[3] /programming/16571150/how-to-capture-stdout-output-from-a-python-function-call

joker
la source

Réponses:

4

La source TSNE dans scikit-learn est en pur Python. La fit_transform()méthode Fit appelle en fait une _fit()fonction privée qui appelle ensuite une _tsne()fonction privée . Cette _tsne()fonction a une variable locale errorqui est imprimée à la fin de l'ajustement. Il semble que vous puissiez facilement changer une ou deux lignes de code source pour que cette valeur soit retournée fit_transform().

Trey
la source
Essentiellement, ce que je pourrais faire est de définir self.error = error à la fin de _tsne (), afin de le récupérer de l'instance TSNE par la suite. Oui, mais cela signifierait changer le code sklearn.manifold, et je me demandais si les développeurs pensaient à d'autres moyens d'obtenir les informations ou si non pourquoi (pas: est-ce que `` l'erreur '' est jugée inutile par eux?). De plus, si je modifiais ce code, j'aurais besoin que toutes les personnes qui exécutent mon code aient le même hack sur leurs installations sklearn. Est-ce ce que vous suggérez ou ai-je fait erreur?
joker
Oui, c'est ce que j'ai proposé comme solution possible. Étant donné que scikit-learn est open source, vous pouvez également soumettre votre solution sous forme de demande d'extraction et voir si les auteurs l'incluront dans les versions futures. Je ne peux pas expliquer pourquoi ils ont inclus ou non diverses choses.
Trey
2
Merci. Si quelqu'un d'autre est intéressé par cela, github.com/scikit-learn/scikit-learn/pull/3422 .
joker