Downstream ML classification with Gretel ACTGAN and PyCaret
In a recent post we discussed how to safely avoid linkage attacks by using synthetic data. The synthetic data was used to train a downstream machine learning classifier. Classifiers require accurate and high quality data before being usefully deployed. In this post, we dive into using synthetic data to train a machine learning model to determine if a customer will purchase a certain product.
When practitioners use the word “downstream,” they're typically referring to a step in the system that happens after the data has been processed and transformed. For example, if we were to train a machine learning classifier and then use it to make predictions about some customer behavior, those predictions could be used downstream to take further business action on behalf of the customer.
However, we can also refer to the machine learning classifier itself as being downstream if we do a series of transformations to the training data. That's the case here, where we create a synthetic version of the data that we then use to train a downstream classifier which itself may have other downstream effects. We have a notebook where you can follow along.
I’ve generated synthetic data, now what?
The data we use as an example here is our grocery store data.
Every food column after the first few metadata columns contains how many of an item a person bought in a single transaction. The column we’ll focus on predicting in this case is `Frozen Pizza`, but we could choose any column of interest.
Since we are going to train both a synthetic data generating model and a downstream classification model, we need to hold out a small validation set. This validation set isn’t seen by the synthetic model or the classification model, and its purpose is to test the eventual classification performance of a classification model trained purely on synthetic data and validated on unseen real data.
This is an additional step in the traditional machine learning pipeline and ensures that our classification model trained on synthetic data can be used for real world data without data leakage.
We can use the remaining 95% of the data to train our synthetic model.
In this instance, because we have over 100 columns, we want to use a model that handles high-dimensional data well. Additionally, we want to make sure that our model can handle any columns that might have a large number of purchases. For that reason, we’ll use our newly released Gretel ACTGAN model. This model is GAN-based, and highly effective for tabular data generation. Its improved memory usage, speed, and accuracy make it an excellent choice for this use case.
Once we’ve trained our synthetic model and verified its quality we can generate records to train our downstream classifier. We generate the same number of records as we had in our training data, which results in a synthetic dataset with high downstream utility.
We split this synthetically generated data in the same way we would split the original data if we were using that directly for our model training.
We also split the 95% data so we can have additional validation for our downstream model’s performance.
Train and evaluate downstream ML models with PyCaret
A data practitioner would spend a great deal of time selecting a model and validating its performance. These steps can be accelerated using an AutoML tool like PyCaret, which we do here.
We first see a table that outlines information about the data and proposed training run. This table shows the size of our data, number of features, and other useful information. A subset of the information is shown here:
After the 14 models have trained on a 10-fold cross-validation, we see the results. In this case, our best model, as measured by accuracy, is a Random Forest Classifier. The actual numbers will change between runs due to the stochastic nature of synthetic data generation.
We then evaluate the models on various subsets of the data as we see fit.
Model trained on synthetic data and evaluated on real test data:
Model trained on synthetic data and evaluated on real validation data:
Importantly, we can compare these results to those of downstream models trained on the original data and evaluated on the same test and valid data splits.
Model trained on original data and evaluated on real test data:
Model trained on original data and evaluated on real validation data:
We see in this run that models trained on synthetic data get slightly lower performance metrics than models trained on the original data. This result doesn’t hold true in general, and often performance is quite comparable between downstream classifiers trained on synthetic data vs those trained on the original data.
Conclusion
In this post we saw that we can train a downstream classifier on 100% synthetic data and achieve performance comparable to a downstream classifier trained purely on the original data. This is encouraging as it suggests we can reap the benefits of synthetic data (e.g., privacy, volume, cost) and still achieve acceptable performance for downstream machine learning use cases.
Check out our CPO Alex Watson working through this notebook and discussing downstream ML!