Choisir l'alpha optimal dans la régression logistique nette élastique

22

J'effectue une régression logistique net élastique sur un ensemble de données de soins de santé en utilisant le glmnetpackage dans R en sélectionnant les valeurs lambda sur une grille de de 0 à 1. Mon code abrégé est ci-dessous:α

alphalist <- seq(0,1,by=0.1)
elasticnet <- lapply(alphalist, function(a){
  cv.glmnet(x, y, alpha=a, family="binomial", lambda.min.ratio=.001)
})
for (i in 1:11) {print(min(elasticnet[[i]]$cvm))}

qui génère l'erreur moyenne de validation croisée pour chaque valeur de alpha de à avec un incrément de :1,0 0,10.01.00.1

[1] 0.2080167
[1] 0.1947478
[1] 0.1949832
[1] 0.1946211
[1] 0.1947906
[1] 0.1953286
[1] 0.194827
[1] 0.1944735
[1] 0.1942612
[1] 0.1944079
[1] 0.1948874

Sur la base de ce que j'ai lu dans la littérature, le choix optimal de est l'endroit où l'erreur cv est minimisée. Mais il y a beaucoup de variation dans les erreurs sur la gamme des alphas. Je vois plusieurs minimums locaux, avec une erreur minimum globale de for .α0.1942612alpha=0.8

Est-il sûr d'y aller alpha=0.8? Ou, étant donné la variation, dois-je réexécuter cv.glmnetavec plus de plis de validation croisée (par exemple au lieu de ) ou peut-être un plus grand nombre d' incréments entre et pour obtenir une image claire du chemin d'erreur cv?10 α2010αalpha=0.01.0

RobertF
la source
5
Vous voudriez jeter un coup d'œil au caretpaquetage qui peut faire des cv et des réglages répétés pour alpha et lambda (prend en charge le traitement multicœur!). De mémoire, je pense que la glmnetdocumentation déconseille de régler l'alpha comme vous le faites ici. Il recommande de conserver les plis fixes si l'utilisateur ajuste l'alpha en plus de l'ajustement de lambda fourni par cv.glmnet.
1
Ah, j'ai
RobertF
2
n'oubliez pas de corriger le foldid lorsque vous essayez différentsα
user4581
1
Pour la reproductibilité, ne courez jamais cv.glmnet()sans passer par foldidscréé à partir d'une graine aléatoire connue.
smci
1
@amoeba jetez un œil à ma réponse - les commentaires sur les compromis entre l1 et l2 sont les bienvenus!
Xavier Bourret Sicotte

Réponses:

7

Clarification de la signification des paramètres α et Elastic Net

Différentes terminologies et paramètres sont utilisés par différents packages, mais la signification est généralement la même:

Le package R Glmnet utilise la définition suivante

minβ0,β1Nje=1Nwjel(yje,β0+βTXje)+λ[(1-α)||β||22/2+α||β||1]

Sklearn utilise

minw12Nje=1N||y-Xw||22+α×l1rapport||w||1+0,5×α×(1-l1rapport)×||w||22

Il existe également des paramétrisations alternatives utilisant et ..uneb

Pour éviter toute confusion, je vais appeler

  • λ le paramètre de force de pénalité
  • L1ratioL 1 L 2 le rapport entre la pénalité et , allant de 0 (crête) à 1 (lasso)L1L2

Visualiser l'impact des paramètres

Considérons un ensemble de données simulées où compose d'une courbe sinusoïdale bruyante et est une caractéristique bidimensionnelle composée de et . En raison de la corrélation entre et la fonction de coût est une vallée étroite.yXX1=xX2=x2X1X2

Les graphiques ci-dessous illustrent le chemin de solution de la régression élastique avec deux paramètres de rapport différents , en fonction de λ le paramètre de résistance.L1λ

  • Pour les deux simulations: lorsque λ=0 la solution est la solution OLS en bas à droite, avec la fonction de coût en forme de vallée associée.
  • À mesure que λ augmente, la régularisation entre en jeu et la solution tend à (0,0)
  • La principale différence entre les deux simulations est le paramètre de rapport L1 .
  • LHS : pour un petit rapport L1 , la fonction de coût régularisé ressemble beaucoup à la régression de Ridge avec des contours ronds.
  • RHS : pour un rapport L1 élevé, la fonction de coût ressemble beaucoup à la régression de Lasso avec les contours de forme de diamant typiques.
  • Pour le ratio intermédiaire (non illustré), la fonction de coût est un mélange des deuxL1

entrez la description de l'image ici


Comprendre l'effet des paramètres

L'ElasticNet a été introduit pour contrer certaines des limites du Lasso qui sont:

  • S'il y a plus de variables que de points de données , , le lasso sélectionne au plus variables.pnp>nn
  • Lasso ne parvient pas à effectuer une sélection groupée, en particulier en présence de variables corrélées. Il aura tendance à sélectionner une variable dans un groupe et à ignorer les autres

En combinant une pénalité et une pénalité quadratique , nous obtenons les avantages des deux:L1L2

  • L1 génère un modèle clairsemé
  • L2 supprime la limitation du nombre de variables sélectionnées, encourage le regroupement et stabilise lechemin de régularisationL1 .

Vous pouvez le voir visuellement sur le diagramme ci-dessus, les singularités au niveau des sommets encouragent la rareté , tandis que les bords convexes stricts encouragent le regroupement .

Voici une visualisation tirée de Hastie (l'inventeur d'ElasticNet)

entrez la description de l'image ici

Lectures complémentaires

Xavier Bourret Sicotte
la source
2

Permettez-moi d'ajouter quelques remarques très pratiques malgré l'âge de la question. Comme je ne suis pas un utilisateur R, je ne peux pas laisser le code parler, mais il doit néanmoins être compréhensible.

  1. αkF1,...,FkF(X)=1kjeFje(X)F(X)=je=1kFje(X)k

  2. Un des avantages du rééchantillonnage est que vous pouvez inspecter la séquence des résultats des tests, qui sont les scores du cv. Vous devez toujours non seulement regarder la moyenne mais la déviation std (ce n'est pas une distribution normale, mais vous agissez comme si). Habituellement, vous affichez ce chiffre à 65,5% (± 2,57%) pour la précision. De cette façon, vous pouvez dire si les "petites déviations" sont plus susceptibles d'être dues au hasard ou structurellement. Mieux serait même d'inspecter les séquences complètes . S'il y a toujours un repli pour une raison quelconque, vous voudrez peut-être repenser la façon dont vous effectuez votre fractionnement (cela suggère une conception expérimentale défectueuse, également: avez-vous mélangé?). Dans scikit, découvrez les GridSearchCVdétails des magasins sur les délais de pliage cv_results_( voir ici ).

  3. αL1αL2

uberwach
la source