étiquettes d'axes pyplot pour les sous-tracés

195

J'ai l'intrigue suivante:

import matplotlib.pyplot as plt

fig2 = plt.figure()
ax3 = fig2.add_subplot(2,1,1)
ax4 = fig2.add_subplot(2,1,2)
ax4.loglog(x1, y1)
ax3.loglog(x2, y2)
ax3.set_ylabel('hello')

Je veux pouvoir créer des étiquettes et des titres d'axes non seulement pour chacun des deux sous-graphiques, mais également des étiquettes communes qui couvrent les deux sous-graphiques. Par exemple, comme les deux graphiques ont des axes identiques, je n'ai besoin que d'un seul jeu d'étiquettes d'axes x et y. Je veux cependant des titres différents pour chaque sous-intrigue.

J'ai essayé plusieurs choses mais aucune n'a fonctionné correctement

farqwag25
la source

Réponses:

271

Vous pouvez créer un grand sous-tracé qui couvre les deux sous-tracés, puis définir les étiquettes communes.

import random
import matplotlib.pyplot as plt

x = range(1, 101)
y1 = [random.randint(1, 100) for _ in range(len(x))]
y2 = [random.randint(1, 100) for _ in range(len(x))]

fig = plt.figure()
ax = fig.add_subplot(111)    # The big subplot
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)

# Turn off axis lines and ticks of the big subplot
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)

ax1.loglog(x, y1)
ax2.loglog(x, y2)

# Set common labels
ax.set_xlabel('common xlabel')
ax.set_ylabel('common ylabel')

ax1.set_title('ax1 title')
ax2.set_title('ax2 title')

plt.savefig('common_labels.png', dpi=300)

common_labels.png

Une autre méthode consiste à utiliser fig.text () pour définir directement les emplacements des étiquettes communes.

import random
import matplotlib.pyplot as plt

x = range(1, 101)
y1 = [random.randint(1, 100) for _ in range(len(x))]
y2 = [random.randint(1, 100) for _ in range(len(x))]

fig = plt.figure()
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)

ax1.loglog(x, y1)
ax2.loglog(x, y2)

# Set common labels
fig.text(0.5, 0.04, 'common xlabel', ha='center', va='center')
fig.text(0.06, 0.5, 'common ylabel', ha='center', va='center', rotation='vertical')

ax1.set_title('ax1 title')
ax2.set_title('ax2 title')

plt.savefig('common_labels_text.png', dpi=300)

common_labels_text.png

Wen-Wei Liao
la source
1
La fonction suptitle utilise la version fig.text (). Donc, cela pourrait être la façon «officielle» de le faire?
PhML
4
Il convient de souligner qu'il axfaut créer avant ax1et ax2, sinon la grande parcelle couvrira les petites parcelles.
1 ''
ax.grid (False) ou plt.grid (False) est également nécessaire si les paramètres de traçage globaux incluent une grille (visible).
Næreen
4
Il semble que la première approche ne fonctionne plus avec les versions récentes de matplotplib (j'utilise 2.0.2): les étiquettes ajoutées à la hache englobante ne sont pas visibles.
M. Toya
Comment ajouter des y_labels à chaque sous-tracé individuel?
Fardin
123

Un moyen simple en utilisant subplots:

import matplotlib.pyplot as plt

fig, axes = plt.subplots(3, 4, sharex=True, sharey=True)
# add a big axes, hide frame
fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
plt.grid(False)
plt.xlabel("common X")
plt.ylabel("common Y")
Julian Chen
la source
1
ax.grid (False) ou plt.grid (False) est également nécessaire si les paramètres de traçage globaux incluent une grille (visible).
Næreen
1
Je fais cela pour une sous-parcelle (5, 1) et mon ylabel est bien sur le bord gauche de la fenêtre au lieu de près des sous-graphiques.
Evidlo
1
Vous avez un vote positif. mais veuillez toujours expliquer ce que fait le code, attacher une image ou montrer un exemple, car il a certainement fallu un peu de temps pour l'obtenir.
Kareem Jeiroudi
5
Changer 'off'pour Falseavec les nouvelles versions de Matplotlib (j'ai 2.2.2)
Ted
2
Et puis comment ajoutez-vous les parcelles? for ax in axes: ax.plot(x, y)ne semble pas faire de bien.
numéro d'utilisateur
17

plt.setp() fera le travail:

# plot something
fig, axs = plt.subplots(3,3, figsize=(15, 8), sharex=True, sharey=True)
for i, ax in enumerate(axs.flat):
    ax.scatter(*np.random.normal(size=(2,200)))
    ax.set_title(f'Title {i}')

# set labels
plt.setp(axs[-1, :], xlabel='x axis label')
plt.setp(axs[:, 0], ylabel='y axis label')

entrez la description de l'image ici

MohammadReza
la source
Existe-t-il un moyen de définir également la taille / l'épaisseur de la police avec cette méthode?
pfabri
16

La réponse de Wen-wei Liao est bonne si vous n'essayez pas d'exporter des graphiques vectoriels ou que vous avez configuré vos backends matplotlib pour ignorer les axes incolores; sinon les axes masqués apparaîtront dans le graphique exporté.

Ma réponse suplabelici est similaire à celle fig.suptitlequi utilise la fig.textfonction. Par conséquent, aucun artiste de haches n'est créé et rendu incolore. Cependant, si vous essayez de l'appeler plusieurs fois, vous obtiendrez du texte ajouté les uns sur les autres (comme le fig.suptitlefait aussi). La réponse de Wen-wei Liao ne le fait pas, car fig.add_subplot(111)renverra le même objet Axes s'il est déjà créé.

Ma fonction peut également être appelée après la création des tracés.

def suplabel(axis,label,label_prop=None,
             labelpad=5,
             ha='center',va='center'):
    ''' Add super ylabel or xlabel to the figure
    Similar to matplotlib.suptitle
    axis       - string: "x" or "y"
    label      - string
    label_prop - keyword dictionary for Text
    labelpad   - padding from the axis (default: 5)
    ha         - horizontal alignment (default: "center")
    va         - vertical alignment (default: "center")
    '''
    fig = pylab.gcf()
    xmin = []
    ymin = []
    for ax in fig.axes:
        xmin.append(ax.get_position().xmin)
        ymin.append(ax.get_position().ymin)
    xmin,ymin = min(xmin),min(ymin)
    dpi = fig.dpi
    if axis.lower() == "y":
        rotation=90.
        x = xmin-float(labelpad)/dpi
        y = 0.5
    elif axis.lower() == 'x':
        rotation = 0.
        x = 0.5
        y = ymin - float(labelpad)/dpi
    else:
        raise Exception("Unexpected axis: x or y")
    if label_prop is None: 
        label_prop = dict()
    pylab.text(x,y,label,rotation=rotation,
               transform=fig.transFigure,
               ha=ha,va=va,
               **label_prop)
KYC
la source
C'est la meilleure réponse imo. C'est facile à mettre en œuvre et les étiquettes ne se chevauchent pas grâce à l'option Labelpad.
Arthur Dent
9

Voici une solution où vous définissez l'étiquette y de l'un des tracés et ajustez la position de celui-ci afin qu'il soit centré verticalement. De cette façon, vous évitez les problèmes mentionnés par KYC.

import numpy as np
import matplotlib.pyplot as plt

def set_shared_ylabel(a, ylabel, labelpad = 0.01):
    """Set a y label shared by multiple axes
    Parameters
    ----------
    a: list of axes
    ylabel: string
    labelpad: float
        Sets the padding between ticklabels and axis label"""

    f = a[0].get_figure()
    f.canvas.draw() #sets f.canvas.renderer needed below

    # get the center position for all plots
    top = a[0].get_position().y1
    bottom = a[-1].get_position().y0

    # get the coordinates of the left side of the tick labels 
    x0 = 1
    for at in a:
        at.set_ylabel('') # just to make sure we don't and up with multiple labels
        bboxes, _ = at.yaxis.get_ticklabel_extents(f.canvas.renderer)
        bboxes = bboxes.inverse_transformed(f.transFigure)
        xt = bboxes.x0
        if xt < x0:
            x0 = xt
    tick_label_left = x0

    # set position of label
    a[-1].set_ylabel(ylabel)
    a[-1].yaxis.set_label_coords(tick_label_left - labelpad,(bottom + top)/2, transform=f.transFigure)

length = 100
x = np.linspace(0,100, length)
y1 = np.random.random(length) * 1000
y2 = np.random.random(length)

f,a = plt.subplots(2, sharex=True, gridspec_kw={'hspace':0})
a[0].plot(x, y1)
a[1].plot(x, y2)
set_shared_ylabel(a, 'shared y label (a. u.)')

entrez la description de l'image ici

Hagne
la source
4
# list loss and acc are your data
fig = plt.figure()
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

ax1.plot(iteration1, loss)
ax2.plot(iteration2, acc)

ax1.set_title('Training Loss')
ax2.set_title('Training Accuracy')

ax1.set_xlabel('Iteration')
ax1.set_ylabel('Loss')

ax2.set_xlabel('Iteration')
ax2.set_ylabel('Accuracy')
J.Zhao
la source
2

Les méthodes des autres réponses ne fonctionneront pas correctement lorsque les yticks sont grands. L'étiquette y sera soit superposée avec des graduations, soit coupée à gauche ou complètement invisible / à l'extérieur de la figure.

J'ai modifié la réponse de Hagne pour qu'elle fonctionne avec plus d'une colonne de sous-graphiques, à la fois pour xlabel et ylabel, et cela déplace le graphique pour garder le ylabel visible sur la figure.

def set_shared_ylabel(a, xlabel, ylabel, labelpad = 0.01, figleftpad=0.05):
    """Set a y label shared by multiple axes
    Parameters
    ----------
    a: list of axes
    ylabel: string
    labelpad: float
        Sets the padding between ticklabels and axis label"""

    f = a[0,0].get_figure()
    f.canvas.draw() #sets f.canvas.renderer needed below

    # get the center position for all plots
    top = a[0,0].get_position().y1
    bottom = a[-1,-1].get_position().y0

    # get the coordinates of the left side of the tick labels
    x0 = 1
    x1 = 1
    for at_row in a:
        at = at_row[0]
        at.set_ylabel('') # just to make sure we don't and up with multiple labels
        bboxes, _ = at.yaxis.get_ticklabel_extents(f.canvas.renderer)
        bboxes = bboxes.inverse_transformed(f.transFigure)
        xt = bboxes.x0
        if xt < x0:
            x0 = xt
            x1 = bboxes.x1
    tick_label_left = x0

    # shrink plot on left to prevent ylabel clipping
    # (x1 - tick_label_left) is the x coordinate of right end of tick label,
    # basically how much padding is needed to fit tick labels in the figure
    # figleftpad is additional padding to fit the ylabel
    plt.subplots_adjust(left=(x1 - tick_label_left) + figleftpad)

    # set position of label, 
    # note that (figleftpad-labelpad) refers to the middle of the ylabel
    a[-1,-1].set_ylabel(ylabel)
    a[-1,-1].yaxis.set_label_coords(figleftpad-labelpad,(bottom + top)/2, transform=f.transFigure)

    # set xlabel
    y0 = 1
    for at in axes[-1]:
        at.set_xlabel('')  # just to make sure we don't and up with multiple labels
        bboxes, _ = at.xaxis.get_ticklabel_extents(fig.canvas.renderer)
        bboxes = bboxes.inverse_transformed(fig.transFigure)
        yt = bboxes.y0
        if yt < y0:
            y0 = yt
    tick_label_bottom = y0

    axes[-1, -1].set_xlabel(xlabel)
    axes[-1, -1].xaxis.set_label_coords((left + right) / 2, tick_label_bottom - labelpad, transform=fig.transFigure)

Cela fonctionne pour l'exemple suivant, alors que la réponse de Hagne ne dessinera pas le ylabel (car il est en dehors du canevas) et le ylabel de KYC chevauche les étiquettes de graduation:

import matplotlib.pyplot as plt
import itertools

fig, axes = plt.subplots(3, 4, sharey='row', sharex=True, squeeze=False)
fig.subplots_adjust(hspace=.5)
for i, a in enumerate(itertools.chain(*axes)):
    a.plot([0,4**i], [0,4**i])
    a.set_title(i)
set_shared_ylabel(axes, 'common X', 'common Y')
plt.show()

Alternativement, si vous êtes d'accord avec l'axe incolore, j'ai modifié la solution de Julian Chen afin que ylabel ne chevauche pas les étiquettes de graduation.

Fondamentalement, nous devons juste définir les ylims de l'incolore afin qu'il corresponde aux plus grands ylims des sous-tracés afin que les étiquettes de graduation incolores définissent l'emplacement correct pour le ylabel.

Encore une fois, nous devons réduire l'intrigue pour éviter l'écrêtage. Ici, j'ai codé en dur le montant à réduire, mais vous pouvez jouer pour trouver un nombre qui vous convient ou le calculer comme dans la méthode ci-dessus.

import matplotlib.pyplot as plt
import itertools

fig, axes = plt.subplots(3, 4, sharey='row', sharex=True, squeeze=False)
fig.subplots_adjust(hspace=.5)
miny = maxy = 0
for i, a in enumerate(itertools.chain(*axes)):
    a.plot([0,4**i], [0,4**i])
    a.set_title(i)
    miny = min(miny, a.get_ylim()[0])
    maxy = max(maxy, a.get_ylim()[1])

# add a big axes, hide frame
# set ylim to match the largest range of any subplot
ax_invis = fig.add_subplot(111, frameon=False)
ax_invis.set_ylim([miny, maxy])

# hide tick and tick label of the big axis
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
plt.xlabel("common X")
plt.ylabel("common Y")

# shrink plot to prevent clipping
plt.subplots_adjust(left=0.15)
plt.show()
Tim
la source