Driver Churn Prediction
Data ScienceMachine Learning

Driver Churn Prediction

Aug 30, 2024  •  10 min read

Table of Contents

Introduction

Assume this, You are working in the HR department of OLA Cabs and are responsible for hiring and retaining drivers. You've noticed an increase in the number of drivers leaving the company over the past few months. Your manager has tasked you with identifying the reasons for this trend. Well, what do you do? You set a meeting with the analytics department and explain the problem. This is where I come in. I am a data scientist responsible for finding out the reasons for the churn and predicting which drivers will churn in the future.

A few moments later

I have completed the analysis and the report of my findings is as follows:

Business Problem

Ola Cabs provides convenience and transparency by connecting customers to various vehicles. Ola's main focus is to ensure a quality driving experience and retain efficient drivers. Recruiting and retaining drivers is challenging due to high churn rates. Losing drivers affects morale, and acquiring new drivers is costlier than retaining current ones. We will be building a predictive model to determine this. We are also going to determine which of the driver features is most responsible for driver churn.

Dataset

The dataset contains monthly information for a segment of drivers for 2019 and 2020. The dataset contains the following columns:

FeatureDescription
MMMM-YYReporting Date (Monthly)
Driver_IDUnique ID for drivers
AgeAge of the driver
GenderGender of the driver – Male : 0, Female: 1
CityCity Code of the driver
Education_LevelEducation level – 0 for 10+, 1 for 12+, 2 for graduate
IncomeMonthly average Income of the driver
Date Of JoiningJoining date for the driver
LastWorkingDateLast date of working for the driver
Joining DesignationDesignation of the driver at the time of joining
GradeGrade of the driver at the time of reporting
Total Business Value

The total business value acquired by the driver in a month (negative business indicates cancellation/refund or car EMI adjustments)

Quarterly RatingQuarterly rating of the driver: 1, 2, 3, 4, 5 (higher is better)

Tech Stack

Predictive Algorithms and Metrics

Why PySpark?

Since I was already proficient in Pandas and scikit-learn, I wanted to explore a new technology. PySpark, a big data processing library capable of handling large datasets, presented a great opportunity to learn and apply it for data preprocessing, feature engineering, and machine learning.

Another advantage of PySpark is that I can execute SQL queries on the data directly. This is a great feature as I can use SQL queries to filter, group, and aggregate data. Only thing to keep in mind is that one has to create a temporary view of the data before executing SQL queries.

Following is an example of SQL query to check the change in business value for drivers whose rating decreased

spark_df.createOrReplaceTempView("ola_driver")
 
spark.sql("""
  with Lower_Rating_Drivers as (
      select
          distinct Driver_ID
      from
          (select
              Driver_ID,
              first(`Quarterly Rating`) over(partition by Driver_ID order by reporting_month_year) as first_rating,
              last(`Quarterly Rating`) over(partition by Driver_ID order by reporting_month_year) as last_rating
          from
              ola_driver
          )
      where
          first_rating > last_rating
  ),
  Quarterly_Business_Value as (
      select
          Driver_ID,
          ReportingYear_Quarter,
          sum(`Total Business Value`) over( partition by Driver_ID, ReportingYear_Quarter) as Total_Business_Value
      from
          ola_driver
          join Lower_Rating_Drivers using(Driver_ID)
  )
  select
      Driver_ID,
      first(Total_Business_Value) as First_Business_Value,
      last(Total_Business_Value) as Last_Business_Value,
      (Last_Business_Value - First_Business_Value)/ First_Business_Value as Business_Value_Change,
      case when Last_Business_Value < First_Business_Value then 1 else 0 end as Business_Decrease
  from
      Quarterly_Business_Value
  group by
      Driver_ID
""").show()

Data Preprocessing

Imputation

This was a small dataset with some missing values. I wanted to use scikit-learn's KNN Imputer to impute these values. However, I didn't want to go through the hassle of converting my dataset to a Pandas dataframe, applying the KNN Imputer, and then converting it back to a PySpark dataframe.

Luckily, PySpark provides a feature called UDF (User Defined Functions). This allows me to use scikit-learn's KNN Imputer directly on my PySpark dataset.

list_of_columns = ['Driver_ID',
 'Age',
 'Gender',
 'Education_Level',
 'Income',
 'Joining Designation',
 'Grade',
 'Total Business Value',
 'Quarterly Rating']
 
df_numpy = np.array(spark_df.select(list_of_columns).collect()) # use a small subset of data that can fit in memory
 
k_imputer = KNNImputer(n_neighbors=5, weights='distance')
k_imputer.fit(df_numpy)
 
sc = spark.sparkContext
broadcast_model = sc.broadcast(k_imputer)
 
column_index_mapping = {col: idx for idx, col in enumerate(list_of_columns)}
 
def create_knn_imputer_udf(column_name):
    index = column_index_mapping[column_name]
 
    @sf.udf(IntegerType())
    def knn_impute(*cols):
        row = np.array(cols).reshape(1, -1)
        imputed_row = broadcast_model.value.transform(row)
        return int(imputed_row[0][index])
    return knn_impute
 
knn_impute_Age = create_knn_imputer_udf("Age")
spark_df = spark_df.withColumn("Age_Imputed", knn_impute_Age(*list_of_columns))
 

There are better ways to impute missing values in PySpark, because of time and space complexity of KNN algorithm. I just wanted to learn how to use UDFs using scikit-learn's functions.

Feature Engineering

window_spec = Window.partitionBy("Driver_ID").orderBy("reporting_month_year")
 
spark_df=spark_df.withColumns({
    "LastWorkingDate":  sf.coalesce(sf.col("LastWorkingDate"), sf.first("LastWorkingDate", True).over(window_spec)),
    "Churned":  sf.when(sf.col("LastWorkingDate").isNotNull(), 1).otherwise(0),
    "Had_Negative_Business":  sf.when(sf.col("Total Business Value") > 0, 1).otherwise(0),
    "Has_Income_Increased": sf.when(sf.last("Income").over(window_spec) > sf.first("Income").over(window_spec), 1).otherwise(0),
    "Has_Rating_Increased": sf.when(sf.last("Quarterly Rating").over(window_spec) > sf.first("Quarterly Rating").over(window_spec), 1).otherwise(0),
})
agg_map=[
    sf.first("Dateofjoining").alias("Date_Of_Joining"),
    sf.sum("Total Business Value").alias("Total_Business_Value"),
    sf.sum("Had_Negative_Business").alias("Total_Had_Negative_Business"),
    sf.max("Has_Income_Increased").alias("Has_Income_Increased"),
    sf.max("Has_Rating_Increased").alias("Has_Rating_Increased"),
    sf.avg("Total Business Value").cast("int").alias("Avg_Business_Value"),
    sf.max("reporting_month_year").alias("Last_Reporting_Month"),
    sf.max("Age").alias("Age"),
    sf.mode("Gender").alias("Gender"),
    sf.last("Income").alias("Income"),
    sf.sum("Income").alias("Total_Income"),
    sf.first("Education_Level").alias("Education_Level"),
    sf.last("City").alias("City"),
    sf.first("Joining Designation").alias("Joining_Designation"),
    sf.last("Grade").alias("Grade"),
    sf.last("Quarterly Rating").alias("Quarterly_Rating"),
    sf.max("LastWorkingDate").alias("Last_Working_Date"),
    sf.max("churned").alias("Churned"),
]
 
merged_df = spark_df.sort("reporting_month_year").groupBy("Driver_ID").agg(*agg_map)

Exploratory Data Analysis and Insights

Tableau Dashboard

Although I am not a Tableau expert, I created a dashboard to visualize the data. You can find the link to the dashboard below. A more detailed exploratory data analysis (EDA) is available in the notebook linked at the end.

Tableau

 

Following is the embedded dashboard from Tableau

Change in rating heatmap

Observations

Effect on business value when ratings decrease

business_decrease_df= spark.sql("""
        with Lower_Rating_Drivers as (
            select
                distinct Driver_ID
            from
                (select
                    Driver_ID,
                    first(`Quarterly Rating`) over(partition by Driver_ID order by reporting_month_year) as first_rating,
                    last(`Quarterly Rating`) over(partition by Driver_ID order by reporting_month_year) as last_rating
                from
                    ola_driver
                )
            where
                first_rating > last_rating
        ),
        Quarterly_Business_Value as (
            select
                Driver_ID,
                ReportingYear_Quarter,
                sum(`Total Business Value`) over( partition by Driver_ID, ReportingYear_Quarter) as Total_Business_Value
            from
                ola_driver
                join Lower_Rating_Drivers using(Driver_ID)
        )
        select
            Driver_ID,
            first(Total_Business_Value) as First_Business_Value,
            last(Total_Business_Value) as Last_Business_Value,
            (Last_Business_Value - First_Business_Value)/ First_Business_Value as Business_Value_Change,
            case when Last_Business_Value < First_Business_Value then 1 else 0 end as Business_Decrease
        from
            Quarterly_Business_Value
        group by
            Driver_ID
""").toPandas()
business_decrease_df.head()
business_decrease_df["Business_Decrease"].value_counts().to_frame().T.rename(columns={1:"Business Decrease Count", 0:"Business Increase Count"}).T.plot(kind="barh")
plt.ylabel(" ")
plt.title("Count of Drivers with Change in Business Value");
Change in business value

Observations

Effect of rating based on the month of the year

spark.sql("""
  with Quarterly_Agg as (
    select
        Driver_ID,
        ReportingYear_Quarter,
        case when max(`Quarterly Rating`) = 1 then 1 else 0 end as rating_1_flag,
        case when max(`Quarterly Rating`) = 2 then 1 else 0 end as rating_2_flag,
        case when max(`Quarterly Rating`) = 3 then 1 else 0 end as rating_3_flag,
        case when max(`Quarterly Rating`) = 4 then 1 else 0 end as rating_4_flag,
        reporting_month_year
    from
        ola_driver
    group by
        ReportingYear_Quarter,
        Driver_ID,
        reporting_month_year
  )
  select
      reporting_month_year,
      sum(rating_1_flag) as Total_Rating_1,
      sum(rating_2_flag) as Total_Rating_2,
      sum(rating_3_flag) as Total_Rating_3,
      sum(rating_4_flag) as Total_Rating_4
  from
      Quarterly_Agg
  group by
      reporting_month_year
  order by
      reporting_month_year
""")
ratings based on month year

Observations

Effect of Ratings based on City

effect of rating based on city

Observations

Other features affecting Quarterly Rating

alt text

Business Recommendations

Predictive Model Recommendations

ModelF1AccuracyPrecisionRecallAUC
LightGBM0.9407270.9256590.9461540.9353610.969285
XGBoost0.9377430.9232610.9601590.9163500.968273
Gradient Boosting0.9193250.8968820.9074070.9315590.955817
Random Forest0.9047620.8753000.8727920.9391630.936497

From the above analysis of models, we can conclude that LightGBM has better stats as compared to other models.

Feature Importance

Feature Importance

From above plot we can see that Tenure and Income are the biggest contributors for generating the predictions

Precision Recall Gap

Recall means out of all the drivers, how correctly the model identifies churning and Precision means from all the drivers identified to be churned, how many churned. Assume that the company has decided to give a raise to those drivers which are predicted to churn

Case 1: Businesses want my model to detect even a small possibility of Churn.

Case 2: The business wants to make sure that the predicted driver will churn.

Conclusion

Hosting the application

I have hosted the app on Streamlit. The app can be found below. Although the final model had been built using PySpark, I couldn't host it on Streamlit due to resource constraints. So I built a model using scikit-learn and hosted it on Streamlit. ie I used the same hyper-parameters and features to build the model.

Streamlit shutdowns the app after some hours of inactivity. Please click on button 'wake up' button to restart the app.

Streamlit

Jupyter Notebook

The main notebook used for the analysis can be found below. One can view the notebook using the nbviewer or open it in Google Colab.

nbviewer Open In Colab kaagle

GitHub Repository

The complete code can be found in the GitHub repository below

GitHub

Table of Contents