# How to extract the decision rules from scikit-learn decision-tree?

Posted on

### Question :

How to extract the decision rules from scikit-learn decision-tree?

Can I extract the underlying decision-rules (or ‘decision paths’) from a trained tree in a decision tree as a textual list?

Something like:

`if A>0.4 then if B<0.2 then if C>0.8 then class='X'`

I believe that this answer is more correct than the other answers here:

``````from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print "def tree({}):".format(", ".join(feature_names))

def recurse(node, depth):
indent = "  " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print "{}if {} <= {}:".format(indent, name, threshold)
recurse(tree_.children_left[node], depth + 1)
print "{}else:  # if {} > {}".format(indent, name, threshold)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(indent, tree_.value[node])

recurse(0, 1)
``````

This prints out a valid Python function. Here’s an example output for a tree that is trying to return its input, a number between 0 and 10.

``````def tree(f0):
if f0 <= 6.0:
if f0 <= 1.5:
return [[ 0.]]
else:  # if f0 > 1.5
if f0 <= 4.5:
if f0 <= 3.5:
return [[ 3.]]
else:  # if f0 > 3.5
return [[ 4.]]
else:  # if f0 > 4.5
return [[ 5.]]
else:  # if f0 > 6.0
if f0 <= 8.5:
if f0 <= 7.5:
return [[ 7.]]
else:  # if f0 > 7.5
return [[ 8.]]
else:  # if f0 > 8.5
return [[ 9.]]
``````

Here are some stumbling blocks that I see in other answers:

1. Using `tree_.threshold == -2` to decide whether a node is a leaf isn’t a good idea. What if it’s a real decision node with a threshold of -2? Instead, you should look at `tree.feature` or `tree.children_*`.
2. The line `features = [feature_names[i] for i in tree_.feature]` crashes with my version of sklearn, because some values of `tree.tree_.feature` are -2 (specifically for leaf nodes).
3. There is no need to have multiple if statements in the recursive function, just one is fine.

I created my own function to extract the rules from the decision trees created by sklearn:

``````import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)
``````

This function first starts with the nodes (identified by -1 in the child arrays) and then recursively finds the parents. I call this a node’s ‘lineage’. Along the way, I grab the values I need to create if/then/else SAS logic:

``````def get_lineage(tree, feature_names):
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
features  = [feature_names[i] for i in tree.tree_.feature]

# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]

def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'

lineage.append((parent, split, threshold[parent], features[parent]))

if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)

for child in idx:
for node in recurse(left, right, child):
print node
``````

The sets of tuples below contain everything I need to create SAS if/then/else statements. I do not like using `do` blocks in SAS which is why I create logic describing a node’s entire path. The single integer after the tuples is the ID of the terminal node in a path. All of the preceding tuples combine to create that node.

``````In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6
``````

I modified the code submitted by Zelazny7 to print some pseudocode:

``````def get_code(tree, feature_names):
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
features  = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value

def recurse(left, right, threshold, features, node):
if (threshold[node] != -2):
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node])
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node])
print "}"
else:
print "return " + str(value[node])

recurse(left, right, threshold, features, 0)
``````

if you call `get_code(dt, df.columns)` on the same example you will obtain:

``````if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}
``````

Scikit learn introduced a delicious new method called `export_text` in version 0.21 (May 2019) to extract the rules from a tree. Documentation here. It’s no longer necessary to create a custom function.

Once you’ve fit your model, you just need two lines of code. First, import `export_text`:

``````from sklearn.tree import export_text
``````

Second, create an object that will contain your rules. To make the rules look more readable, use the `feature_names` argument and pass a list of your feature names. For example, if your model is called `model` and your features are named in a dataframe called `X_train`, you could create an object called `tree_rules`:

``````tree_rules = export_text(model, feature_names=list(X_train.columns))
``````

Then just print or save `tree_rules`. Your output will look like this:

``````|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1
``````

There is a new `DecisionTreeClassifier` method, `decision_path`, in the 0.18.0 release. The developers provide an extensive (well-documented) walkthrough.

The first section of code in the walkthrough that prints the tree structure seems to be OK. However, I modified the code in the second section to interrogate one sample. My changes denoted with `# <--`

Edit The changes marked by `# <--` in the code below have since been updated in walkthrough link after the errors were pointed out in pull requests #8653 and #10951. It’s much easier to follow along now.

``````sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

if leave_id[sample_id] == node_id:  # <-- changed != to ==
#continue # <-- comment out
print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

else: # < -- added else to iterate through decision nodes
if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
threshold_sign = "<="
else:
threshold_sign = ">"

print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
% (node_id,
sample_id,
feature[node_id],
X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
threshold_sign,
threshold[node_id]))

Rules used to predict sample 0:
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here
``````

Change the `sample_id` to see the decision paths for other samples. I haven’t asked the developers about these changes, just seemed more intuitive when working through the example.

``````from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()
``````

You can see a digraph Tree. Then, `clf.tree_.feature` and `clf.tree_.value` are array of nodes splitting feature and array of nodes values respectively. You can refer to more details from this github source.

Just because everyone was so helpful I’ll just add a modification to Zelazny7 and Daniele’s beautiful solutions. This one is for python 2.7, with tabs to make it more readable:

``````def get_code(tree, feature_names, tabdepth=0):
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
features  = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value

def recurse(left, right, threshold, features, node, tabdepth=0):
if (threshold[node] != -2):
print 't' * tabdepth,
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node], tabdepth+1)
print 't' * tabdepth,
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node], tabdepth+1)
print 't' * tabdepth,
print "}"
else:
print 't' * tabdepth,
print "return " + str(value[node])

recurse(left, right, threshold, features, 0)
``````

I’ve been going through this, but i needed the rules to be written in this format

``````if A>0.4 then if B<0.2 then if C>0.8 then class='X'
``````

``````def tree_to_code(tree, feature_names, Y):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
pathto=dict()

global k
k = 0
def recurse(node, depth, parent):
global k
indent = "  " * depth

if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
s= "{} <= {} ".format( name, threshold, node )
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+' & ' +s

recurse(tree_.children_left[node], depth + 1, node)
s="{} > {}".format( name, threshold)
if node == 0:
pathto[node]=s
else:
pathto[node]=pathto[parent]+' & ' +s
recurse(tree_.children_right[node], depth + 1, node)
else:
k=k+1
print(k,')',pathto[parent], tree_.value[node])
recurse(0, 1, 0)
``````