Downstream ML classification with Gretel ACTGAN and PyCaret

Learn about downstream machine learning tasks and synthetic data with Gretel’s new ACTGAN model and the PyCaret library

In a recent post we discussed how to safely avoid linkage attacks by using synthetic data. The synthetic data was used to train a downstream machine learning classifier. Classifiers require accurate and high quality data before being usefully deployed. In this post, we dive into using synthetic data to train a machine learning model to determine if a customer will purchase a certain product.

When practitioners use the word “downstream, they're typically referring to a step in the system that happens after the data has been processed and transformed. For example, if we were to train a machine learning classifier and then use it to make predictions about some customer behavior, those predictions could be used downstream to take further business action on behalf of the customer. 

However, we can also refer to the machine learning classifier itself as being downstream if we do a series of transformations to the training data. That's the case here, where we create a synthetic version of the data that we then use to train a downstream classifier which itself may have other downstream effects. We have a notebook where you can follow along.

I’ve generated synthetic data, now what? 

The data we use as an example here is our grocery store data.

order_idorder_doworder_hour_of_daydays_since_prior_orderair fresheners candlesasian foodsbaby accessoriesbaby bath body carebaby food formulabakery desserts...spreadsteatofu meat alternativestortillas flat breadtrail mix snack mixtrash bags linersvitamins supplementswater seltzer sparkling waterwhite winesyogurt
1597184000000...0000000001
201141030000000...0000000000
28220829000010...0002000002
28891158000000...0001000001
39712188000000...0000000000

Every food column after the first few metadata columns contains how many of an item a person bought in a single transaction. The column we’ll focus on predicting in this case is `Frozen Pizza`, but we could choose any column of interest. 

Since we are going to train both a synthetic data generating model and a downstream classification model, we need to hold out a small validation set. This validation set isn’t seen by the synthetic model or the classification model, and its purpose is to test the eventual classification performance of a classification model trained purely on synthetic data and validated on unseen real data.

This is an additional step in the traditional machine learning pipeline and ensures that our classification model trained on synthetic data can be used for real world data without data leakage.

FIgure 1: Diagram showing data split into training, test, and validation sets

We can use the remaining 95% of the data to train our synthetic model. 

train_df, validation_data = train_test_split(all_original_data, test_size=0.05)

In this instance, because we have over 100 columns, we want to use a model that handles high-dimensional data well. Additionally, we want to make sure that our model can handle any columns that might have a large number of purchases. For that reason, we’ll use our newly released Gretel ACTGAN model. This model is GAN-based, and highly effective for tabular data generation. Its improved memory usage, speed, and accuracy make it an excellent choice for this use case.

Figure 2: Synthetic Data Quality Score

Once we’ve trained our synthetic model and verified its quality we can generate records to train our downstream classifier. We generate the same number of records as we had in our training data, which results in a synthetic dataset with high downstream utility. 

We split this synthetically generated data in the same way we would split the original data if we were using that directly for our model training. 

synthetic_train_data, synthetic_test_data = train_test_split(synthetic_df, test_size=0.2)
original_train_data, original_test_data = train_test_split(train_df, test_size=0.2)

We also split the 95% data so we can have additional validation for our downstream model’s performance. 

Train and evaluate downstream ML models with PyCaret

A data practitioner would spend a great deal of time selecting a model and validating its performance. These steps can be accelerated using an AutoML tool like PyCaret, which we do here.

s = setup(synthetic_train_data, target='frozen pizza')
best = compare_models()

We first see a table that outlines information about the data and proposed training run. This table shows the size of our data, number of features, and other useful information. A subset of the information is shown here:

DescriptionValue
Targetfrozen pizza
Target TypeMulticlass
Original Data(4000, 137)
Fold Number10

After the 14 models have trained on a 10-fold cross-validation, we see the results. In this case, our best model, as measured by accuracy, is a Random Forest Classifier. The actual numbers will change between runs due to the stochastic nature of synthetic data generation.


Model
AccuracyRecallPrec.F1KappaMCCTT (Sec)
rfRandom Forest Classifier0.90850.25000.82540.86500.00000.00000.0930
etExtra Trees Classifier0.90850.25000.82540.86500.00000.00000.1020
dummyDummy Classifier0.90850.25000.82540.86500.00000.00000.0130
knnK Neighbors Classifier0.90780.24980.82570.86480.00180.00320.1900
ridgeRidge Classifier0.90680.24950.82530.8641-0.0030-0.00650.0160
lightgbmLight Gradient Boosting Machine0.90600.24930.82520.8637-0.0042-0.00970.0640
lrLogistic Regression0.90210.25050.83290.86320.00550.01680.7870
svmSVM - Linear Kernel0.89570.24760.82700.8594-0.0077-0.01060.0550
gbcGradient Boosting Classifier0.89500.25560.83000.8595-0.00080.00030.4820
adaAda Boost Classifier0.87460.24180.82990.85000.00710.01210.0550
ldaLinear Discriminant Analysis0.85920.23870.82680.8422-0.0151-0.01450.0470
dtDecision Tree Classifier0.82530.24290.82810.8264-0.0157-0.01580.1560
nbNaive Bayes0.25510.19840.83560.35610.00340.00690.1600
qdaQuadratic Discriminant Analysis0.06220.23460.00540.0099-0.0011-0.00310.0390

We then evaluate the models on various subsets of the data as we see fit.

Model trained on synthetic data and evaluated on real test data:

test_predictions = predict_model(best, data=original_test_data)
ModelAccuracyRecallPrec.F1KappaMCC
Random Forest Classifier0.94950.33330.90150.92490.00000.0000

Model trained on synthetic data and evaluated on real validation data:

valid_predictions = predict_model(best, data=valid_df)
ModelAccuracyRecallPrec.F1KappaMCC
Random Forest Classifier0.93200.33330.86860.92490.00000.0000

Importantly, we can compare these results to those of downstream models trained on the original data and evaluated on the same test and valid data splits. 

Model trained on original data and evaluated on real test data:

ModelAccuracyRecallPrec.F1KappaMCC
Random Forest Classifier0.94740.25000.89750.92180.00000.0000

Model trained on original data and evaluated on real validation data:

ModelAccuracyRecallPrec.F1KappaMCC
Random Forest Classifier0.95600.33330.91390.93450.00000.0000

We see in this run that models trained on synthetic data get slightly lower performance metrics than models trained on the original data. This result doesn’t hold true in general, and often performance is quite comparable between downstream classifiers trained on synthetic data vs those trained on the original data. 

Conclusion

In this post we saw that we can train a downstream classifier on 100% synthetic data and achieve performance comparable to a downstream classifier trained purely on the original data. This is encouraging as it suggests we can reap the benefits of synthetic data (e.g., privacy, volume, cost) and still achieve acceptable performance for downstream machine learning use cases. 


Check out our CPO Alex Watson working through this notebook and discussing downstream ML!