Comment extraire les règles de décision de l'arbre de décision scikit-learn?

157

Puis-je extraire les règles de décision sous-jacentes (ou «chemins de décision») à partir d'un arbre formé dans un arbre de décision sous forme de liste textuelle?

Quelque chose comme:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

Merci de votre aide.

Dror Hilman
la source
Avez-vous déjà trouvé une réponse à ce problème? Je dois exporter les règles d'arbre de décision dans un format d'étape de données SAS qui est presque exactement comme vous l'avez répertorié.
Zelazny
1
Vous pouvez utiliser le package sklearn-porter pour exporter et transpiler des arbres de décision (également des forêts aléatoires et des arbres boostés) vers C, Java, JavaScript et autres.
Darius
Vous pouvez consulter ce lien- kdnuggets.com/2017/05/…
yogesh agrawal

Réponses:

139

Je pense que cette réponse est plus correcte que les autres réponses ici:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Cela imprime une fonction Python valide. Voici un exemple de sortie pour un arbre qui tente de renvoyer son entrée, un nombre compris entre 0 et 10.

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Voici quelques pierres d'achoppement que je vois dans d'autres réponses:

  1. Utiliser tree_.threshold == -2pour décider si un nœud est une feuille n'est pas une bonne idée. Et si c'était un vrai nœud de décision avec un seuil de -2? Au lieu de cela, vous devriez regarder tree.featureou tree.children_*.
  2. La ligne features = [feature_names[i] for i in tree_.feature]plante avec ma version de sklearn, car certaines valeurs de tree.tree_.featuresont -2 (spécifiquement pour les nœuds feuilles).
  3. Il n'est pas nécessaire d'avoir plusieurs instructions if dans la fonction récursive, une seule suffit.
Paulkernfeld
la source
1
Ce code fonctionne très bien pour moi. Cependant, j'ai plus de 500 noms de fonctionnalité, donc le code de sortie est presque impossible à comprendre pour un humain. Existe-t-il un moyen de me laisser entrer uniquement les noms de fonctionnalité qui me intéressent dans la fonction?
user3768495
1
Je suis d'accord avec le commentaire précédent. IIUC, print "{}return {}".format(indent, tree_.value[node])doit être remplacé par print "{}return {}".format(indent, np.argmax(tree_.value[node][0]))pour que la fonction renvoie l'index de classe.
soupault
1
@paulkernfeld Ah oui, je vois que vous pouvez boucler RandomForestClassifier.estimators_, mais je n'ai pas pu trouver comment combiner les résultats des estimateurs.
Nathan Lloyd
6
Je n'ai pas pu faire fonctionner cela dans python 3, les bits _tree ne semblent pas fonctionner et le TREE_UNDEFINED n'a pas été défini. Ce lien m'a aidé. Bien que le code exporté ne soit pas directement exécutable en python, il est similaire à C et assez facile à traduire dans d'autres langues: web.archive.org/web/20171005203850/http://www.kdnuggets.com/…
Josiah
1
@Josiah, ajoutez () aux instructions d'impression pour que cela fonctionne en python3. eg print "bla"=>print("bla")
Nir
48

J'ai créé ma propre fonction pour extraire les règles des arbres de décision créés par sklearn:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

Cette fonction commence par les nœuds (identifiés par -1 dans les tableaux enfants), puis trouve récursivement les parents. J'appelle cela la «lignée» d'un nœud. En cours de route, je saisis les valeurs dont j'ai besoin pour créer une logique SAS if / then / else:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

Les ensembles de tuples ci-dessous contiennent tout ce dont j'ai besoin pour créer des instructions SAS if / then / else. Je n'aime pas utiliser des doblocs dans SAS, c'est pourquoi je crée une logique décrivant le chemin complet d'un nœud. Le seul entier après les tuples est l'ID du nœud terminal dans un chemin. Tous les tuples précédents se combinent pour créer ce nœud.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

Sortie GraphViz de l'arborescence d'exemples

Zelazny7
la source
est-ce que ce type d'arbre est correct car col1 revient un est col1 <= 0.50000 et un col1 <= 2.5000 si oui, est-ce n'importe quel type de récursivité qui est utilisé dans la bibliothèque
jayant singh
la bonne branche aurait des enregistrements entre (0.5, 2.5]. Les arbres sont réalisés avec un partitionnement récursif. Rien n'empêche une variable d'être sélectionnée plusieurs fois.
Zelazny7
ok pouvez-vous expliquer la partie récursivité ce qui se passe xactly parce que je l'ai utilisé dans mon code et un résultat similaire est vu
jayant singh
38

J'ai modifié le code soumis par Zelazny7 pour imprimer un pseudocode:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

si vous appelez get_code(dt, df.columns)le même exemple vous obtiendrez:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}
Daniele
la source
1
Pouvez-vous dire ce que signifie exactement [[1. 0.]] dans l'instruction return dans la sortie ci-dessus. Je ne suis pas un gars de Python, mais je travaille sur le même genre de chose. Ce sera donc bien pour moi si vous prouvez s'il vous plaît quelques détails afin que ce soit plus facile pour moi.
Subhradip Bose
1
@ user3156186 Cela signifie qu'il y a un objet dans la classe '0' et zéro objet dans la classe '1'
Daniele
1
@Daniele, savez-vous comment les cours sont classés? Je suppose alphanumérique, mais je n'ai trouvé aucune confirmation.
IanS
Merci! Pour le scénario de cas de pointe où la valeur de seuil est en fait de -2, nous devrons peut-être passer (threshold[node] != -2)à ( left[node] != -1)(similaire à la méthode ci-dessous pour obtenir les identifiants des nœuds enfants)
tlingf
@Daniele, une idée comment faire de votre fonction "get_code" "return" une valeur et non pas "imprimer", parce que j'ai besoin de l'envoyer à une autre fonction?
RoyaumeIX
17

Scikit learn a introduit une nouvelle méthode délicieuse appelée export_textdans la version 0.21 (mai 2019) pour extraire les règles d'un arbre. Documentation ici . Il n'est plus nécessaire de créer une fonction personnalisée.

Une fois que vous avez adapté votre modèle, il vous suffit de deux lignes de code. Tout d'abord, importez export_text:

from sklearn.tree.export import export_text

Deuxièmement, créez un objet qui contiendra vos règles. Pour rendre les règles plus lisibles, utilisez l' feature_namesargument et transmettez une liste de vos noms de fonctionnalités. Par exemple, si votre modèle est appelé modelet que vos entités sont nommées dans un dataframe appelé X_train, vous pouvez créer un objet appelé tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Ensuite, imprimez ou enregistrez simplement tree_rules. Votre sortie ressemblera à ceci:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1
Yzerman
la source
14

Il existe une nouvelle DecisionTreeClassifierméthode ,,decision_path dans la version 0.18.0 . Les développeurs fournissent une procédure pas à pas détaillée (bien documentée) .

La première section de code de la procédure pas à pas qui imprime la structure arborescente semble être correcte. Cependant, j'ai modifié le code dans la deuxième section pour interroger un échantillon. Mes changements indiqués par# <--

Modifier Les modifications marquées par # <--dans le code ci-dessous ont depuis été mises à jour dans le lien de visite virtuelle après que les erreurs aient été signalées dans les demandes d'extraction # 8653 et # 10951 . C'est beaucoup plus facile à suivre maintenant.

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

Modifiez le sample_idpour afficher les chemins de décision des autres échantillons. Je n'ai pas interrogé les développeurs sur ces changements, je me suis simplement senti plus intuitif en travaillant sur l'exemple.

Kevin
la source
tu es mon ami une légende! des idées comment tracer l'arbre de décision pour cet échantillon spécifique? beaucoup d'aide est appréciée
1
Merci Victor, il est probablement préférable de poser cette question séparément, car le traçage des exigences peut être spécifique aux besoins d'un utilisateur. Vous obtiendrez probablement une bonne réponse si vous donnez une idée de ce à quoi vous voulez que la sortie ressemble.
Kevin le
hey kevin, j'ai créé la question stackoverflow.com/questions/48888893/…
seriez-vous si gentil de jeter un oeil à: stackoverflow.com/questions/52654280/…
Alexander Chervov
Pouvez-vous s'il vous plaît expliquer la partie appelée node_index, sans obtenir cette partie. Qu'est ce que ça fait?
Anindya Sankar Dey
12
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Vous pouvez voir un arbre digraph. Ensuite, clf.tree_.featureet clf.tree_.valuesont respectivement un tableau de fonctions de division de nœuds et un tableau de valeurs de nœuds. Vous pouvez vous référer à plus de détails à partir de cette source github .

lennon310
la source
1
Oui, je sais dessiner l'arbre - mais j'ai besoin de la version plus textuelle - les règles. quelque chose comme: orange.biolab.si/docs/latest/reference/rst/…
Dror Hilman
4

Juste parce que tout le monde a été si utile, je vais juste ajouter une modification aux belles solutions de Zelazny7 et Daniele. Celui-ci est pour python 2.7, avec des onglets pour le rendre plus lisible:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)
Ruslan
la source
3

Les codes ci-dessous sont mon approche sous anaconda python 2.7 plus un nom de package "pydot-ng" pour créer un fichier PDF avec des règles de décision. J'espère que c'est utile.

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

une représentation graphique d'arbre ici

TED Zhao
la source
3

J'ai vécu cela, mais j'avais besoin que les règles soient écrites dans ce format

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

J'ai donc adapté la réponse de @paulkernfeld (merci) que vous pouvez personnaliser à votre besoin

def tree_to_code(tree, feature_names, Y):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    pathto=dict()

    global k
    k = 0
    def recurse(node, depth, parent):
        global k
        indent = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            s= "{} <= {} ".format( name, threshold, node )
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s

            recurse(tree_.children_left[node], depth + 1, node)
            s="{} > {}".format( name, threshold)
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s
            recurse(tree_.children_right[node], depth + 1, node)
        else:
            k=k+1
            print(k,')',pathto[parent], tree_.value[node])
    recurse(0, 1, 0)
Jambon Ala
la source
3

Voici un moyen de traduire l'ensemble de l'arbre en une seule expression python (pas nécessairement trop lisible par l'homme) à l'aide de la bibliothèque SKompiler :

from skompiler import skompile
skompile(dtree.predict).to('python/code')
KT.
la source
3

Cela s'appuie sur la réponse de @paulkernfeld. Si vous avez une dataframe X avec vos fonctionnalités et une dataframe cible y avec vos réponses et que vous souhaitez avoir une idée de la valeur y terminée dans quel nœud (et aussi de la tracer en conséquence), vous pouvez faire ce qui suit:

    def tree_to_code(tree, feature_names):
        from sklearn.tree import _tree
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

pas la version la plus élégante mais elle fait le travail ...

fer à cheval
la source
1
C'est une bonne approche lorsque vous souhaitez renvoyer les lignes de code au lieu de simplement les imprimer.
Hajar Homayouni
3

C'est le code dont vous avez besoin

J'ai modifié le code préféré pour indenter correctement dans un notebook jupyter python 3

import numpy as np
from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [feature_names[i] 
                    if i != _tree.TREE_UNDEFINED else "undefined!" 
                    for i in tree_.feature]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "    " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, np.argmax(tree_.value[node])))

    recurse(0, 1)
Cameron Sorensen
la source
2

Voici une fonction, imprimant les règles d'un arbre de décision scikit-learn sous python 3 et avec des décalages pour les blocs conditionnels pour rendre la structure plus lisible:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)
Apogentus
la source
2

Vous pouvez également le rendre plus informatif en le distinguant à quelle classe il appartient ou même en mentionnant sa valeur de sortie.

def print_decision_tree(tree, feature_names, offset_unit='    '):    
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
    features  = ['f%d'%i for i in tree.tree_.feature]
else:
    features  = [feature_names[i] for i in tree.tree_.feature]        

def recurse(left, right, threshold, features, node, depth=0):
        offset = offset_unit*depth
        if (threshold[node] != -2):
                print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node],depth+1)
                print(offset+"} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node],depth+1)
                print(offset+"}")
        else:
                #print(offset,value[node]) 

                #To remove values from node
                temp=str(value[node])
                mid=len(temp)//2
                tempx=[]
                tempy=[]
                cnt=0
                for i in temp:
                    if cnt<=mid:
                        tempx.append(i)
                        cnt+=1
                    else:
                        tempy.append(i)
                        cnt+=1
                val_yes=[]
                val_no=[]
                res=[]
                for j in tempx:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_no.append(j)
                for j in tempy:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_yes.append(j)
                val_yes = int("".join(map(str, val_yes)))
                val_no = int("".join(map(str, val_no)))

                if val_yes>val_no:
                    print(offset,'\033[1m',"YES")
                    print('\033[0m')
                elif val_no>val_yes:
                    print(offset,'\033[1m',"NO")
                    print('\033[0m')
                else:
                    print(offset,'\033[1m',"Tie")
                    print('\033[0m')

recurse(left, right, threshold, features, 0,0)

entrez la description de l'image ici

Amit Rautray
la source
2

Voici mon approche pour extraire les règles de décision sous une forme qui peut être utilisée directement dans SQL, afin que les données puissent être regroupées par nœud. (Basé sur les approches des affiches précédentes.)

Le résultat sera des CASEclauses ultérieures qui peuvent être copiées dans une instruction SQL, ex.

SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN <conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>


import numpy as np

import pickle
feature_names=.............
features  = [feature_names[i] for i in range(len(feature_names))]
clf= pickle.loads(trained_model)
impurity=clf.tree_.impurity
importances = clf.feature_importances_
SqlOut=""

#global Conts
global ContsNode
global Path
#Conts=[]#
ContsNode=[]
Path=[]
global Results
Results=[]

def print_decision_tree(tree, feature_names, offset_unit=''    ''):    
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value

    if feature_names is None:
        features  = [''f%d''%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0,ParentNode=0,IsElse=0):
        global Conts
        global ContsNode
        global Path
        global Results
        global LeftParents
        LeftParents=[]
        global RightParents
        RightParents=[]
        for i in range(len(left)): # This is just to tell you how to create a list.
            LeftParents.append(-1)
            RightParents.append(-1)
            ContsNode.append("")
            Path.append("")


        for i in range(len(left)): # i is node
            if (left[i]==-1 and right[i]==-1):      
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not " +ContsNode[RightParents[i]]                     
                Results.append(" case when  " +Path[i]+"  then ''" +"{:4d}".format(i)+ " "+"{:2.2f}".format(impurity[i])+" "+Path[i][0:180]+"''")

            else:       
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not "+ContsNode[RightParents[i]]                      
                if (left[i]!=-1):
                    LeftParents[left[i]]=i
                if (right[i]!=-1):
                    RightParents[right[i]]=i
                ContsNode[i]=   "( "+ features[i] + " <= " + str(threshold[i])   + " ) "

    recurse(left, right, threshold, features, 0,0,0,0)
print_decision_tree(clf,features)
SqlOut=""
for i in range(len(Results)): 
    SqlOut=SqlOut+Results[i]+ " end,"+chr(13)+chr(10)
Flingue
la source
1

Vous pouvez maintenant utiliser export_text.

from sklearn.tree import export_text

r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)

Un exemple complet de [sklearn] [1]

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)
Kevin
la source
0

Modification du code de Zelazny7 pour récupérer SQL à partir de l'arbre de décision.

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'
Arslán
la source
0

Apparemment, il y a longtemps, quelqu'un a déjà décidé d'essayer d'ajouter la fonction suivante aux fonctions d'exportation d'arborescence officielle de scikit (qui ne supporte fondamentalement que export_graphviz)

def export_dict(tree, feature_names=None, max_depth=None) :
    """Export a decision tree in dict format.

Voici son engagement complet:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

Je ne sais pas exactement ce qui est arrivé à ce commentaire. Mais vous pouvez également essayer d'utiliser cette fonction.

Je pense que cela justifie une demande de documentation sérieuse aux bonnes personnes de scikit-learn pour documenter correctement le sklearn.tree.Tree API qui est l'arborescence sous-jacente qui DecisionTreeClassifierexpose comme son attribut tree_.

Aris Koning
la source
0

Utilisez simplement la fonction de sklearn.tree comme ceci

from sklearn.tree import export_graphviz
    export_graphviz(tree,
                out_file = "tree.dot",
                feature_names = tree.columns) //or just ["petal length", "petal width"]

Et puis cherchez dans votre dossier de projet le fichier tree.dot , copiez TOUT le contenu et collez-le ici http://www.webgraphviz.com/ et générez votre graphique :)

escalier
la source
0

Merci pour la merveilleuse solution de @paulkerfeld. En plus de sa solution, pour tous ceux qui veulent avoir une version sérialisée d'arbres, utilisez simplement tree.threshold, tree.children_left, tree.children_right, tree.featureet tree.value. Étant donné que les feuilles n'ont pas de fractionnement et donc pas de noms de caractéristiques et d'enfants, leur espace réservé dans tree.featureet tree.children_***sont _tree.TREE_UNDEFINEDet _tree.TREE_LEAF. Chaque division se voit attribuer un index unique par depth first search.
Notez que le tree.valueest de forme[n, 1, 1]

Yanqi Huang
la source
0

Voici une fonction qui génère du code Python à partir d'un arbre de décision en convertissant la sortie de export_text:

import string
from sklearn.tree import export_text

def export_py_code(tree, feature_names, max_depth=100, spacing=4):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)

    # Second: generate Python code from the text
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'def decision_tree({}):\n'.format(', '.join(features))
    for line in repr(tree).split('\n'):
        code += skip + "# " + line + '\n'
    for line in res.split('\n'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g}:'.format(line, float(val))
        else:
            line = line.replace(' {} class:'.format(dash), 'return')
        code += skip + line + '\n'

    return code

Exemple d'utilisation:

res = export_py_code(tree, feature_names=names, spacing=4)
print (res)

Exemple de sortie:

def decision_tree(f1, f2, f3):
    # DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
    #                        max_features=None, max_leaf_nodes=None,
    #                        min_impurity_decrease=0.0, min_impurity_split=None,
    #                        min_samples_leaf=1, min_samples_split=2,
    #                        min_weight_fraction_leaf=0.0, presort=False,
    #                        random_state=42, splitter='best')
    if f1 <= 12.5:
        if f2 <= 17.5:
            if f1 <= 10.5:
                return 2
            if f1 > 10.5:
                return 3
        if f2 > 17.5:
            if f2 <= 22.5:
                return 1
            if f2 > 22.5:
                return 1
    if f1 > 12.5:
        if f1 <= 17.5:
            if f3 <= 23.5:
                return 2
            if f3 > 23.5:
                return 3
        if f1 > 17.5:
            if f1 <= 25:
                return 1
            if f1 > 25:
                return 2

L'exemple ci-dessus est généré avec names = ['f'+str(j+1) for j in range(NUM_FEATURES)] .

Une caractéristique pratique est qu'il peut générer une taille de fichier plus petite avec un espacement réduit. Juste réglé spacing=2.

Andriy Makukha
la source