Générer une carte thermique dans MatPlotLib à l'aide d'un ensemble de données scatter

187

J'ai un ensemble de points de données X, Y (environ 10k) qui sont faciles à tracer sous forme de nuage de points mais que je voudrais représenter sous forme de carte thermique.

J'ai regardé à travers les exemples dans MatPlotLib et ils semblent tous déjà commencer avec des valeurs de cellule de carte thermique pour générer l'image.

Existe-t-il une méthode qui convertit un groupe de x, y, tous différents, en une carte thermique (où les zones avec une fréquence plus élevée de x, y seraient "plus chaudes")?

Greye
la source

Réponses:

182

Si vous ne voulez pas d'hexagones, vous pouvez utiliser la histogram2dfonction de numpy :

import numpy as np
import numpy.random
import matplotlib.pyplot as plt

# Generate some test data
x = np.random.randn(8873)
y = np.random.randn(8873)

heatmap, xedges, yedges = np.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.clf()
plt.imshow(heatmap.T, extent=extent, origin='lower')
plt.show()

Cela fait une carte thermique 50x50. Si vous voulez, par exemple, 512x384, vous pouvez bins=(512, 384)appeler histogram2d.

Exemple: Exemple de carte thermique Matplotlib

ptomate
la source
1
Je ne veux pas être un idiot, mais comment avez-vous réellement cette sortie dans un fichier PNG / PDF au lieu de l'afficher uniquement dans une session IPython interactive? J'essaie d'obtenir cela comme une sorte d' axesinstance normale , où je peux ajouter un titre, des étiquettes d'axe, etc. puis faire la normale savefig()comme je le ferais pour n'importe quel autre tracé matplotlib typique.
gotgenes
3
@gotgenes: ne plt.savefig('filename.png')fonctionne pas ? Si vous voulez obtenir une instance d'axes, utilisez l'interface orientée objet de Matplotlib:fig = plt.figure() ax = fig.gca() ax.imshow(...) fig.savefig(...)
ptomato
1
En effet, merci! Je suppose que je ne comprends pas tout à fait que imshow()c'est sur la même catégorie de fonctions que scatter(). Honnêtement, je ne comprends pas pourquoi imshow()convertit un tableau 2D de flotteurs en blocs de couleur appropriée, alors que je comprends ce qui scatter()est censé faire avec un tel tableau.
gotgenes le
14
Un avertissement sur l'utilisation de imshow pour tracer un histogramme 2D de valeurs x / y comme ceci: par défaut, imshow trace l'origine dans le coin supérieur gauche et transpose l'image. Ce que je ferais pour obtenir la même orientation qu'un nuage de points estplt.imshow(heatmap.T, extent=extent, origin = 'lower')
Jamie
7
Pour ceux qui veulent faire une barre de couleurs logarithmique voir cette question stackoverflow.com/questions/17201172/… et simplement fairefrom matplotlib.colors import LogNorm plt.imshow(heatmap, norm=LogNorm()) plt.colorbar()
tommy.carstensen
109

Dans le lexique Matplotlib , je pense que vous voulez un tracé hexbin .

Si vous n'êtes pas familier avec ce type de tracé, il ne s'agit que d'un histogramme bivarié dans lequel le plan xy est pavé par une grille régulière d'hexagones.

Ainsi, à partir d'un histogramme, vous pouvez simplement compter le nombre de points tombant dans chaque hexagone, discrétiser la région de traçage comme un ensemble de fenêtres , affecter chaque point à l'une de ces fenêtres; enfin, mappez les fenêtres sur un tableau de couleurs , et vous avez un diagramme hexbin.

Bien que moins couramment utilisés que par exemple, les cercles ou les carrés, ces hexagones sont un meilleur choix car la géométrie du conteneur de regroupement est intuitive:

  • les hexagones ont une symétrie du plus proche voisin (par exemple, les cases carrées ne le font pas, par exemple, la distance entre un point sur la bordure d'un carré et un point à l'intérieur de ce carré n'est pas partout égale) et

  • hexagone est le n-polygone le plus élevé qui donne une tessellation plane régulière (c'est-à-dire que vous pouvez modéliser en toute sécurité le sol de votre cuisine avec des carreaux de forme hexagonale car vous n'aurez pas d'espace vide entre les carreaux lorsque vous avez terminé - pas vrai pour tous les autres polygones n supérieur, n> = 7).

( Matplotlib utilise le terme hexbin plot; ainsi (AFAIK) toutes les bibliothèques de traçage pour R ; je ne sais toujours pas si c'est le terme généralement accepté pour les parcelles de ce type, même si je soupçonne que c'est probablement étant donné que hexbin est court pour le regroupement hexagonal , qui décrit l'étape essentielle de la préparation des données pour l'affichage.)


from matplotlib import pyplot as PLT
from matplotlib import cm as CM
from matplotlib import mlab as ML
import numpy as NP

n = 1e5
x = y = NP.linspace(-5, 5, 100)
X, Y = NP.meshgrid(x, y)
Z1 = ML.bivariate_normal(X, Y, 2, 2, 0, 0)
Z2 = ML.bivariate_normal(X, Y, 4, 1, 1, 1)
ZD = Z2 - Z1
x = X.ravel()
y = Y.ravel()
z = ZD.ravel()
gridsize=30
PLT.subplot(111)

# if 'bins=None', then color of each hexagon corresponds directly to its count
# 'C' is optional--it maps values to x-y coordinates; if 'C' is None (default) then 
# the result is a pure 2D histogram 

PLT.hexbin(x, y, C=z, gridsize=gridsize, cmap=CM.jet, bins=None)
PLT.axis([x.min(), x.max(), y.min(), y.max()])

cb = PLT.colorbar()
cb.set_label('mean value')
PLT.show()   

entrez la description de l'image ici

doug
la source
Que signifie «les hexagones ont la symétrie du plus proche voisin»? Vous dites que "la distance entre un point sur la frontière d'un carré et un point à l'intérieur de ce carré n'est pas partout égale" mais la distance à quoi?
Jaan
9
Pour un hexagone, la distance du centre à un sommet joignant deux côtés est également plus longue que du centre au milieu d'un côté, seul le rapport est plus petit (2 / sqrt (3) ≈ 1.15 pour hexagone vs. sqrt (2) ≈ 1.41 pour carré). La seule forme où la distance entre le centre et chaque point de la bordure est égale est le cercle.
Jaan
5
@Jaan Pour un hexagone, chaque voisin est à la même distance. Il n'y a aucun problème avec 8 quartiers ou 4 quartiers. Pas de voisins en diagonale, juste un type de voisin.
isarandi
@doug Comment choisissez-vous le gridsize=paramètre. Je voudrais le choisir tel, de sorte que les hexagones se touchent sans se chevaucher. J'ai remarqué que gridsize=100cela produirait des hexagones plus petits, mais comment choisir la bonne valeur?
Alexander Cska
40

Edit: Pour une meilleure approximation de la réponse d'Alejandro, voir ci-dessous.

Je sais que c'est une vieille question, mais je voulais ajouter quelque chose à la réponse d'Alejandro: si vous voulez une belle image lissée sans utiliser py-sphviewer, vous pouvez à la place utiliser np.histogram2det appliquer un filtre gaussien (de scipy.ndimage.filters) à la carte thermique:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.ndimage.filters import gaussian_filter


def myplot(x, y, s, bins=1000):
    heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins)
    heatmap = gaussian_filter(heatmap, sigma=s)

    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    return heatmap.T, extent


fig, axs = plt.subplots(2, 2)

# Generate some test data
x = np.random.randn(1000)
y = np.random.randn(1000)

sigmas = [0, 16, 32, 64]

for ax, s in zip(axs.flatten(), sigmas):
    if s == 0:
        ax.plot(x, y, 'k.', markersize=5)
        ax.set_title("Scatter plot")
    else:
        img, extent = myplot(x, y, s)
        ax.imshow(img, extent=extent, origin='lower', cmap=cm.jet)
        ax.set_title("Smoothing with  $\sigma$ = %d" % s)

plt.show()

Produit:

Images de sortie

Le nuage de points et s = 16 tracés l'un sur l'autre pour Agape Gal'lo (cliquez pour une meilleure vue):

L'un sur l'autre


Une différence que j'ai remarquée avec mon approche du filtre gaussien et celle d'Alejandro était que sa méthode montre bien mieux les structures locales que la mienne. Par conséquent, j'ai implémenté une méthode simple du voisin le plus proche au niveau du pixel. Cette méthode calcule pour chaque pixel la somme inverse des distances dun points plus proches dans les données. Cette méthode est à une haute résolution assez coûteuse en calcul et je pense qu'il existe un moyen plus rapide, alors faites-moi savoir si vous avez des améliorations.

Mise à jour: Comme je le soupçonnais, il existe une méthode beaucoup plus rapide utilisant Scipy scipy.cKDTree. Voir la réponse de Gabriel pour la mise en œuvre.

Bref, voici mon code:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm


def data_coord2view_coord(p, vlen, pmin, pmax):
    dp = pmax - pmin
    dv = (p - pmin) / dp * vlen
    return dv


def nearest_neighbours(xs, ys, reso, n_neighbours):
    im = np.zeros([reso, reso])
    extent = [np.min(xs), np.max(xs), np.min(ys), np.max(ys)]

    xv = data_coord2view_coord(xs, reso, extent[0], extent[1])
    yv = data_coord2view_coord(ys, reso, extent[2], extent[3])
    for x in range(reso):
        for y in range(reso):
            xp = (xv - x)
            yp = (yv - y)

            d = np.sqrt(xp**2 + yp**2)

            im[y][x] = 1 / np.sum(d[np.argpartition(d.ravel(), n_neighbours)[:n_neighbours]])

    return im, extent


n = 1000
xs = np.random.randn(n)
ys = np.random.randn(n)
resolution = 250

fig, axes = plt.subplots(2, 2)

for ax, neighbours in zip(axes.flatten(), [0, 16, 32, 64]):
    if neighbours == 0:
        ax.plot(xs, ys, 'k.', markersize=2)
        ax.set_aspect('equal')
        ax.set_title("Scatter Plot")
    else:
        im, extent = nearest_neighbours(xs, ys, resolution, neighbours)
        ax.imshow(im, origin='lower', extent=extent, cmap=cm.jet)
        ax.set_title("Smoothing over %d neighbours" % neighbours)
        ax.set_xlim(extent[0], extent[1])
        ax.set_ylim(extent[2], extent[3])
plt.show()

Résultat:

Lissage du voisin le plus proche

Jurgy
la source
1
Aime ça. Graph est aussi agréable que la réponse d'Alejandro, mais aucun nouveau package n'est requis.
Nathan Clement
Très agréable ! Mais vous générez un décalage avec cette méthode. Vous pouvez le voir en comparant un nuage de points normal avec le graphique coloré. Pourriez-vous ajouter quelque chose pour le corriger? Ou simplement pour déplacer le graphique de valeurs x et y?
Agape Gal'lo
1
Agape Gal'lo, qu'est-ce que tu veux dire par offset? Si vous les tracez les uns sur les autres, ils correspondent (voir la modification de mon message). Vous êtes peut-être découragé parce que la largeur de la dispersion ne correspond pas exactement aux trois autres.
Jurgy
Merci beaucoup d'avoir tracé le graphique rien que pour moi! J'ai compris mon erreur: j'avais modifié l '"étendue" pour définir les limites x et y. Je comprends maintenant qu'il a modifié l'origine du graphique. Ensuite, j'ai une dernière question: comment puis-je étendre les limites du graphique, même pour une zone où il n'y a pas de données existantes? Par exemple, entre -5 et +5 pour x et y.
Agape Gal'lo
1
Supposons que vous vouliez que l'axe x passe de -5 à 5 et l'axe y de -3 à 4; dans la myplotfonction, ajouter le rangeparamètre à np.histogram2d: np.histogram2d(x, y, bins=bins, range=[[-5, 5], [-3, 4]])et dans l'ensemble de la boucle x et y de l'axe lim: ax.set_xlim([-5, 5]) ax.set_ylim([-3, 4]). De plus, par défaut, imshowconserve le rapport hauteur / largeur identique au rapport de vos axes (donc dans mon exemple un rapport de 10: 7), mais si vous voulez qu'il corresponde à votre fenêtre de tracé, ajoutez le paramètre aspect='auto'à imshow.
Jurgy
31

Au lieu d'utiliser np.hist2d, qui en général produit des histogrammes assez laids, j'aimerais recycler py-sphviewer , un package python pour le rendu de simulations de particules à l'aide d'un noyau de lissage adaptatif et qui peut être facilement installé à partir de pip (voir la documentation de la page Web). Considérez le code suivant, basé sur l'exemple:

import numpy as np
import numpy.random
import matplotlib.pyplot as plt
import sphviewer as sph

def myplot(x, y, nb=32, xsize=500, ysize=500):   
    xmin = np.min(x)
    xmax = np.max(x)
    ymin = np.min(y)
    ymax = np.max(y)

    x0 = (xmin+xmax)/2.
    y0 = (ymin+ymax)/2.

    pos = np.zeros([3, len(x)])
    pos[0,:] = x
    pos[1,:] = y
    w = np.ones(len(x))

    P = sph.Particles(pos, w, nb=nb)
    S = sph.Scene(P)
    S.update_camera(r='infinity', x=x0, y=y0, z=0, 
                    xsize=xsize, ysize=ysize)
    R = sph.Render(S)
    R.set_logscale()
    img = R.get_image()
    extent = R.get_extent()
    for i, j in zip(xrange(4), [x0,x0,y0,y0]):
        extent[i] += j
    print extent
    return img, extent

fig = plt.figure(1, figsize=(10,10))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)


# Generate some test data
x = np.random.randn(1000)
y = np.random.randn(1000)

#Plotting a regular scatter plot
ax1.plot(x,y,'k.', markersize=5)
ax1.set_xlim(-3,3)
ax1.set_ylim(-3,3)

heatmap_16, extent_16 = myplot(x,y, nb=16)
heatmap_32, extent_32 = myplot(x,y, nb=32)
heatmap_64, extent_64 = myplot(x,y, nb=64)

ax2.imshow(heatmap_16, extent=extent_16, origin='lower', aspect='auto')
ax2.set_title("Smoothing over 16 neighbors")

ax3.imshow(heatmap_32, extent=extent_32, origin='lower', aspect='auto')
ax3.set_title("Smoothing over 32 neighbors")

#Make the heatmap using a smoothing over 64 neighbors
ax4.imshow(heatmap_64, extent=extent_64, origin='lower', aspect='auto')
ax4.set_title("Smoothing over 64 neighbors")

plt.show()

qui produit l'image suivante:

entrez la description de l'image ici

Comme vous le voyez, les images sont plutôt jolies et nous pouvons y identifier différentes sous-structures. Ces images sont construites en étalant un poids donné pour chaque point dans un certain domaine, défini par la longueur de lissage, qui à son tour est donnée par la distance au nb voisin le plus proche (j'ai choisi 16, 32 et 64 pour les exemples). Ainsi, les régions à densité plus élevée sont généralement réparties sur des régions plus petites par rapport aux régions à plus faible densité.

La fonction myplot est juste une fonction très simple que j'ai écrite pour donner les données x, y à py-sphviewer pour faire la magie.

Alejandro
la source
2
Un commentaire pour tous ceux qui essaient d'installer py-sphviewer sur OSX: J'ai eu beaucoup de difficultés, voir: github.com/alejandrobll/py-sphviewer/issues/3
Sam Finnigan
Dommage que cela ne fonctionne pas avec python3. Il s'installe, mais se bloque lorsque vous essayez de l'utiliser ...
Fábio Dias
1
@Fabio Dias, La dernière version (1.1.x) fonctionne maintenant avec Python 3.
Alejandro
29

Si vous utilisez 1.2.x

import numpy as np
import matplotlib.pyplot as plt

x = np.random.randn(100000)
y = np.random.randn(100000)
plt.hist2d(x,y,bins=100)
plt.show()

gaussian_2d_heat_map

Piti Ongmongkolkul
la source
17

Seaborn a maintenant le fonction jointplot qui devrait bien fonctionner ici:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Generate some test data
x = np.random.randn(8873)
y = np.random.randn(8873)

sns.jointplot(x=x, y=y, kind='hex')
plt.show()

image de démonstration

mots pour
la source
Simple, joli et analytiquement utile.
ryanjdillon
@wordsforthewise comment rendre une donnée de 600k lisible visuellement en utilisant ceci? (comment redimensionner)
nrmb
Je ne sais pas trop ce que vous voulez dire; il vaut peut-être mieux poser une question distincte et la lier ici. Vous voulez dire redimensionner toute la figue? Créez d'abord la figure avec fig = plt.figure(figsize=(12, 12)), puis obtenez l'axe actuel avec ax=plt.gca(), puis ajoutez l'argument ax=axà la jointplotfonction.
Wordsforthewise
@wordsforthewise pourriez-vous s'il vous plaît répondre à cette question: stackoverflow.com/questions/50997662/… merci
ebrahimi
4

et la question initiale était ... comment convertir les valeurs de dispersion en valeurs de grille, non? histogram2dcompte la fréquence par cellule, cependant, si vous avez d'autres données par cellule que la fréquence, vous aurez besoin d'un travail supplémentaire.

x = data_x # between -10 and 4, log-gamma of an svc
y = data_y # between -4 and 11, log-C of an svc
z = data_z #between 0 and 0.78, f1-values from a difficult dataset

Donc, j'ai un jeu de données avec des résultats Z pour les coordonnées X et Y. Cependant, je calculais quelques points en dehors de la zone d'intérêt (grands écarts) et des tas de points dans une petite zone d'intérêt.

Oui ici ça devient plus difficile mais aussi plus amusant. Certaines bibliothèques (désolé):

from matplotlib import pyplot as plt
from matplotlib import cm
import numpy as np
from scipy.interpolate import griddata

pyplot est mon moteur graphique aujourd'hui, cm est une gamme de cartes de couleurs avec quelques choix initeresting. numpy pour les calculs et griddata pour attacher des valeurs à une grille fixe.

Le dernier est important surtout parce que la fréquence des points xy n'est pas également distribuée dans mes données. Tout d'abord, commençons par quelques limites adaptées à mes données et une taille de grille arbitraire. Les données d'origine ont des points de données également en dehors de ces limites x et y.

#determine grid boundaries
gridsize = 500
x_min = -8
x_max = 2.5
y_min = -2
y_max = 7

Nous avons donc défini une grille de 500 pixels entre les valeurs min et max de x et y.

Dans mes données, il y a beaucoup plus que les 500 valeurs disponibles dans la zone de grand intérêt; considérant que dans la zone à faible intérêt, il n'y a même pas 200 valeurs dans la grille totale; entre les limites graphiques de x_minetx_max il y en a encore moins.

Donc, pour obtenir une belle image, la tâche est d'obtenir une moyenne des valeurs d'intérêt élevé et de combler les lacunes ailleurs.

Je définis ma grille maintenant. Pour chaque paire xx-yy, je veux avoir une couleur.

xx = np.linspace(x_min, x_max, gridsize) # array of x values
yy = np.linspace(y_min, y_max, gridsize) # array of y values
grid = np.array(np.meshgrid(xx, yy.T))
grid = grid.reshape(2, grid.shape[1]*grid.shape[2]).T

Pourquoi cette forme étrange? scipy.griddata veut une forme de (n, D).

Griddata calcule une valeur par point de la grille, par une méthode prédéfinie. Je choisis "le plus proche" - les points de grille vides seront remplis avec les valeurs du voisin le plus proche. On dirait que les zones avec moins d'informations ont des cellules plus grandes (même si ce n'est pas le cas). On pourrait choisir d'interpoler "linéaire", alors les zones avec moins d'informations semblent moins nettes. Question de goût, vraiment.

points = np.array([x, y]).T # because griddata wants it that way
z_grid2 = griddata(points, z, grid, method='nearest')
# you get a 1D vector as result. Reshape to picture format!
z_grid2 = z_grid2.reshape(xx.shape[0], yy.shape[0])

Et hop, on passe à matplotlib pour afficher l'intrigue

fig = plt.figure(1, figsize=(10, 10))
ax1 = fig.add_subplot(111)
ax1.imshow(z_grid2, extent=[x_min, x_max,y_min, y_max,  ],
            origin='lower', cmap=cm.magma)
ax1.set_title("SVC: empty spots filled by nearest neighbours")
ax1.set_xlabel('log gamma')
ax1.set_ylabel('log C')
plt.show()

Autour de la partie pointue du V-Shape, vous voyez que j'ai fait beaucoup de calculs lors de ma recherche du sweet spot, alors que les parties les moins intéressantes presque partout ailleurs ont une résolution inférieure.

Carte thermique d'un SVC en haute résolution

Anderas
la source
Pouvez-vous améliorer votre réponse pour avoir un code complet et exécutable? C'est une méthode intéressante que vous avez fournie. J'essaye de mieux le comprendre pour le moment. Je ne comprends pas très bien pourquoi il existe une forme en V. Merci.
ldmtwo
Le V-Shape provient de mes données. C'est la valeur f1 pour un SVM entraîné: Cela va un peu dans la théorie des SVM. Si vous avez un C élevé, il inclut tous vos points dans le calcul, ce qui permet à une plage gamma plus large de fonctionner. Gamma est la rigidité de la courbe séparant le bon et le mauvais. Ces deux valeurs doivent être données au SVM (X et Y dans mon graphique); alors vous obtenez un résultat (Z dans mon graphique). Dans la meilleure zone, nous espérons atteindre des sommets significatifs.
Anderas
deuxième essai: le V-Shape est dans mes données. C'est la valeur f1 pour un SVM: si vous avez un C élevé, cela inclut tous vos points dans le calcul, ce qui permet à une plage gamma plus large de fonctionner, mais rend le calcul lent. Gamma est la rigidité de la courbe séparant le bon et le mauvais. Ces deux valeurs doivent être données au SVM (X et Y dans mon graphique); alors vous obtenez un résultat (Z dans mon graphique). Dans la zone optimisée, vous obtenez des valeurs élevées, ailleurs des valeurs faibles. Ce que j'ai montré ici est utilisable si vous avez des valeurs Z pour certains (X, Y) et de nombreuses lacunes ailleurs. Si vous avez des points de données (X, Y, Z), vous pouvez utiliser mon code.
Anderas
4

Voici l' approche du grand voisin le plus proche de Jurgy, mais implémentée à l'aide de scipy.cKDTree . Dans mes tests, c'est environ 100 fois plus rapide.

entrez la description de l'image ici

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.spatial import cKDTree


def data_coord2view_coord(p, resolution, pmin, pmax):
    dp = pmax - pmin
    dv = (p - pmin) / dp * resolution
    return dv


n = 1000
xs = np.random.randn(n)
ys = np.random.randn(n)

resolution = 250

extent = [np.min(xs), np.max(xs), np.min(ys), np.max(ys)]
xv = data_coord2view_coord(xs, resolution, extent[0], extent[1])
yv = data_coord2view_coord(ys, resolution, extent[2], extent[3])


def kNN2DDens(xv, yv, resolution, neighbours, dim=2):
    """
    """
    # Create the tree
    tree = cKDTree(np.array([xv, yv]).T)
    # Find the closest nnmax-1 neighbors (first entry is the point itself)
    grid = np.mgrid[0:resolution, 0:resolution].T.reshape(resolution**2, dim)
    dists = tree.query(grid, neighbours)
    # Inverse of the sum of distances to each grid point.
    inv_sum_dists = 1. / dists[0].sum(1)

    # Reshape
    im = inv_sum_dists.reshape(resolution, resolution)
    return im


fig, axes = plt.subplots(2, 2, figsize=(15, 15))
for ax, neighbours in zip(axes.flatten(), [0, 16, 32, 63]):

    if neighbours == 0:
        ax.plot(xs, ys, 'k.', markersize=5)
        ax.set_aspect('equal')
        ax.set_title("Scatter Plot")
    else:

        im = kNN2DDens(xv, yv, resolution, neighbours)

        ax.imshow(im, origin='lower', extent=extent, cmap=cm.Blues)
        ax.set_title("Smoothing over %d neighbours" % neighbours)
        ax.set_xlim(extent[0], extent[1])
        ax.set_ylim(extent[2], extent[3])

plt.savefig('new.png', dpi=150, bbox_inches='tight')
Gabriel
la source
1
Je savais que mon implémentation était très inefficace mais je ne connaissais pas cKDTree. Bien joué! Je vais vous citer dans ma réponse.
Jurgy
2

Créez un tableau à 2 dimensions qui correspond aux cellules de votre image finale, appelé par exemple heatmap_cellset instanciez-le comme tous les zéros.

Choisissez deux facteurs d'échelle qui définissent la différence entre chaque élément du tableau en unités réelles, pour chaque dimension, disons x_scaleet y_scale. Choisissez-les de manière à ce que tous vos points de données tombent dans les limites du tableau de heatmap.

Pour chaque point de données brut avec x_valueet y_value:

heatmap_cells[floor(x_value/x_scale),floor(y_value/y_scale)]+=1

Meep meep
la source
1

entrez la description de l'image ici

En voici un que j'ai fait sur un ensemble de 1 million de points avec 3 catégories (de couleur rouge, vert et bleu). Voici un lien vers le référentiel si vous souhaitez essayer la fonction. Repo Github

histplot(
    X,
    Y,
    labels,
    bins=2000,
    range=((-3,3),(-3,3)),
    normalize_each_label=True,
    colors = [
        [1,0,0],
        [0,1,0],
        [0,0,1]],
    gain=50)
Joel Stansbury
la source
0

Très similaire à la réponse de @ Piti , mais en utilisant 1 appel au lieu de 2 pour générer les points:

import numpy as np
import matplotlib.pyplot as plt

pts = 1000000
mean = [0.0, 0.0]
cov = [[1.0,0.0],[0.0,1.0]]

x,y = np.random.multivariate_normal(mean, cov, pts).T
plt.hist2d(x, y, bins=50, cmap=plt.cm.jet)
plt.show()

Production:

2d_gaussian_heatmap

Alaa M.
la source
0

J'ai bien peur d'être un peu en retard à la fête, mais j'avais une question similaire il y a quelque temps. La réponse acceptée (par @ptomato) m'a aidé, mais je voudrais également publier ceci au cas où cela serait utile à quelqu'un.


''' I wanted to create a heatmap resembling a football pitch which would show the different actions performed '''

import numpy as np
import matplotlib.pyplot as plt
import random

#fixing random state for reproducibility
np.random.seed(1234324)

fig = plt.figure(12)
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

#Ratio of the pitch with respect to UEFA standards 
hmap= np.full((6, 10), 0)
#print(hmap)

xlist = np.random.uniform(low=0.0, high=100.0, size=(20))
ylist = np.random.uniform(low=0.0, high =100.0, size =(20))

#UEFA Pitch Standards are 105m x 68m
xlist = (xlist/100)*10.5
ylist = (ylist/100)*6.5

ax1.scatter(xlist,ylist)

#int of the co-ordinates to populate the array
xlist_int = xlist.astype (int)
ylist_int = ylist.astype (int)

#print(xlist_int, ylist_int)

for i, j in zip(xlist_int, ylist_int):
    #this populates the array according to the x,y co-ordinate values it encounters 
    hmap[j][i]= hmap[j][i] + 1   

#Reversing the rows is necessary 
hmap = hmap[::-1]

#print(hmap)
im = ax2.imshow(hmap)

Voici le résultat entrez la description de l'image ici

Abhishek
la source