Quelles sont la moyenne et la variance d'une normale multivariée sans censure?

9

ZN(μ,Σ)RdZ+=max(0,Z)

Cela se produit par exemple parce que, si nous utilisons la fonction d'activation ReLU à l'intérieur d'un réseau profond, et supposons via le CLT que les entrées d'une couche donnée sont approximativement normales, alors c'est la distribution des sorties.

(Je suis sûr que beaucoup de gens l'ont déjà calculé, mais je n'ai pas pu trouver le résultat répertorié n'importe où de manière raisonnablement lisible.)

Dougal
la source
Cela simplifierait votre réponse - peut-être grandement - pour observer que vous pouvez l'obtenir en combinant les résultats de deux questions distinctes: (1) quels sont les moments d'une distribution normale tronquée et (2) quels sont les moments d'un mélange ? Ce dernier est simple et tout ce que vous avez à faire est de citer des résultats pour le premier.
whuber
@whuber Hmm. Bien que je ne l'ai pas dit explicitement, c'est essentiellement ce que je fais dans ma réponse, sauf que je n'ai pas trouvé de résultats pour une distribution bivariée tronquée avec une moyenne et une variance générales et j'ai donc dû faire une mise à l'échelle et un décalage de toute façon. Existe-t-il un moyen de dériver, par exemple, la covariance sans faire la quantité d'algèbre que je devais faire? Je ne prétends certainement pas que quelque chose dans cette réponse est nouveau, juste que l'algèbre était fastidieuse et sujette aux erreurs, et peut-être que quelqu'un d'autre trouvera la solution utile.
Dougal
À droite: je suis sûr que votre algèbre équivaut à ce que j'ai décrit, il semble donc que nous partageons une appréciation pour la simplification de l'algèbre. Un moyen simple de réduire l'algèbre consiste à normaliser les éléments diagonaux de à l'unité, car tout ce que cela fait est d'établir une unité de mesure pour chaque variable. À ce stade, vous pouvez directement connecter les résultats de Rosenbaum dans les expressions (simples, évidentes) pour les moments de mélanges. Que cela vaille la peine d'une simplification algébrique peut être une question de goût: sans simplification, cela conduit à un programme informatique simple et modulaire. Σ
whuber
1
Je suppose que l'on pourrait écrire un programme qui calcule des moments directement avec les résultats de Rosenbaum et en les mélangeant de manière appropriée, puis les décale et les redimensionne dans l'espace d'origine. Cela aurait probablement été plus rapide que la façon dont je l'ai fait.
Dougal

Réponses:

7

Nous pouvons d'abord réduire cela pour ne dépendre que de certains moments de distributions normales tronquées univariées / bivariées: notons bien sûr que

E[Z+]=[E[(Zi)+]]iCov(Z+)=[Cov((Zi)+,(Zj)+)]ij,
et parce que nous faisons des transformations coordonnées de certaines dimensions d'une distribution normale, nous ne besoin de s'inquiéter de la moyenne et de la variance d'une normale 1d censurée et de la covariance de deux normales 1d censurées.

Nous utiliserons quelques résultats de

S Rosenbaum (1961). Moments d'une distribution normale bivariée tronquée . JRSS B, vol 23 pp 405-408. ( jstor )

Rosenbaum considère et considère la troncature à l'événement .V={˜XaX,˜YaY}

[X~Y~]N([00],[1ρρ1]),
V={X~aX,Y~aY}

Plus précisément, nous utiliserons les trois résultats suivants, ses (1), (3) et (5). Définissez d'abord les éléments suivants:

qx=ϕ(ax)qy=ϕ(ay)Qx=Φ(ax)Qy=Φ(ay)Rxy=Φ(ρaxay1ρ2)Ryx=Φ(ρayax1ρ2)rxy=1ρ22πϕ(h22ρhk+k21ρ2)

Maintenant, Rosenbaum montre que:

(1)Pr(V)E[X~V]=qxRxy+ρqyRyx(3)Pr(V)E[X~2V]=Pr(V)+axqxRxy+ρ2ayqyRyx+ρrxy(5)Pr(V)E[X~Y~V]=ρPr(V)+ρaxqxRxy+ρayqyRyx+rxy.

Il sera utile de considérer également le cas particulier de (1) et (3) avec , c'est-à-dire une troncature 1d: ay=

(*)Pr(V)E[X~V]=qx(**)Pr(V)E[X~2V]=Pr(V)=Qx.

Nous voulons maintenant considérer

[XY]=[μxμy]+[σx00σy][X~Y~]N([μXμY],[σx2ρσxσyρσxσyσy2])=N(μ,Σ).

Nous utiliserons qui sont les valeurs de et lorsque , .

ax=μxσxay=μyσy,
X~Y~X=0Y=0

Maintenant, en utilisant (*), nous obtenons et en utilisant à la fois (*) et (**) donne afin que

E[X+]=Pr(X+>0)E[XX>0]+Pr(X+=0)0=Pr(X>0)(μx+σxE[X~X~ax])=Qxμx+qxσx,
E[X+2]=Pr(X+>0)E[X2X>0]+Pr(X+=0)0=Pr(X~ax)E[(μx+σxX~)2X~ax]=Pr(X~ax)E[μx2+μxσxX~+σx2X~2X~ax]=Qxμx2+qxμxσx+Qxσx2
Var[X+]=E[X+2]E[X+]2=Qxμx2+qxμxσx+Qxσx2Qx2μx2qx2σx22qxQxμxσx=Qx(1Qx)μx2+(12Qx)qxμxσx+(Qxqx2)σx2.

Pour trouver , nous aurons besoin Cov(X+,Y+)

E[X+Y+]=Pr(V)E[XYV]+Pr(¬V)0=Pr(V)E[(μx+σxX~)(μy+σyY~)V]=μxμyPr(V)+μyσxPr(V)E[X~V]+μxσyPr(V)E[Y~V]+σxσyPr(V)E[X~Y~V]=μxμyPr(V)+μyσx(qxRxy+ρqyRyx)+μxσy(ρqxRxy+qyRyx)+σxσy(ρPr(V)ρμxqxRxy/σxρμyqyRyx/σy+rxy)=(μxμy+σxσyρ)Pr(V)+(μyσx+μxσyρρμxσy)qxRxy+(μyσxρ+μxσyρμyσx)qyRyx+σxσyrxy=(μxμy+Σxy)Pr(V)+μyσxqxRxy+μxσyqyRyx+σxσyrxy,
puis en soustrayant nous obtenons E[X+]E[Y+]
Cov(X+,Y+)=(μxμy+Σxy)Pr(V)+μyσxqxRxy+μxσyqyRyx+σxσyrxy(Qxμx+qxσx)(Qyμy+qyσy).

Voici du code Python pour calculer les moments:

import numpy as np
from scipy import stats

def relu_mvn_mean_cov(mu, Sigma):
    mu = np.asarray(mu, dtype=float)
    Sigma = np.asarray(Sigma, dtype=float)
    d, = mu.shape
    assert Sigma.shape == (d, d)

    x = (slice(None), np.newaxis)
    y = (np.newaxis, slice(None))

    sigma2s = np.diagonal(Sigma)
    sigmas = np.sqrt(sigma2s)
    rhos = Sigma / sigmas[x] / sigmas[y]

    prob = np.empty((d, d))  # prob[i, j] = Pr(X_i > 0, X_j > 0)
    zero = np.zeros(d)
    for i in range(d):
        prob[i, i] = np.nan
        for j in range(i + 1, d):
            # Pr(X > 0) = Pr(-X < 0); X ~ N(mu, S) => -X ~ N(-mu, S)
            s = [i, j]
            prob[i, j] = prob[j, i] = stats.multivariate_normal.cdf(
                zero[s], mean=-mu[s], cov=Sigma[np.ix_(s, s)])

    mu_sigs = mu / sigmas

    Q = stats.norm.cdf(mu_sigs)
    q = stats.norm.pdf(mu_sigs)
    mean = Q * mu + q * sigmas

    # rho_cs is sqrt(1 - rhos**2); but don't calculate diagonal, because
    # it'll just be zero and we're dividing by it (but not using result)
    # use inf instead of nan; stats.norm.cdf doesn't like nan inputs
    rho_cs = 1 - rhos**2
    np.fill_diagonal(rho_cs, np.inf)
    np.sqrt(rho_cs, out=rho_cs)

    R = stats.norm.cdf((mu_sigs[y] - rhos * mu_sigs[x]) / rho_cs)

    mu_sigs_sq = mu_sigs ** 2
    r_num = mu_sigs_sq[x] + mu_sigs_sq[y] - 2 * rhos * mu_sigs[x] * mu_sigs[y]
    np.fill_diagonal(r_num, 1)  # don't want slightly negative numerator here
    r = rho_cs / np.sqrt(2 * np.pi) * stats.norm.pdf(np.sqrt(r_num) / rho_cs)

    bit = mu[y] * sigmas[x] * q[x] * R
    cov = (
        (mu[x] * mu[y] + Sigma) * prob
        + bit + bit.T
        + sigmas[x] * sigmas[y] * r
        - mean[x] * mean[y])

    cov[range(d), range(d)] = (
        Q * (1 - Q) * mu**2 + (1 - 2 * Q) * q * mu * sigmas
        + (Q - q**2) * sigma2s)

    return mean, cov

et un test de Monte Carlo que cela fonctionne:

np.random.seed(12)
d = 4
mu = np.random.randn(d)
L = np.random.randn(d, d)
Sigma = L.T.dot(L)
dist = stats.multivariate_normal(mu, Sigma)

mn, cov = relu_mvn_mean_cov(mu, Sigma)

samps = dist.rvs(10**7)
mn_est = samps.mean(axis=0)
cov_est = np.cov(samps, rowvar=False)
print(np.max(np.abs(mn - mn_est)), np.max(np.abs(cov - cov_est)))

ce qui donne 0.000572145310512 0.00298692620286, indiquant que l'espérance et la covariance revendiquées correspondent aux estimations de Monte Carlo (sur la base de échantillons).10,000,000

Dougal
la source
pouvez-vous résumer quelles sont ces valeurs finales? S'agit-il d'estimations des paramètres mu et L que vous avez générés? Peut-être imprimer ces valeurs cibles?
AdamO
Non, les valeurs de retour sont et ; ce que j'ai imprimé était la distance entre les estimateurs de Monte Carlo de ces quantités et la valeur calculée. Vous pouvez peut-être inverser ces expressions pour obtenir un estimateur de correspondance de moment pour et - Rosenbaum le fait en fait dans sa section 3 dans le cas tronqué - mais ce n'est pas ce que je voulais ici. \ Cov ( Z + ) L μ Σ\E(Z+)\Cov(Z+)LμΣ
Dougal