Qu'est-ce que l'activation GELU?

18

Je parcourais le papier BERT qui utilise GELU (Gaussian Error Linear Unit) qui énonce l'équation comme qui à son tour est approximativement à

gELU(X)=XP(XX)=XΦ(X).
0,5X(1+tunenh[2/π(X+0,044715X3)])

Pourriez-vous simplifier l'équation et expliquer comment elle a été approximée.

thanatoz
la source

Réponses:

19

Fonction GELU

Nous pouvons étendre la distribution cumulative de N(0,1) , c'est-à-dire Φ(X) , comme suit:

GELU(X): =XP(XX)=XΦ(X)=0,5X(1+erf(X2))

Notez que c'est une définition , pas une équation (ou une relation). Les auteurs ont fourni quelques justifications à cette proposition, par exemple une analogie stochastique , mais mathématiquement, ce n'est qu'une définition.

Voici l'intrigue de GELU:

Approximation de Tanh

Pour ce type d'approximations numériques, l'idée clé est de trouver une fonction similaire (principalement basée sur l'expérience), de la paramétrer, puis de l'adapter à un ensemble de points de la fonction d'origine.

Sachant que est très proche deerf(X)tanh(X)

et la première dérivée de coïncide avec celle de à , qui est , nous procédons à l'ajustement de (ou avec plus de termes) à un ensemble de points .erf(X2)tanh(2πX)X=02π

tanh(2π(X+uneX2+bX3+cX4+X5))
(xi,erf(xi2))

J'ai ajusté cette fonction à 20 échantillons entre (en utilisant ce site ), et voici les coefficients:(1.5,1.5)

En définissant , été estimé à . Avec plus d'échantillons d'une gamme plus large (ce site ne permettait que 20), le coefficient sera plus proche de du papier . Enfin, nous obtenonsa=c=d=0b0.04495641b0.044715

GELU(x)=xΦ(x)=0.5x(1+erf(x2))0.5x(1+tanh(2π(x+0.044715x3)))

avec une erreur quadratique moyenne pour .108x[10,10]

Notez que si nous n'avions pas utilisé la relation entre les premières dérivées, le terme aurait été inclus dans les paramètres comme suit ce qui est moins beau (moins analytique, plus numérique)!2π

0.5x(1+tanh(0.797885x+0.035677x3))

Utiliser la parité

Comme suggéré par @BookYourLuck , nous pouvons utiliser la parité des fonctions pour restreindre l'espace des polynômes dans lesquels nous recherchons. C'est-à-dire que est une fonction impaire, c'est-à-dire , et est également une fonction impaire, la fonction polynomiale intérieur doit également être impair (ne doit avoir que des pouvoirs impairs de ) pour avoir erff(-X)=-F(X)tanhpol(X)tanhX

erf(-X)tanh(pol(-X))=tanh(-pol(X))=-tanh(pol(X))-erf(X)

Auparavant, nous avons eu la chance de nous retrouver avec des coefficients (presque) nuls pour les puissances paires et , mais en général, cela pourrait conduire à des approximations de faible qualité qui, par exemple, ont un terme comme qui est annulé par des termes supplémentaires (pairs ou impairs) au lieu d'opter simplement pour .X2X40,23X20X2

Approximation sigmoïde

Une relation similaire existe entre et (sigmoid), qui est proposé dans l'article comme une autre approximation, avec erreur quadratique moyenne pour .erf(X)2(σ(x)12)104x[10,10]

Voici un code Python pour générer des points de données, ajuster les fonctions et calculer les erreurs quadratiques moyennes:

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

Production:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05
Esmailian
la source
2
Pourquoi l'approximation est-elle nécessaire? Ne pourraient-ils pas simplement utiliser la fonction erf?
SebiSebi
8

Φ(X)=12erFc(-X2)=12(1+erF(X2))
erF
erF(X2)tanh(2π(X+uneX3))
une0,044715

X[-1,1]X

tanh(X)=X-X33+o(X3)
erF(X)=2π(X-X33)+o(X3).
tanh(2π(X+uneX3))=2π(X+(une-23π)X3)+o(X3)
erF(X2)=2π(X-X36)+o(X3).
X3
une0,04553992412
0,044715

BookYourLuck
la source