Zero Inflated Regression Clearly Explained
Learn how to model when the response contains lots of zeros
Introduction to Zero Inflated Regression
When the response variable (Y) in your dataset contains lots of zeros, a regression model can face difficulty in predicting those observations as exactly zeros. Besides, having a lot of zeros in the data can affect the model fit as well.
Zero inflated regression provides an elegant approach to model this situation, which by the way is quite common in real world ML problems.
Let’s go over the various steps needed to apply Zero inflated regression in Python. We will be using the `sklego` package for this.
First, let’s setup the packages
1. Setup and Data
import numpy as np
import pandas as pd
%matplotlib inline
We will use a synthetic dataset. So, let’s create the X and y data.
df1 = pd.concat([pd.Series(np.arange(100)),
pd.Series(np.zeros(100))],
axis=1)
df2 = pd.concat([pd.Series(np.arange(100,200)),
pd.Series(20 + np.arange(100) +
pd.Series(np.random.uniform(-9,9,100)))],
axis=1)
df3 = pd.concat([pd.Series(np.arange(200,300)),
pd.Series(np.zeros(100))],
axis=1)
df = pd.concat([df1, df2, df3], axis=0)
print(df.shape)
df.columns = ["x","y"]
df.head()
Plot
df.plot(x="x", y="y", kind="scatter");
2. Build Linear Regression on raw data
# Train Linear Regression model
from sklearn.linear_model import LinearRegression
regMod = LinearRegression()
regMod.fit(df[['x']], df[['y']])
Make predictions
df['y_hat_1'] = regMod.predict(df[['x']])
df['y_hat_1']
Plot
# Plot
ax1 = df.plot(x="x", y="y",
kind="scatter",
figsize=(5,3),
title="Linear Regression");
df.plot(x="x", y="y_hat_1", kind="line", ax=ax1, color='firebrick');
3. Build Zero Inflated Regression Model
First you will need to install `sklego` package to implement Zero inflated regression.
!pip install sklego
Now, let’s import the packages
from sklearn.tree import DecisionTreeClassifier
from sklego.meta import ZeroInflatedRegressor
Initialize the model.
zirMod = ZeroInflatedRegressor(
classifier = DecisionTreeClassifier(),
regressor = LinearRegression()
)
Train the model
zirMod.fit(df['x'].to_numpy().reshape(-1, 1),
df['y'].to_numpy().reshape(-1, 1))\
Make predictions
df['y_hat_2'] = zirMod.predict(df[['x']])
df['y_hat_2']
Plot
ax1 = df.plot(x="x", y="y",
kind="scatter",
figsize=(5,3),
title="Zero Inflated Regression");
df.plot(x="x", y="y_hat_2", kind="line", ax=ax1, color='firebrick');