Decision Trees

Adapted from Machine Learning Recipes #8 by Google Developers.

We'll write a Decision Tree Classifier from scratch.


Format: Each row is an example. The last column is the label. The first two columns are features.

Feel free to play with it by adding more features & examples.

Interesting note: The 2nd and 5th examples have the same features, but different labels - so we can see how the tree handles this case.

In [ ]:
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
In [ ]:
header = ["color", "diameter", "label"] #Column labels. These are used only to print the tree.

Helper Functions

In [ ]:
def unique_vals(rows, col):
    """Find the unique values for a column in a dataset."""
    return None
In [ ]:
# Demo:
unique_vals(training_data, 0)
# unique_vals(training_data, 1)
In [ ]:
def class_counts(rows):
    """Counts the number of each type of example in a dataset."""
    counts = {}  # a dictionary of label -> count.
    for row in rows: # in our dataset format, the label is always the last column
        pass #TODO
    return counts
In [ ]:
# Demo:
In [ ]:
def is_numeric(value):
    """Test if a value is numeric."""
    return None  #Use python function isinstance
In [ ]:
# Demo:
# is_numeric("Red")

Creating a split

In [ ]:
class Question:
    """A Question is used to partition a dataset.

    This class just records a 'column number' (e.g., 0 for Color) and a
    'column value' (e.g., Green). The 'match' method is used to compare
    the feature value in an example to the feature value stored in the

    def __init__(self, column, value):
        self.column = column
        self.value = value

    def match(self, example):
        # Compare the feature value in an example to the
        # feature value in this question.
        # If the value is numeric type the comparision should be >= or <=
        # For non-numeric values it should be ==
        return None

    def __repr__(self):
        # This is just a helper method to print
        # the question in a readable format.
        condition = "=="
        if is_numeric(self.value):
            condition = ">="
        return "Is %s %s %s?" % (
            header[self.column], condition, str(self.value))
In [ ]:
# Demo:
# Let's write a question for a numeric attribute
Question(1, 3)
In [ ]:
# How about one for a categorical attribute
q = Question(0, 'Green')
In [ ]:
# Let's pick an example from the training set...
example = training_data[0]
# ... and see if it matches the question
q.match(example) # this will be true, since the first example is Green.
In [ ]:
def partition(rows, question):
    """Partitions a dataset.

    For each row in the dataset, check if it matches the question. If
    so, add it to 'true rows', otherwise, add it to 'false rows'.
    true_rows, false_rows = [], []
    for row in rows:
        pass #TODO
    return true_rows, false_rows
In [ ]:
# Demo:
# Let's partition the training data based on whether rows are Red.
true_rows, false_rows = partition(training_data, Question(0, 'Red'))
# This will contain all the 'Red' rows.
In [ ]:
# This will contain everything else.

Calculating Impurity

In [ ]:
def gini(rows):
    """Calculate the Gini Impurity for a list of rows.

    There are a few different ways to do this, see:
    impurity = None
    return impurity
In [ ]:
# Demo:
# Let's look at some example to understand how Gini Impurity works.
# First, we'll look at a dataset with no mixing.
no_mixing = [['Apple'],
# this will return 0
In [ ]:
# Now, we'll look at dataset with a 50:50 apples:oranges ratio
some_mixing = [['Apple'],
# this will return 0.5 - meaning, there's a 50% chance of misclassifying
# a random example we draw from the dataset.
In [ ]:
# Now, we'll look at a dataset with many different labels
lots_of_mixing = [['Apple'],
# This will return 0.8

Calculation Information gain

In [ ]:
def info_gain(left, right, current_uncertainty):
    """Information Gain.

    The uncertainty of the starting node, minus the weighted impurity of
    two child nodes.
    return None
In [ ]:
# Demo:
# Calculate the uncertainy of our training data.
current_uncertainty = gini(training_data)
In [ ]:
# How much information do we gain by partioning on 'Green'?
true_rows, false_rows = partition(training_data, Question(0, 'Green'))
info_gain(true_rows, false_rows, current_uncertainty)
In [ ]:
# What about if we partioned on 'Red' instead?
true_rows, false_rows = partition(training_data, Question(0,'Red'))
info_gain(true_rows, false_rows, current_uncertainty)
In [ ]:
# It looks like we learned more using 'Red' (0.37), than 'Green' (0.14).
# Why? Look at the different splits that result, and see which one
# looks more 'unmixed' to you.
true_rows, false_rows = partition(training_data, Question(0,'Red'))

# Here, the true_rows contain only 'Grapes'.
In [ ]:
# And the false rows contain two types of fruit. Not too bad.
In [ ]:
# On the other hand, partitioning by Green doesn't help so much.
true_rows, false_rows = partition(training_data, Question(0,'Green'))

# We've isolated one apple in the true rows.
In [ ]:
# But, the false-rows are badly mixed up.

Finding the best split

In [ ]:
def find_best_split(rows):
    """Find the best question to ask by iterating over every feature / value
    and calculating the information gain."""
    best_gain = 0  # keep track of the best information gain
    best_question = None  # keep train of the feature / value that produced it
    current_uncertainty = gini(rows)
    n_features = len(rows[0]) - 1  # number of columns

    for col in range(n_features):  # for each feature

        values = None  # unique values in the column

        for val in values:  # for each value

            question = None

            # try splitting the dataset
            true_rows, false_rows = None

            # Skip this split if it doesn't divide the
            # dataset.
            if '''condition''' :

            # Calculate the information gain from this split
            gain = None

            # You actually can use '>' instead of '>=' here
            # but I wanted the tree to look a certain way for our
            # toy dataset.
            if gain >= best_gain:
                best_gain, best_question = gain, question

    return best_gain, best_question
In [ ]:
# Demo:
# Find the best question to ask first for our toy dataset.
best_gain, best_question = find_best_split(training_data)
# FYI: is color == Red is just as good. See the note in the code above
# where I used '>='.

Defining Node Objects

In [ ]:
class Leaf:
    """A Leaf node classifies data.

    This holds a dictionary of class (e.g., "Apple") -> number of times
    it appears in the rows from the training data that reach this leaf.

    def __init__(self, rows):
        self.predictions = class_counts(rows)
In [ ]:
class Decision_Node:
    """A Decision Node asks a question.

    This holds a reference to the question, and to the two child nodes.

    def __init__(self,
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

Building a Decision Tree

In [ ]:
def build_tree(rows):
    """Builds the tree.

    Rules of recursion: 1) Believe that it works. 2) Start by checking
    for the base case (no further information gain). 3) Prepare for
    giant stack traces.

    # Try partitioing the dataset on each of the unique attribute,
    # calculate the information gain,
    # and return the question that produces the highest gain.
    gain, question = None

    # Base case: no further info gain
    # Since we can ask no further questions,
    # we'll return a leaf.
    if gain == 0:
        return None

    # If we reach here, we have found a useful feature / value
    # to partition on.
    true_rows, false_rows = None

    # Recursively build the true branch.
    true_branch = None

    # Recursively build the false branch.
    false_branch = None

    # Return a Question node.
    # This records the best feature / value to ask at this point,
    # as well as the branches to follow
    # depending on the answer.
    return None
In [ ]:
def print_tree(node, spacing=""):
    """World's most elegant tree printing function."""

    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        print (spacing + "Predict", node.predictions)

    # Print the question at this node
    print (spacing + str(node.question))

    # Call this function recursively on the true branch
    print (spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")

    # Call this function recursively on the false branch
    print (spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")
In [ ]:
my_tree = build_tree(training_data)
In [ ]:
In [ ]:
def classify(row, node):
    """See the 'rules of recursion' above."""

    # Base case: we've reached a leaf. return prediction
    # Decide whether to follow the true-branch or the false-branch.
    # Compare the feature / value stored in the node,
    # to the example we're considering.
    return None
In [ ]:
# Demo:
# The tree predicts the 1st row of our
# training data is an apple with confidence 1.
classify(training_data[0], my_tree)
In [ ]:
def print_leaf(counts):
    """A nicer way to print the predictions at a leaf."""
    total = sum(counts.values()) * 1.0
    probs = {}
    for lbl in counts.keys():
        probs[lbl] = str(int(counts[lbl] / total * 100)) + "%"
    return probs
In [ ]:
# Demo:
# Printing that a bit nicer
print_leaf(classify(training_data[0], my_tree))
In [ ]:
# Demo:
# On the second example, the confidence is lower
print_leaf(classify(training_data[1], my_tree))
In [ ]:
# Evaluate
testing_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 4, 'Apple'],
    ['Red', 2, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
In [ ]:
for row in testing_data:
    print ("Actual: %s. Predicted: %s" %
           (row[-1], print_leaf(classify(row, my_tree))))
In [ ]: