19 min read
Churn prediction is big business. It minimizes customer defection by predicting which customers are likely to cancel a subscription to a service. Though originally used within the telecommunications industry, it has become common practice across banks, ISPs, insurance firms, and other verticals.
The prediction process is heavily data-driven and often utilizes advanced machine learning techniques. In this post, we'll take a look at what types of customer data are typically used, do some preliminary analysis of the data, and generate churn prediction models–all with Spark and its machine learning frameworks.
Customer 360 Using data science in order to better understand and predict customer behavior is an iterative process, which involves:
In order to understand the customer, a number of factors can be analyzed, such as:
With this analysis, telecom companies can gain insights to predict and enhance the customer experience, prevent churn, and tailor marketing campaigns.
Classification is a family of supervised machine learning algorithms that identify which category an item belongs to (e.g., whether a transaction is fraud or not fraud), based on labeled examples of known items (e.g., transactions known to be fraud or not). Classification takes a set of data with known labels and pre-determined features and learns how to label new records based on that information. Features are the “if questions” that you ask. The label is the answer to those questions. In the example below, if it walks, swims, and quacks like a duck, then the label is "duck."
Let’s go through an example of telecom customer churn:
Decision trees create a model that predicts the class or label, based on several input features. Decision trees work by evaluating an expression containing a feature at every node and selecting a branch to the next node, based on the answer. A possible decision tree for predicting credit risk is shown below. The feature questions are the nodes, and the answers “yes” or “no” are the branches in the tree to the child nodes.
For this tutorial, we'll be using the Orange Telecoms Churn Dataset. It consists of cleaned customer activity data (features), along with a churn label specifying whether the customer canceled the subscription or not. The data can be fetched from BigML's S3 bucket, churn-80 and churn-20. The two sets are from the same batch but have been split by an 80/20 ratio. We'll use the larger set for training and cross-validation purposes and the smaller set for final testing and model performance evaluation. The two data sets have been included with the complete code in this repository for convenience. The data set has the following schema:
1. State: string 2. Account length: integer 3. Area code: integer 4. International plan: string 5. Voice mail plan: string 6. Number vmail messages: integer 7. Total day minutes: double 8. Total day calls: integer 9. Total day charge: double 10.Total eve minutes: double 11. Total eve calls: integer 12. Total eve charge: double 13. Total night minutes: double 14. Total night calls: integer 15. Total night charge: double 16. Total intl minutes: double 17. Total intl calls: integer 18. Total intl charge: double 19. Customer service calls: integer
The CSV file has the following format:
The image below shows the first few rows of the data set:
This tutorial will run on Spark 2.0.1 and above.
$spark-shell --master local
First, we will import the SQL and machine learning packages.
import org.apache.spark._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql._ import org.apache.spark.sql.Dataset import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.tuning.ParamGridBuilder import org.apache.spark.ml.tuning.CrossValidator import org.apache.spark.ml.feature.VectorAssembler
We use a Scala case class and Structype to define the schema, corresponding to a line in the CSV data file.
// define the Churn Schema case class Account(state: String, len: Integer, acode: String, intlplan: String, vplan: String, numvmail: Double, tdmins: Double, tdcalls: Double, tdcharge: Double, temins: Double, tecalls: Double, techarge: Double, tnmins: Double, tncalls: Double, tncharge: Double, timins: Double, ticalls: Double, ticharge: Double, numcs: Double, churn: String) val schema = StructType(Array( StructField("state", StringType, true), StructField("len", IntegerType, true), StructField("acode", StringType, true), StructField("intlplan", StringType, true), StructField("vplan", StringType, true), StructField("numvmail", DoubleType, true), StructField("tdmins", DoubleType, true), StructField("tdcalls", DoubleType, true), StructField("tdcharge", DoubleType, true), StructField("temins", DoubleType, true), StructField("tecalls", DoubleType, true), StructField("techarge", DoubleType, true), StructField("tnmins", DoubleType, true), StructField("tncalls", DoubleType, true), StructField("tncharge", DoubleType, true), StructField("timins", DoubleType, true), StructField("ticalls", DoubleType, true), StructField("ticharge", DoubleType, true), StructField("numcs", DoubleType, true), StructField("churn", StringType, true) ))
Using Spark 2.0, we specify the data source and schema to load into a Dataset. Note that with Spark 2.0, specifying the schema when loading data into a DataFrame will give better performance than schema inference. We cache the Datasets for quick, repeated access. We also print the schema of the Datasets.
val train: Dataset[Account] = spark.read.option("inferSchema", "false") .schema(schema).csv("/user/user01/data/churn-bigml-80.csv").as[Account] train.cache val test: Dataset[Account] = spark.read.option("inferSchema", "false") .schema(schema).csv("/user/user01/data/churn-bigml-20.csv").as[Account] test.cache train.printSchema()
root |-- state: string (nullable = true) |-- len: integer (nullable = true) |-- acode: string (nullable = true) |-- intlplan: string (nullable = true) |-- vplan: string (nullable = true) |-- numvmail: double (nullable = true) |-- tdmins: double (nullable = true) |-- tdcalls: double (nullable = true) |-- tdcharge: double (nullable = true) |-- temins: double (nullable = true) |-- tecalls: double (nullable = true) |-- techarge: double (nullable = true) |-- tnmins: double (nullable = true) |-- tncalls: double (nullable = true) |-- tncharge: double (nullable = true) |-- timins: double (nullable = true) |-- ticalls: double (nullable = true) |-- ticharge: double (nullable = true) |-- numcs: double (nullable = true) |-- churn: string (nullable = true) //display the first 20 rows: train.show
Spark DataFrames include some built-in functions for statistical processing. The describe() function performs summary statistics calculations on all numeric columns and returns them as a DataFrame.
We can use Spark SQL to explore the dataset. Here are some example queries using the Scala DataFrame API:
train.groupBy("churn").sum("numcs").show +-----+----------+ |churn|sum(numcs)| +-----+----------+ |False| 3310.0| | True| 856.0| +-----+----------+ train.createOrReplaceTempView("account") spark.catalog.cacheTable("account")
Total day minutes and Total day charge are highly correlated fields. Such correlated data won't be very beneficial for our model training runs, so we're going to remove them. We'll do so by dropping one column of each pair of correlated fields, along with the State and Area code columns, which we also won’t use.
val dtrain =train.drop("state").drop("acode").drop("vplan") .drop("tdcharge").drop("techarge")
Grouping the data by the Churn field and counting the number of instances in each group shows that there are roughly 6 times as many false churn samples as true churn samples.
+-----+-----+ |churn|count| +-----+-----+ |False| 2278| | True| 388| +-----+-----+
Business decisions will be used to retain the customers most likely to leave, not those who are likely to stay. Thus, we need to ensure that our model is sensitive to the Churn=True samples.
We can put the two sample types on the same footing using stratified sampling. The DataFrames sampleBy() function does this when provided with fractions of each sample type to be returned. Here, we're keeping all instances of the Churn=True class, but downsampling the Churn=False class to a fraction of 388/2278.
val fractions = Map("False" -> .17, "True" -> 1.0) val strain = dtrain.stat.sampleBy("churn", fractions, 36L) strain.groupBy("churn").count.show
-----+-----+ |churn|count| +-----+-----+ |False| 379| | True| 388| +-----+-----+
To build a classifier model, you extract the features that most contribute to the classification. The features for each item consist of the fields shown below:
In order for the features to be used by a machine learning algorithm, they are transformed and put into Feature Vectors, which are vectors of numbers representing the value for each feature.
Reference: Learning Spark
The ML package is the newer library of machine learning routines. Spark ML provides a uniform set of high-level APIs built on top of DataFrames.
We will use an ML Pipeline to pass the data through transformers in order to extract the features and an estimator to produce the model.
The ML package needs data to be put in a (label: Double, features: Vector) DataFrame format with correspondingly named fields. We set up a pipeline to pass the data through 3 transformers in order to extract the features: 2 StringIndexers and a VectorAssembler. We use the StringIndexers to convert the String Categorial feature intlplan and label into number indices. Indexing categorical features allows decision trees to treat categorical features appropriately, improving performance.
// set up StringIndexer transformers for label and string feature val ipindexer = new StringIndexer() .setInputCol("intlplan") .setOutputCol("iplanIndex") val labelindexer = new StringIndexer() .setInputCol("churn") .setOutputCol("label")
The VectorAssembler combines a given list of columns into a single feature vector column.
// set up a VectorAssembler transformer val featureCols = Array("len", "iplanIndex", "numvmail", "tdmins", "tdcalls", "temins", "tecalls", "tnmins", "tncalls", "timins", "ticalls", "numcs") val assembler = new VectorAssembler() .setInputCols(featureCols) .setOutputCol("features")
The final element in our pipeline is an estimator (a decision tree classifier), training on the vector of labels and features.
// set up a DecisionTreeClassifier estimator val dTree = new DecisionTreeClassifier().setLabelCol("label") .setFeaturesCol("features") // Chain indexers and tree in a Pipeline val pipeline = new Pipeline() .setStages(Array(ipindexer, labelindexer, assembler, dTree))
We would like to determine which parameter values of the decision tree produce the best model. A common technique for model selection is k-fold cross validation, where the data is randomly split into k partitions. Each partition is used once as the testing data set, while the rest are used for training. Models are then generated using the training sets and evaluated with the testing sets, resulting in k model performance measurements. The average of the performance scores is often taken to be the overall score of the model, given its build parameters. For model selection we can search through the model parameters, comparing their cross validation performances. The model parameters leading to the highest performance metric produce the best model.
Spark ML supports k-fold cross validation with a transformation/estimation pipeline to try out different combinations of parameters, using a process called grid search, where you set up the parameters to test, and a cross validation evaluator to construct a model selection workflow.
Below, we use a ParamGridBuilder to construct the parameter grid.
// Search through decision tree's maxDepth parameter for best model val paramGrid = new ParamGridBuilder().addGrid(dTree.maxDepth, Array(2, 3, 4, 5, 6, 7)).build()
We define a BinaryClassificationEvaluator Evaluator, which will evaluate the model according to a precision metric by comparing the test label column with the test prediction column. The default metric is the area under the ROC curve.
// Set up Evaluator (prediction, true label) val evaluator = new BinaryClassificationEvaluator() .setLabelCol("label") .setRawPredictionCol("prediction")
We use a CrossValidator for model selection. The CrossValidator uses the Estimator Pipeline, the Parameter Grid, and the Classification Evaluator. The CrossValidator uses the ParamGridBuilder to iterate through the maxDepth parameter of the decision tree and evaluate the models, repeating 3 times per parameter value for reliable results.
// Set up 3-fold cross validation val crossval = new CrossValidator().setEstimator(pipeline) .setEvaluator(evaluator) .setEstimatorParamMaps(paramGrid).setNumFolds(3) val cvModel = crossval.fit(ntrain)
We get the best decision tree model, in order to print out the decision tree and parameters.
// Fetch best model val bestModel = cvModel.bestModel val treeModel = bestModel.asInstanceOf[org.apache.spark.ml.PipelineModel] .stages(3).asInstanceOf[DecisionTreeClassificationModel] println("Learned classification tree model:\n" + treeModel.toDebugString)
//0-11 feature columns: len, iplanIndex, numvmail, tdmins, tdcalls, temins, tecalls, tnmins, tncalls, timins, ticalls, numcs println( "Feature 11:" + featureCols(11)) println( "Feature 3:" + featureCols(3)) Feature 11:numcs Feature 3:tdmins
We find that the best tree model produced using the cross-validation process is one with a depth of 5. The toDebugString() function provides a print of the tree's decision nodes and final prediction outcomes at the end leaves. We can see that features 11 and 3 are used for decision making and should thus be considered as having high predictive power to determine a customer's likeliness to churn. It's not surprising that these feature numbers map to the fields Customer service calls and Total day minutes. Decision trees are often used for feature selection because they provide an automated mechanism for determining the most important features (those closest to the tree root).
The actual performance of the model can be determined using the test data set that has not been used for any training or cross-validation activities. We'll transform the test set with the model pipeline, which will map the features according to the same recipe.
val predictions = cvModel.transform(test)
The evaluator will provide us with the score of the predictions, and then we'll print them along with their probabilities.
val accuracy = evaluator.evaluate(predictions) evaluator.explainParams() val result = predictions.select("label", "prediction", "probability") result.show
accuracy: Double = 0.8484817813765183 metric name in evaluation (default: areaUnderROC)
In this case, the evaluation returns 84.8% precision. The prediction probabilities can be very useful in ranking customers by their likeliness to defect. This way, the limited resources available to the business for retention can be focused on the appropriate customers.
Below, we calculate some more metrics. The number of false/true positive and negative predictions is also useful:
val lp = predictions.select("label", "prediction") val counttotal = predictions.count() val correct = lp.filter($"label" === $"prediction").count() val wrong = lp.filter(not($"label" === $"prediction")).count() val ratioWrong = wrong.toDouble / counttotal.toDouble val ratioCorrect = correct.toDouble / counttotal.toDouble val truep = lp.filter($"prediction" === 0.0) .filter($"label" === $"prediction").count() / counttotal.toDouble val truen = lp.filter($"prediction" === 1.0) .filter($"label" === $"prediction").count() / counttotal.toDouble val falsep = lp.filter($"prediction" === 1.0) .filter(not($"label" === $"prediction")).count() / counttotal.toDouble val falsen = lp.filter($"prediction" === 0.0) .filter(not($"label" === $"prediction")).count() / counttotal.toDouble println("counttotal : " + counttotal) println("correct : " + correct) println("wrong: " + wrong) println("ratio wrong: " + ratioWrong) println("ratio correct: " + ratioCorrect) println("ratio true positive : " + truep) println("ratio false positive : " + falsep) println("ratio true negative : " + truen) println("ratio false negative : " + falsen)
counttotal : 667 correct : 574 wrong: 93 ratio wrong: 0.13943028485757122 ratio correct: 0.8605697151424287 ratio true positive : 0.1184407796101949 ratio false positive : 0.0239880059970015 ratio true negative : 0.7421289355322339 ratio false negative : 0.11544227886056972
In this blog post, we showed you how to get started using Apache Spark’s machine learning decision trees and ML Pipelines for classification. If you have any further questions about this tutorial, please ask them in the comments section below.
Stay ahead of the bleeding edge...get the best of Big Data in your inbox.