J'expérimente l'algorithme de la machine de renforcement de gradient via le caret
package en R.
À l'aide d'un petit ensemble de données d'admission à l'université, j'ai exécuté le code suivant:
library(caret)
### Load admissions dataset. ###
mydata <- read.csv("http://www.ats.ucla.edu/stat/data/binary.csv")
### Create yes/no levels for admission. ###
mydata$admit_factor[mydata$admit==0] <- "no"
mydata$admit_factor[mydata$admit==1] <- "yes"
### Gradient boosting machine algorithm. ###
set.seed(123)
fitControl <- trainControl(method = 'cv', number = 5, summaryFunction=defaultSummary)
grid <- expand.grid(n.trees = seq(5000,1000000,5000), interaction.depth = 2, shrinkage = .001, n.minobsinnode = 20)
fit.gbm <- train(as.factor(admit_factor) ~ . - admit, data=mydata, method = 'gbm', trControl=fitControl, tuneGrid=grid, metric='Accuracy')
plot(fit.gbm)
et j'ai constaté à ma grande surprise que la précision de validation croisée du modèle a diminué plutôt qu'augmenté à mesure que le nombre d'itérations de renforcement augmentait, atteignant une précision minimale d'environ 0,59 à ~ 450 000 itérations.
Ai-je mal implémenté l'algorithme GBM?
EDIT: Suite à la suggestion d'Underminer, j'ai réexécuté le caret
code ci-dessus mais concentré sur l'exécution de 100 à 5 000 itérations de boost:
set.seed(123)
fitControl <- trainControl(method = 'cv', number = 5, summaryFunction=defaultSummary)
grid <- expand.grid(n.trees = seq(100,5000,100), interaction.depth = 2, shrinkage = .001, n.minobsinnode = 20)
fit.gbm <- train(as.factor(admit_factor) ~ . - admit, data=mydata, method = 'gbm', trControl=fitControl, tuneGrid=grid, metric='Accuracy')
plot(fit.gbm)
Le graphique résultant montre que la précision culmine à près de 0,705 à environ 1800 itérations:
Ce qui est curieux, c'est que la précision n'a pas atteint un plateau à ~ 0,70 mais a plutôt diminué après 5 000 itérations.
la source
Codes pour reproduire un résultat similaire, sans recherche de grille,
la source
Le paquet gbm a une fonction pour estimer le nombre optimal d'itérations (= nombre d'arbres ou nombre de fonctions de base),
Vous n'avez pas besoin du train de caret pour cela.
la source