Comprendre l'einsum de NumPy

205

J'ai du mal à comprendre exactement comment ça einsummarche. J'ai regardé la documentation et quelques exemples, mais cela ne semble pas coller.

Voici un exemple que nous avons passé en classe:

C = np.einsum("ij,jk->ki", A, B)

pour deux tableaux AetB

Je pense que cela prendrait A^T * B, mais je ne suis pas sûr (cela prend la transposition de l'un d'eux, non?). Quelqu'un peut-il me expliquer exactement ce qui se passe ici (et en général lors de l'utilisation einsum)?

Détroit de Lance
la source
7
En fait, ce sera le cas (A * B)^T, ou de manière équivalente B^T * A^T.
Tigran Saluev
23
J'ai écrit un court article de blog sur les bases d' einsum ici . (Je suis heureux de transplanter les éléments les plus pertinents dans une réponse sur Stack Overflow si cela est utile).
Alex Riley
1
@ajcr - Beau lien. Merci. La numpydocumentation est malheureusement insuffisante pour expliquer les détails.
rayryeng
Merci pour le vote de confiance! Tardivement, j'ai apporté une réponse ci-dessous .
Alex Riley
Notez qu'en Python, ce *n'est pas une multiplication matricielle mais une multiplication élémentaire. Fais attention!
ComputerScientist

Réponses:

392

(Remarque: cette réponse est basée sur un court article de blog sur einsumj'ai écrit il y a quelque temps.)

Que fait einsum-on?

Imaginez que nous ayons deux tableaux multidimensionnels, Aet B. Supposons maintenant que nous voulions ...

  • multipliez A avec Bd'une manière particulière pour créer une nouvelle gamme de produits; et puis peut-être
  • additionner ce nouveau tableau le long d'axes particuliers; et puis peut-être
  • transposez les axes du nouveau tableau dans un ordre particulier.

Il y a de fortes chances que einsumcela nous aide à le faire plus rapidement et avec une mémoire plus efficace que les combinaisons des fonctions NumPy comme multiply, sumet transposele permettront.

Comment ça einsummarche?

Voici un exemple simple (mais pas complètement trivial). Prenez les deux tableaux suivants:

A = np.array([0, 1, 2])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

Nous allons multiplier Aet par Bélément, puis additionner le long des lignes du nouveau tableau. Dans NumPy "normal", nous écrivions:

>>> (A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

Donc ici, l'opération d'indexation sur Aaligne les premiers axes des deux matrices pour que la multiplication puisse être diffusée. Les lignes du tableau de produits sont ensuite additionnées pour renvoyer la réponse.

Maintenant, si nous voulions utiliser à la einsumplace, nous pourrions écrire:

>>> np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

La chaîne de signature'i,ij->i' est la clé ici et a besoin d'un peu d'explication. Vous pouvez y penser en deux moitiés. Sur le côté gauche (à gauche du ->), nous avons étiqueté les deux tableaux d'entrée. À droite de ->, nous avons étiqueté le tableau avec lequel nous voulons finir.

Voici ce qui se passe ensuite:

  • Aa un axe; nous l'avons étiqueté i. Et Ba deux axes; nous avons étiqueté l'axe 0 comme iet l'axe 1 comme j.

  • En répétant l'étiquette idans les deux tableaux d'entrée, nous disons einsumque ces deux axes doivent être multipliés ensemble. En d'autres termes, nous multiplions le tableau Aavec chaque colonne du tableau B, tout comme le A[:, np.newaxis] * Bfait.

  • Notez que cela jn'apparaît pas comme une étiquette dans la sortie souhaitée; nous venons d'utiliser i(nous voulons finir avec un tableau 1D). En omettant l'étiquette, nous disons de einsumfaire la somme le long de cet axe. En d'autres termes, nous additionnons les lignes des produits, tout comme le .sum(axis=1)fait.

C'est essentiellement tout ce que vous devez savoir pour l'utiliser einsum. Cela aide à jouer un peu; si nous laissons les deux étiquettes dans la sortie,, 'i,ij->ij'nous récupérons un tableau 2D de produits (identique à A[:, np.newaxis] * B). Si nous disons aucune étiquette de sortie 'i,ij->, nous récupérons un seul numéro (comme pour le faire (A[:, np.newaxis] * B).sum()).

La grande chose à propos de einsumcependant, c'est que cela ne crée pas d'abord une gamme temporaire de produits; il additionne simplement les produits au fur et à mesure. Cela peut conduire à de grandes économies dans l'utilisation de la mémoire.

Un exemple légèrement plus grand

Pour expliquer le produit scalaire, voici deux nouveaux tableaux:

A = array([[1, 1, 1],
           [2, 2, 2],
           [5, 5, 5]])

B = array([[0, 1, 0],
           [1, 1, 0],
           [1, 1, 1]])

Nous calculerons le produit scalaire en utilisant np.einsum('ij,jk->ik', A, B). Voici une image montrant l'étiquetage du Aet Bet du tableau de sortie que nous obtenons de la fonction:

entrez la description de l'image ici

Vous pouvez voir que l'étiquette jest répétée - cela signifie que nous multiplions les lignes de Aavec les colonnes de B. De plus, l'étiquette jn'est pas incluse dans la sortie - nous additionnons ces produits. Les étiquettes iet ksont conservées pour la sortie, nous récupérons donc un tableau 2D.

Il peut être encore plus clair de comparer ce résultat avec le tableau où l'étiquette jn'est pas additionnée. Ci-dessous, à gauche, vous pouvez voir le tableau 3D qui résulte de l'écriture np.einsum('ij,jk->ijk', A, B)(c'est-à-dire que nous avons conservé l'étiquette j):

entrez la description de l'image ici

L'axe de somme jdonne le produit scalaire attendu, illustré à droite.

Quelques exercices

Pour en savoir plus einsum, il peut être utile d'implémenter des opérations de tableau NumPy familières à l'aide de la notation en indice. Tout ce qui implique des combinaisons d'axes de multiplication et de sommation peut être écrit en utilisant einsum.

Soit A et B deux tableaux 1D de même longueur. Par exemple, A = np.arange(10)et B = np.arange(5, 15).

  • La somme de Apeut s'écrire:

    np.einsum('i->', A)
    
  • La multiplication élément par élément,, A * Bpeut s'écrire:

    np.einsum('i,i->i', A, B)
    
  • Le produit interne ou produit scalaire, np.inner(A, B)ou np.dot(A, B), peut être écrit:

    np.einsum('i,i->', A, B) # or just use 'i,i'
    
  • Le produit extérieur,, np.outer(A, B)peut être écrit:

    np.einsum('i,j->ij', A, B)
    

Pour les tableaux 2D, Cet D, à condition que les axes soient des longueurs compatibles (les deux de même longueur ou l'un d'eux de longueur 1), voici quelques exemples:

  • La trace de C(somme de la diagonale principale),, np.trace(C)peut s'écrire:

    np.einsum('ii', C)
    
  • Multiplication élément par élément de Cet la transposition de D, C * D.Tpeut être écrit:

    np.einsum('ij,ji->ij', C, D)
    
  • Multiplier chaque élément de Cpar le tableau D(pour faire un tableau 4D),, C[:, :, None, None] * Dpeut s'écrire:

    np.einsum('ij,kl->ijkl', C, D)  
    
Alex Riley
la source
1
Très belle explication, merci. "Notez que je n'apparaît pas comme une étiquette dans la sortie souhaitée" - n'est-ce pas?
Ian Hincks
Merci @IanHincks! Cela ressemble à une faute de frappe; Je l'ai corrigé maintenant.
Alex Riley
1
Très bonne réponse. Il convient également de noter que cela ij,jkpourrait fonctionner par lui-même (sans les flèches) pour former la multiplication matricielle. Mais il semble que pour plus de clarté, il est préférable de mettre les flèches, puis les dimensions de sortie. C'est dans le billet de blog.
ComputerScientist
1
@Peaceful: c'est une de ces occasions où il est difficile de choisir le bon mot! Je pense que "colonne" convient un peu mieux ici car il Aest de longueur 3, la même que la longueur des colonnes dans B(alors que les lignes de Blongueur 4 et ne peuvent pas être multipliées par élément par A).
Alex Riley
1
Notez que l'omission de ->affecte la sémantique: "En mode implicite, les indices choisis sont importants car les axes de la sortie sont réordonnés par ordre alphabétique. Cela signifie que cela np.einsum('ij', a)n'affecte pas un tableau 2D, alors qu'il np.einsum('ji', a)prend sa transposition."
BallpointBen
45

Il numpy.einsum()est très facile de saisir l'idée de si vous la comprenez intuitivement. À titre d'exemple, commençons par une description simple impliquant la multiplication matricielle .


Pour l'utiliser numpy.einsum(), tout ce que vous avez à faire est de passer la soi-disant chaîne d'indices en argument, suivie de vos tableaux d'entrée .

Disons que vous avez deux tableaux 2D, Aet B, et que vous voulez faire la multiplication de matrices. Alors, vous faites:

np.einsum("ij, jk -> ik", A, B)

Ici, la chaîne de l' indiceij correspond au tableau Atandis que la chaîne de l' indicejk correspond au tableau B. En outre, la chose la plus importante à noter ici est que le nombre de caractères dans chaque chaîne d'indice doit correspondre aux dimensions du tableau. (c'est-à-dire deux caractères pour les tableaux 2D, trois caractères pour les tableaux 3D, et ainsi de suite.) Et si vous répétez les caractères entre les chaînes d'indice ( jdans notre cas), cela signifie que vous voulez que la einsomme se produise le long de ces dimensions. Ainsi, ils seront réduits en somme. (c'est-à-dire que cette dimension aura disparu )

La chaîne d'indice après cela ->, sera notre tableau résultant. Si vous le laissez vide, tout sera additionné et une valeur scalaire est renvoyée en résultat. Sinon, le tableau résultant aura des dimensions en fonction de la chaîne d'indice . Dans notre exemple, ce sera le cas ik. C'est intuitif car nous savons que pour la multiplication matricielle, le nombre de colonnes dans le tableau Adoit correspondre au nombre de lignes dans le tableau, Bce qui se passe ici (c'est-à-dire que nous encodons cette connaissance en répétant le caractère jdans la chaîne d'indice )


Voici quelques exemples supplémentaires illustrant l'utilisation / la puissance de np.einsum()dans la mise en œuvre de certaines opérations courantes de tenseurs ou de nd-tableaux , de manière succincte.

Contributions

# a vector
In [197]: vec
Out[197]: array([0, 1, 2, 3])

# an array
In [198]: A
Out[198]: 
array([[11, 12, 13, 14],
       [21, 22, 23, 24],
       [31, 32, 33, 34],
       [41, 42, 43, 44]])

# another array
In [199]: B
Out[199]: 
array([[1, 1, 1, 1],
       [2, 2, 2, 2],
       [3, 3, 3, 3],
       [4, 4, 4, 4]])

1) Multiplication matricielle (similaire à np.matmul(arr1, arr2))

In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]: 
array([[130, 130, 130, 130],
       [230, 230, 230, 230],
       [330, 330, 330, 330],
       [430, 430, 430, 430]])

2) Extraire les éléments le long de la diagonale principale (similaire à np.diag(arr))

In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])

3) Produit Hadamard (c.-à-d. Produit élément par élément de deux tableaux) (similaire à arr1 * arr2)

In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]: 
array([[ 11,  12,  13,  14],
       [ 42,  44,  46,  48],
       [ 93,  96,  99, 102],
       [164, 168, 172, 176]])

4) Mise au carré élément par élément (similaire à np.square(arr)ou arr ** 2)

In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]: 
array([[ 1,  1,  1,  1],
       [ 4,  4,  4,  4],
       [ 9,  9,  9,  9],
       [16, 16, 16, 16]])

5) Trace (c'est-à-dire somme des éléments de la diagonale principale) (similaire à np.trace(arr))

In [217]: np.einsum("ii -> ", A)
Out[217]: 110

6) Transposition de matrice (similaire à np.transpose(arr))

In [221]: np.einsum("ij -> ji", A)
Out[221]: 
array([[11, 21, 31, 41],
       [12, 22, 32, 42],
       [13, 23, 33, 43],
       [14, 24, 34, 44]])

7) Produit extérieur (de vecteurs) (similaire à np.outer(vec1, vec2))

In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]: 
array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]])

8) Produit intérieur (de vecteurs) (similaire à np.inner(vec1, vec2))

In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14

9) Somme le long de l'axe 0 (similaire à np.sum(arr, axis=0))

In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])

10) Somme le long de l'axe 1 (similaire à np.sum(arr, axis=1))

In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4,  8, 12, 16])

11) Multiplication de la matrice par lots

In [287]: BM = np.stack((A, B), axis=0)

In [288]: BM
Out[288]: 
array([[[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]],

       [[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3],
        [ 4,  4,  4,  4]]])

In [289]: BM.shape
Out[289]: (2, 4, 4)

# batch matrix multiply using einsum
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)

In [293]: BMM
Out[293]: 
array([[[1350, 1400, 1450, 1500],
        [2390, 2480, 2570, 2660],
        [3430, 3560, 3690, 3820],
        [4470, 4640, 4810, 4980]],

       [[  10,   10,   10,   10],
        [  20,   20,   20,   20],
        [  30,   30,   30,   30],
        [  40,   40,   40,   40]]])

In [294]: BMM.shape
Out[294]: (2, 4, 4)

12) Somme le long de l'axe 2 (similaire à np.sum(arr, axis=2))

In [330]: np.einsum("ijk -> ij", BM)
Out[330]: 
array([[ 50,  90, 130, 170],
       [  4,   8,  12,  16]])

13) Somme tous les éléments du tableau (similaire à np.sum(arr))

In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480

14) Somme sur plusieurs axes (c.-à-d. Marginalisation)
(similaire à np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7)))

# 8D array
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))

# marginalize out axis 5 (i.e. "n" here)
In [363]: esum = np.einsum("ijklmnop -> n", R)

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))

In [365]: np.allclose(esum, nsum)
Out[365]: True

15) Double Dot Products (similaire à np.sum (hadamard-product) cf. 3 )

In [772]: A
Out[772]: 
array([[1, 2, 3],
       [4, 2, 2],
       [2, 3, 4]])

In [773]: B
Out[773]: 
array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124

16) Multiplication de tableaux 2D et 3D

Une telle multiplication pourrait être très utile lors de la résolution d'un système d'équations linéaires ( Ax = b ) où vous souhaitez vérifier le résultat.

# inputs
In [115]: A = np.random.rand(3,3)
In [116]: b = np.random.rand(3, 4, 5)

# solve for x
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)

# 2D and 3D array multiplication :)
In [118]: Ax = np.einsum('ij, jkl', A, x)

# indeed the same!
In [119]: np.allclose(Ax, b)
Out[119]: True

Au contraire, si l'on doit utiliser np.matmul()pour cette vérification, nous devons faire quelques reshapeopérations pour obtenir le même résultat comme:

# reshape 3D array `x` to 2D, perform matmul
# then reshape the resultant array to 3D
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)

# indeed correct!
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True

Bonus : Lisez plus de maths ici: Einstein-Summation et certainement ici: Tensor-Notation

kmario23
la source
7

Permet de créer 2 tableaux, avec des dimensions différentes mais compatibles pour mettre en évidence leur interaction

In [43]: A=np.arange(6).reshape(2,3)
Out[43]: 
array([[0, 1, 2],
       [3, 4, 5]])


In [44]: B=np.arange(12).reshape(3,4)
Out[44]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

Votre calcul prend un «point» (somme des produits) de a (2,3) avec un (3,4) pour produire un tableau (4,2). iest le 1er dim de A, le dernier de C; kle dernier de B, 1er de C. jest «consommé» par la sommation.

In [45]: C=np.einsum('ij,jk->ki',A,B)
Out[45]: 
array([[20, 56],
       [23, 68],
       [26, 80],
       [29, 92]])

C'est la même chose que np.dot(A,B).T- c'est la sortie finale qui est transposée.

Pour en savoir plus sur ce qui se passe j, remplacez les Cindices par ijk:

In [46]: np.einsum('ij,jk->ijk',A,B)
Out[46]: 
array([[[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [16, 18, 20, 22]],

       [[ 0,  3,  6,  9],
        [16, 20, 24, 28],
        [40, 45, 50, 55]]])

Cela peut également être produit avec:

A[:,:,None]*B[None,:,:]

Autrement dit, ajoutez une kdimension à la fin de A, et une ià l'avant de B, ce qui donne un tableau (2,3,4).

0 + 4 + 16 = 20, 9 + 28 + 55 = 92, Etc; Sommez jet transposez pour obtenir le résultat antérieur:

np.sum(A[:,:,None] * B[None,:,:], axis=1).T

# C[k,i] = sum(j) A[i,j (,k) ] * B[(i,)  j,k]
hpaulj
la source
7

J'ai trouvé NumPy: les astuces du métier (partie II) instructif

Nous utilisons -> pour indiquer l'ordre du tableau de sortie. Pensez donc à «ij, i-> j» comme ayant le côté gauche (LHS) et le côté droit (RHS). Toute répétition d'étiquettes sur le LHS calcule l'élément de produit par élément, puis additionne. En changeant l'étiquette du côté RHS (sortie), on peut définir l'axe dans lequel on veut procéder par rapport au tableau d'entrée, c'est-à-dire la sommation le long de l'axe 0, 1 et ainsi de suite.

import numpy as np

>>> a
array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
>>> b
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
>>> d = np.einsum('ij, jk->ki', a, b)

Notez qu'il y a trois axes, i, j, k, et que j est répété (sur le côté gauche). i,jreprésentent des lignes et des colonnes pour a. j,kpour b.

Afin de calculer le produit et d'aligner l' jaxe, nous devons ajouter un axe a. ( bsera diffusé le long (?) du premier axe)

a[i, j, k]
   b[j, k]

>>> c = a[:,:,np.newaxis] * b
>>> c
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 0,  2,  4],
        [ 6,  8, 10],
        [12, 14, 16]],

       [[ 0,  3,  6],
        [ 9, 12, 15],
        [18, 21, 24]]])

jest absent du côté droit donc nous additionnons sur jquel est le deuxième axe du tableau 3x3x3

>>> c = c.sum(1)
>>> c
array([[ 9, 12, 15],
       [18, 24, 30],
       [27, 36, 45]])

Enfin, les indices sont inversés (par ordre alphabétique) sur le côté droit donc nous les transposons.

>>> c.T
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])

>>> np.einsum('ij, jk->ki', a, b)
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])
>>>
la Seconde Guerre mondiale
la source
NumPy: Les astuces du métier (Partie II) semblent nécessiter une invitation du propriétaire du site ainsi qu'un compte Wordpress
Tejas Shetty
... lien mis à jour, heureusement je l'ai trouvé avec une recherche. - Thnx.
wwii
@TejasShetty Beaucoup de meilleures réponses ici maintenant - peut-être que je devrais supprimer celle-ci.
seconde guerre mondiale du
2
Veuillez ne pas supprimer votre réponse.
Tejas Shetty
5

Lors de la lecture des équations einsum, j'ai trouvé le plus utile de pouvoir les résumer mentalement à leurs versions impératives.

Commençons par la déclaration (imposante) suivante:

C = np.einsum('bhwi,bhwj->bij', A, B)

En travaillant d'abord sur la ponctuation, nous voyons que nous avons deux blobs séparés par des virgules de 4 lettres - bhwiet bhwj, avant la flèche, et un seul blob de 3 lettres bijaprès. Par conséquent, l'équation produit un résultat tenseur de rang 3 à partir de deux entrées de tenseur de rang 4.

Maintenant, laissez chaque lettre de chaque objet blob être le nom d'une variable de plage. La position à laquelle la lettre apparaît dans le blob est l'indice de l'axe sur lequel elle se situe dans ce tenseur. La sommation impérative qui produit chaque élément de C doit donc commencer par trois boucles for imbriquées, une pour chaque indice de C.

for b in range(...):
    for i in range(...):
        for j in range(...):
            # the variables b, i and j index C in the order of their appearance in the equation
            C[b, i, j] = ...

Donc, essentiellement, vous avez une forboucle pour chaque index de sortie de C. Nous laisserons les plages indéterminées pour le moment.

Ensuite , nous regardons le côté gauche - sont là toutes les variables de portée là- bas qui ne font pas apparaître sur la droite côté? Dans notre cas - oui, het w. Ajoutez une forboucle imbriquée interne pour chaque variable de ce type:

for b in range(...):
    for i in range(...):
        for j in range(...):
            C[b, i, j] = 0
            for h in range(...):
                for w in range(...):
                    ...

À l'intérieur de la boucle la plus interne, nous avons maintenant tous les indices définis, nous pouvons donc écrire la sommation réelle et la traduction est terminée:

# three nested for-loops that index the elements of C
for b in range(...):
    for i in range(...):
        for j in range(...):

            # prepare to sum
            C[b, i, j] = 0

            # two nested for-loops for the two indexes that don't appear on the right-hand side
            for h in range(...):
                for w in range(...):
                    # Sum! Compare the statement below with the original einsum formula
                    # 'bhwi,bhwj->bij'

                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]

Si vous avez pu suivre le code jusqu'à présent, félicitations! C'est tout ce dont vous avez besoin pour lire les équations einsum. Remarquez en particulier comment la formule d'origine einsum correspond à l'instruction de somme finale dans l'extrait ci-dessus. Les boucles for et les limites de plage ne sont que des peluches et cette déclaration finale est tout ce dont vous avez vraiment besoin pour comprendre ce qui se passe.

Par souci d'exhaustivité, voyons comment déterminer les plages pour chaque variable de plage. Eh bien, la plage de chaque variable est simplement la longueur de la ou des dimensions qu'elle indexe. Évidemment, si une variable indexe plus d'une dimension dans un ou plusieurs tenseurs, alors les longueurs de chacune de ces dimensions doivent être égales. Voici le code ci-dessus avec les plages complètes:

# C's shape is determined by the shapes of the inputs
# b indexes both A and B, so its range can come from either A.shape or B.shape
# i indexes only A, so its range can only come from A.shape, the same is true for j and B
assert A.shape[0] == B.shape[0]
assert A.shape[1] == B.shape[1]
assert A.shape[2] == B.shape[2]
C = np.zeros((A.shape[0], A.shape[3], B.shape[3]))
for b in range(A.shape[0]): # b indexes both A and B, or B.shape[0], which must be the same
    for i in range(A.shape[3]):
        for j in range(B.shape[3]):
            # h and w can come from either A or B
            for h in range(A.shape[1]):
                for w in range(A.shape[2]):
                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]
Stefan Dragnev
la source
0

Je pense que l'exemple le plus simple est dans les documents tensorflow

Il y a quatre étapes pour convertir votre équation en notation einsum. Prenons cette équation comme exempleC[i,k] = sum_j A[i,j] * B[j,k]

  1. Nous supprimons d'abord les noms des variables. On aik = sum_j ij * jk
  2. Nous abandonnons le sum_jterme car il est implicite. On aik = ij * jk
  3. Nous remplaçons *par ,. On aik = ij, jk
  4. La sortie est sur le RHS et est séparée par un ->signe. On aij, jk -> ik

L'interpréteur einsum exécute simplement ces 4 étapes en sens inverse. Tous les indices manquants dans le résultat sont additionnés.

Voici quelques autres exemples tirés de la documentation

# Matrix multiplication
einsum('ij,jk->ik', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]

# Dot product
einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]

# Outer product
einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]

# Transpose
einsum('ij->ji', m)  # output[j,i] = m[i,j]

# Trace
einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]

# Batch matrix multiplication
einsum('aij,ajk->aik', s, t)  # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
Souradeep Nanda
la source