What Is Information Gain?

Information Gain, or IG for short, measures the reduction in entropy or surprise by splitting a dataset according to a given value of a random variable. A larger information gain suggests a lower entropy group or groups of samples, and hence less surprise.

Entropy quantifies how much information there is in a random variable, or more specifically its probability distribution. A skewed distribution has a low entropy, a distribution where events have equal probability will have larger entropy.

Information gain is commonly used in the construction of decision trees from a training dataset, by evaluating the information gain for each variable, and selecting the variable that maximizes the information gain between the parent node and its split nodes, which in turn minimizes the entropy and best splits the dataset into groups for best classification.

  • In a binary case:
    • Entropy is 0 if all samples belong to the same class for a node (i.e., pure)
    • Entropy is 1 samples contain both classes for a node (i.e., 50% for each class awful)

Alt text that describes the graphic

Alt text that describes the graphic

Information quantifies how surprising an event is in bits. Lower probability events have more information, higher probability events have less information.

  • Skewed Probability Distribution (unsurprising): Low entropy.

  • Balanced Probability Distribution (surprising): High entropy.

  • Information gain can also be used for feature selection, by evaluating the gain of each variable in the context of the target variable. In this slightly different usage, the calculation is referred to as mutual information between the two random variables.

Entropy example

In a binary classification problem, we can calculate the entropy of the data sample as follows:

  • Entropy = -(p(0) * log(P(0)) + p(1) * log(P(1)))

A dataset with a 50/50 split of samples for the two classes would have a maximum entropy (maximum surprise) of 1 bit, whereas an imbalanced dataset with a split of 10/90 would have a smaller entropy as there would be less surprise for a randomly drawn example from the dataset.

In [1]:
# calculate the entropy for a dataset
from math import log2

# proportion of examples in each class
class0 = 10/100
class1 = 90/100

# calculate entropy
entropy = -(class0 * log2(class0) + class1 * log2(class1))

# print the result
print(f'Entropy: {round(entropy,3)} bits')
Entropy: 0.469 bits

Entropy can be used as a calculation of the purity of a dataset, e.g. how balanced the distribution of classes happens to be.

Information gain provides a way to use entropy to calculate how a change to the dataset impacts the purity of the dataset, e.g. the distribution of classes. A smaller entropy suggests more purity or less surprise.

Calculating Information Gain

We can define a function to calculate the entropy of a group of samples based on the ratio of samples that belong to class 0 and class 1.

In [3]:
# calculate the entropy for the split in the dataset
def entropy(class0, class1):
    return -(class0 * log2(class0) + class1 * log2(class1))

Now, consider a dataset with 20 examples, 13 for class 0 and 7 for class 1. We can calculate the entropy for this dataset, which will have less than 1 bit.

In [4]:
# split of the main dataset
class0 = 13 / 20
class1 = 7 / 20

# calculate entropy before the change
s_entropy = entropy(class0, class1)
print('Dataset Entropy: %.3f bits' % s_entropy)
Dataset Entropy: 0.934 bits

Now consider that one of the variables in the dataset has 2 unique values, say value1 and value2. We are interested in calculating the information gain of this variable.

Let’s assume that if we split the dataset by value1, we have a group of 8 samples, 7 for class 0 and 1 for class 1. We can then calculate the entropy of this group of samples.

In [5]:
# split 1 (split via value1)
s1_class0 = 7 / 8
s1_class1 = 1 / 8

# calculate the entropy of the first group
s1_entropy = entropy(s1_class0, s1_class1)
print('Group1 Entropy: %.3f bits' % s1_entropy)
Group1 Entropy: 0.544 bits

Now, let’s assume that we split the dataset by value2; we have a group of 12 samples with 6 in each group. We would expect this group to have an entropy of 1.

In [6]:
# split 2  (split via value2)
s2_class0 = 6 / 12
s2_class1 = 6 / 12

# calculate the entropy of the second group
s2_entropy = entropy(s2_class0, s2_class1)
print('Group2 Entropy: %.3f bits' % s2_entropy)
Group2 Entropy: 1.000 bits

Lastly, we can calculate the information gain for this variable based on the groups created for each value of the variable and the calculated entropy.

The first variable resulted in a group of 8 examples from the dataset, and the second group had the remaining 12 samples in the data set. Therefore, we have everything we need to calculate the information gain.

In this case, information gain can be calculated as:

  • Entropy(Dataset) – Count(Group1) / Count(Dataset) * Entropy(Group1) + Count(Group2) / Count(Dataset) * Entropy(Group2)

Or:

  • Entropy(13/20, 7/20) – 8/20 * Entropy(7/8, 1/8) + 12/20 * Entropy(6/12, 6/12)
In [7]:
# calculate the information gain
gain = s_entropy - (8/20 * s1_entropy + 12/20 * s2_entropy)
print('Information Gain: %.3f bits' % gain)
Information Gain: 0.117 bits

Using information theory to evaluate features

The mutual information (MI) between a feature and the outcome is a measure of the mutual dependence between the two variables. It extends the notion of correlation to nonlinear relationships. More specifically, it quantifies the information obtained about one random variable through the other random variable. MI determines how different the joint distribution of the pair (X,Y) is to the product of the marginal distributions of X and Y.

The sklearn function implements feature_selection.mutual_info_regression that computes the mutual information between all features and a continuous outcome to select the features that are most likely to contain predictive information. There is also a classification version (see the documentation for more details).

In [1]:
%matplotlib inline
import warnings
from datetime import datetime
import os
from pathlib import Path
import quandl
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pandas_datareader.data as web
from pandas_datareader.famafrench import get_available_datasets
from pyfinance.ols import PandasRollingOLS
from sklearn.feature_selection import mutual_info_classif
In [2]:
warnings.filterwarnings('ignore')
sns.set(style="darkgrid", color_codes=True)

Get Data

In [3]:
'read'
with pd.HDFStore('/home/aj/Research Notebooks_py/Hands-On-Machine-Learning-for-Algorithmic-Trading/data/assets.h5', mode='r') as store:
    data = store['engineered_features']
In [4]:
data.head()
Out[4]:
return_1m return_2m return_3m return_6m return_9m return_12m Mkt-RF SMB HML RMW ... return_1m_t-4 return_1m_t-5 return_1m_t-6 target_1m target_2m target_3m target_6m target_12m msize Sector
ticker date
AAL 2006-09-30 0.049231 -0.014995 -0.042760 0.017278 0.019858 0.064198 0.138863 2.373876 -0.901034 -0.830475 ... NaN NaN NaN 0.124746 0.131546 0.066996 0.004278 -0.042727 10 Industrials
2006-10-31 0.124746 0.086333 0.029541 0.023947 0.060689 0.060353 0.138863 2.373876 -0.901034 -0.830475 ... NaN NaN NaN 0.138387 0.039242 0.039346 -0.048759 -0.047917 10 Industrials
2006-11-30 0.138387 0.131546 0.103414 0.033049 0.061789 0.044687 0.138863 2.373876 -0.901034 -0.830475 ... NaN NaN NaN -0.051268 -0.006895 -0.026972 -0.074586 -0.080364 10 Industrials
2006-12-31 -0.051268 0.039242 0.066996 0.010629 0.033588 0.031443 0.138863 2.373876 -0.901034 -0.830475 ... NaN NaN NaN 0.039554 -0.014591 -0.054754 -0.091543 -0.102498 10 Industrials
2007-01-31 0.039554 -0.006895 0.039346 0.034432 0.029055 0.055312 0.138863 2.373876 -0.901034 -0.830475 ... 0.049231 NaN NaN -0.065916 -0.098649 -0.129395 -0.093757 -0.109928 10 Industrials

5 rows × 32 columns

In [5]:
data.info()
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 323338 entries, (AAL, 2006-09-30 00:00:00) to (ZUMZ, 2018-02-28 00:00:00)
Data columns (total 32 columns):
return_1m        323338 non-null float64
return_2m        323338 non-null float64
return_3m        323338 non-null float64
return_6m        323338 non-null float64
return_9m        323338 non-null float64
return_12m       323338 non-null float64
Mkt-RF           323338 non-null float64
SMB              323338 non-null float64
HML              323338 non-null float64
RMW              323338 non-null float64
CMA              323338 non-null float64
momentum_2       323338 non-null float64
momentum_3       323338 non-null float64
momentum_6       323338 non-null float64
momentum_9       323338 non-null float64
momentum_12      323338 non-null float64
momentum_3_12    323338 non-null float64
year             323338 non-null int64
month            323338 non-null int64
return_1m_t-1    321688 non-null float64
return_1m_t-2    320038 non-null float64
return_1m_t-3    318388 non-null float64
return_1m_t-4    316738 non-null float64
return_1m_t-5    315088 non-null float64
return_1m_t-6    313438 non-null float64
target_1m        323338 non-null float64
target_2m        321688 non-null float64
target_3m        320038 non-null float64
target_6m        315088 non-null float64
target_12m       305188 non-null float64
msize            323338 non-null int64
Sector           323338 non-null object
dtypes: float64(28), int64(3), object(1)
memory usage: 80.3+ MB
In [6]:
data.isna().sum()
Out[6]:
return_1m            0
return_2m            0
return_3m            0
return_6m            0
return_9m            0
return_12m           0
Mkt-RF               0
SMB                  0
HML                  0
RMW                  0
CMA                  0
momentum_2           0
momentum_3           0
momentum_6           0
momentum_9           0
momentum_12          0
momentum_3_12        0
year                 0
month                0
return_1m_t-1     1650
return_1m_t-2     3300
return_1m_t-3     4950
return_1m_t-4     6600
return_1m_t-5     8250
return_1m_t-6     9900
target_1m            0
target_2m         1650
target_3m         3300
target_6m         8250
target_12m       18150
msize                0
Sector               0
dtype: int64
In [7]:
data = data.dropna()
data.info()
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 295288 entries, (AAL, 2007-03-31 00:00:00) to (ZUMZ, 2017-03-31 00:00:00)
Data columns (total 32 columns):
return_1m        295288 non-null float64
return_2m        295288 non-null float64
return_3m        295288 non-null float64
return_6m        295288 non-null float64
return_9m        295288 non-null float64
return_12m       295288 non-null float64
Mkt-RF           295288 non-null float64
SMB              295288 non-null float64
HML              295288 non-null float64
RMW              295288 non-null float64
CMA              295288 non-null float64
momentum_2       295288 non-null float64
momentum_3       295288 non-null float64
momentum_6       295288 non-null float64
momentum_9       295288 non-null float64
momentum_12      295288 non-null float64
momentum_3_12    295288 non-null float64
year             295288 non-null int64
month            295288 non-null int64
return_1m_t-1    295288 non-null float64
return_1m_t-2    295288 non-null float64
return_1m_t-3    295288 non-null float64
return_1m_t-4    295288 non-null float64
return_1m_t-5    295288 non-null float64
return_1m_t-6    295288 non-null float64
target_1m        295288 non-null float64
target_2m        295288 non-null float64
target_3m        295288 non-null float64
target_6m        295288 non-null float64
target_12m       295288 non-null float64
msize            295288 non-null int64
Sector           295288 non-null object
dtypes: float64(28), int64(3), object(1)
memory usage: 73.2+ MB

Create Dummy variables

In [8]:
dummy_data = pd.get_dummies(data,
                            columns=['year','month', 'msize', 'Sector'],
                            prefix=['year','month', 'msize', ''],
                            prefix_sep=['_', '_', '_', ''])
dummy_data = dummy_data.rename(columns={c:c.replace('.0', '') for c in dummy_data.columns})
dummy_data.info()
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 295288 entries, (AAL, 2007-03-31 00:00:00) to (ZUMZ, 2017-03-31 00:00:00)
Data columns (total 79 columns):
return_1m                 295288 non-null float64
return_2m                 295288 non-null float64
return_3m                 295288 non-null float64
return_6m                 295288 non-null float64
return_9m                 295288 non-null float64
return_12m                295288 non-null float64
Mkt-RF                    295288 non-null float64
SMB                       295288 non-null float64
HML                       295288 non-null float64
RMW                       295288 non-null float64
CMA                       295288 non-null float64
momentum_2                295288 non-null float64
momentum_3                295288 non-null float64
momentum_6                295288 non-null float64
momentum_9                295288 non-null float64
momentum_12               295288 non-null float64
momentum_3_12             295288 non-null float64
return_1m_t-1             295288 non-null float64
return_1m_t-2             295288 non-null float64
return_1m_t-3             295288 non-null float64
return_1m_t-4             295288 non-null float64
return_1m_t-5             295288 non-null float64
return_1m_t-6             295288 non-null float64
target_1m                 295288 non-null float64
target_2m                 295288 non-null float64
target_3m                 295288 non-null float64
target_6m                 295288 non-null float64
target_12m                295288 non-null float64
year_2001                 295288 non-null uint8
year_2002                 295288 non-null uint8
year_2003                 295288 non-null uint8
year_2004                 295288 non-null uint8
year_2005                 295288 non-null uint8
year_2006                 295288 non-null uint8
year_2007                 295288 non-null uint8
year_2008                 295288 non-null uint8
year_2009                 295288 non-null uint8
year_2010                 295288 non-null uint8
year_2011                 295288 non-null uint8
year_2012                 295288 non-null uint8
year_2013                 295288 non-null uint8
year_2014                 295288 non-null uint8
year_2015                 295288 non-null uint8
year_2016                 295288 non-null uint8
year_2017                 295288 non-null uint8
month_1                   295288 non-null uint8
month_2                   295288 non-null uint8
month_3                   295288 non-null uint8
month_4                   295288 non-null uint8
month_5                   295288 non-null uint8
month_6                   295288 non-null uint8
month_7                   295288 non-null uint8
month_8                   295288 non-null uint8
month_9                   295288 non-null uint8
month_10                  295288 non-null uint8
month_11                  295288 non-null uint8
month_12                  295288 non-null uint8
msize_1                   295288 non-null uint8
msize_2                   295288 non-null uint8
msize_3                   295288 non-null uint8
msize_4                   295288 non-null uint8
msize_5                   295288 non-null uint8
msize_6                   295288 non-null uint8
msize_7                   295288 non-null uint8
msize_8                   295288 non-null uint8
msize_9                   295288 non-null uint8
msize_10                  295288 non-null uint8
Basic Materials           295288 non-null uint8
Communication Services    295288 non-null uint8
Consumer Cyclical         295288 non-null uint8
Consumer Defensive        295288 non-null uint8
Energy                    295288 non-null uint8
Financial Services        295288 non-null uint8
Healthcare                295288 non-null uint8
Industrial Goods          295288 non-null uint8
Industrials               295288 non-null uint8
Real Estate               295288 non-null uint8
Technology                295288 non-null uint8
Utilities                 295288 non-null uint8
dtypes: float64(28), uint8(51)
memory usage: 78.6+ MB
In [9]:
dummy_data.head()
Out[9]:
return_1m return_2m return_3m return_6m return_9m return_12m Mkt-RF SMB HML RMW ... Consumer Cyclical Consumer Defensive Energy Financial Services Healthcare Industrial Goods Industrials Real Estate Technology Utilities
ticker date
AAL 2007-03-31 -0.130235 -0.098649 -0.054754 0.004278 -0.011653 0.010757 0.138863 2.373876 -0.901034 -0.830475 ... 0 0 0 0 0 0 1 0 0 0
2007-04-30 -0.187775 -0.159497 -0.129395 -0.048759 -0.023344 -0.013075 0.138863 2.373876 -0.901034 -0.830475 ... 0 0 0 0 0 0 1 0 0 0
2007-05-31 -0.034921 -0.114641 -0.119870 -0.074586 -0.018696 -0.022248 0.138863 2.373876 -0.901034 -0.830475 ... 0 0 0 0 0 0 1 0 0 0
2007-06-30 -0.150912 -0.094772 -0.126900 -0.091543 -0.041504 -0.041818 0.138863 2.373876 -0.901034 -0.830475 ... 0 0 0 0 0 0 1 0 0 0
2007-07-31 0.024447 -0.067345 -0.056660 -0.093757 -0.051400 -0.031781 0.138863 2.373876 -0.901034 -0.830475 ... 0 0 0 0 0 0 1 0 0 0

5 rows × 79 columns

In [10]:
dummy_data.describe()
Out[10]:
return_1m return_2m return_3m return_6m return_9m return_12m Mkt-RF SMB HML RMW ... Consumer Cyclical Consumer Defensive Energy Financial Services Healthcare Industrial Goods Industrials Real Estate Technology Utilities
count 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 ... 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000 295288.000000
mean 0.012364 0.009287 0.008322 0.007270 0.006870 0.006551 0.977150 0.626099 0.129970 -0.056357 ... 0.128197 0.051231 0.054442 0.172716 0.118186 0.000640 0.171609 0.058147 0.120130 0.034597
std 0.112866 0.080281 0.065998 0.047945 0.039451 0.034315 0.863596 1.210207 1.512535 1.919062 ... 0.334309 0.220470 0.226888 0.378002 0.322829 0.025291 0.377041 0.234021 0.325114 0.182756
min -0.325956 -0.251753 -0.211078 -0.159002 -0.129920 -0.112127 -9.311845 -10.159771 -13.036170 -24.698688 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
25% -0.046016 -0.030430 -0.023723 -0.014675 -0.010906 -0.008809 0.500507 -0.074166 -0.642750 -0.929064 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
50% 0.009954 0.010063 0.010004 0.009447 0.009114 0.008774 0.932280 0.557657 0.102173 0.049917 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
75% 0.066210 0.049401 0.042274 0.032073 0.027288 0.024564 1.413691 1.257672 0.875003 0.915518 ... 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
max 0.428079 0.279401 0.219990 0.153115 0.123644 0.105678 10.407951 10.212963 13.111228 17.643528 ... 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000

8 rows × 79 columns

Mutual Information

Original Data

In [11]:
target_labels = [f'target_{i}m' for i in [1,2,3,6,12]]
targets = data.dropna().loc[:, target_labels]

features = data.dropna().drop(target_labels, axis=1)
features.Sector = pd.factorize(features.Sector)[0]

cat_cols = ['year', 'month', 'msize', 'Sector']
discrete_features = [features.columns.get_loc(c) for c in cat_cols]
In [12]:
mutual_info = pd.DataFrame()
for label in target_labels:
    mi = mutual_info_classif(X=features, 
                             y=(targets[label]>0).astype(int),
                             discrete_features=discrete_features,
                             random_state=42
                            )
    mutual_info[label] = pd.Series(mi, index=features.columns)
In [13]:
mutual_info.sum()
Out[13]:
target_1m     0.036124
target_2m     0.062342
target_3m     0.089524
target_6m     0.138216
target_12m    0.209336
dtype: float64

Normalized MI Heatmap

In [14]:
fig, ax= plt.subplots(figsize=(15, 4))
sns.heatmap(mutual_info.div(mutual_info.sum()).T, ax=ax);

Dummy Data

In [15]:
target_labels = [f'target_{i}m' for i in [1, 2, 3, 6, 12]]
dummy_targets = dummy_data.dropna().loc[:, target_labels]

dummy_features = dummy_data.dropna().drop(target_labels, axis=1)
cat_cols = [c for c in dummy_features.columns if c not in features.columns]
discrete_features = [dummy_features.columns.get_loc(c) for c in cat_cols]
In [16]:
mutual_info_dummies = pd.DataFrame()
for label in target_labels:
    mi = mutual_info_classif(X=dummy_features, 
                             y=(dummy_targets[label]> 0).astype(int),
                             discrete_features=discrete_features,
                             random_state=42
                            )    
    mutual_info_dummies[label] = pd.Series(mi, index=dummy_features.columns)
In [17]:
mutual_info_dummies.sum()
Out[17]:
target_1m     0.037281
target_2m     0.064253
target_3m     0.092183
target_6m     0.141994
target_12m    0.213704
dtype: float64
In [18]:
fig, ax= plt.subplots(figsize=(4, 20))
sns.heatmap(mutual_info_dummies.div(mutual_info_dummies.sum()), ax=ax);
In [ ]: