Reducing AI bias with Synthetic data

Generate artificial records to balance biased datasets and improve overall model accuracy.
Source: Kubkoo, via iStockPhoto
Source: Kubkoo, via iStockPhoto

In this post, we are going to explore using synthetic data to augment a popular health dataset on Kaggle, and then train AI and ML models that perform and generalize better, while reducing algorithmic bias.

The Heart Disease dataset published by University of California Irvine is one of the top 5 datasets on the data science competition site Kaggle, with 9 data science tasks listed and 1,014+ notebook kernels created by data scientists. It is a series of health 14 attributes and is labeled with whether the patient had a heart disease or not, making it a great dataset for prediction.

Image for post
Overview of the UCI heart disease dataset on Kaggle

A quick look at the dataset shows that male patient records account for 68% of the overall dataset, with female patient records at only 32%. With a 2 to 1 ratio of male to female patients, this could result in algorithms trained on the dataset over-indexing on male symptoms and performing poor diagnoses for female patients. There is no substitute for having an equal representation of groups in training data, especially with Healthcare. In absence of that, how do we reduce biases in our input data as much as possible?

“By augmenting the training set with synthetic records, can we reduce the gender bias and improve ML accuracy?”

To test our thesis, we will use Gretel.ai’s open source synthetic data library to generate additional female patient records to attempt to compensate for the biased training data. Our hope is that this will help the classifiers improve predictions for heart disease for both male and female patients, and generalize better to unknown data. We can then run the synthetic dataset through ML algorithms on Kaggle to compare results vs. the training set.

A top data science notebook on Kaggle (by forks, linked below) runs a series of 6 classification algorithms on the UCI dataset and compares the resulting model accuracies. The notebook splits the original dataset into a train (80%) and test (20%) split, which we save to disk as train.csv and test.csv. We will use train.csv to train our synthetic model, and test.csv to validate the results with the 6 classification algorithms from the notebook.

https://www.kaggle.com/cdabakoglu/heart-disease-classifications-machine-learning

Train a synthetic data model

To generate your own synthetic records, launch Gretel-synthetics via Google Colaboratory, or check out the notebook directly on our Github. Click “Run all” in Colaboratory to download the training dataset exported from Kaggle, train a model, and generate new patient records to augment the original training data.

Configure the following settings for your synthetic data model- note that we found a good balance of generalization vs model accuracy with 15 epochs of training and 256 hidden units. Note that in this configuration, training with differential privacy is not necessary as the dataset has already been de-identified.

To test the theory about algorithmic bias, we added a very simple custom validator to the notebook that only accepts Female records generated by our synthetic model (column 1-Gender is equal to 0).

We are now ready to train a synthetic data model on our input data, and use it to generate 111 female patient data records to augment our training set.

Image for post

The synthetic model quickly learned the semantics of the data, and trained to 95%+ accuracy within 10 epochs. Next, download the generated dataset and let’s run it on Kaggle!

Run the experiment

Now, let’s go to Kaggle and run the classification notebook (with minimal edits to allow both the original models and our augmented models to run). To make this easy, you can load the modified notebook here. By default, it will run with our generated test set. To use your own, upload the generated dataset from the previous step to the Kaggle notebook.

https://www.kaggle.com/redlined/heart-disease-classifications-machine-learning

Results

As we can see below, creating and adding synthetically generated patient records to the training set increased accuracy in 5 out of 6 classification algorithms, achieving 96.7% overall accuracy for KNN (up from 88.5%), 93% for Random Forest, and 13% gains for the Decision Tree classifier against the models trained on the non-synthetic dataset. It is possible that Naive Bayes accuracy dropped as the algorithm makes a strong assumption that all features are independent (hence Naive), and the synthetic data model likely learned and replayed correlations in the training data.

Image for post
Synthetics average accuracy: 90.16%, original: 85.79%. Average improvement: 4.37%
Image for post
Overall model performance (baseline vs. baseline + synthetics)

Finally, looking at the model accuracy results by gender, the original model accuracy average for Females was 84.57%. By augmenting female patient data records and training the same models, accuracy improved to 90.74%. Interestingly, male patient data prediction accuracy improved as well from 86.61% to 90.71%.

6.17% more females with heart disease can now be accurately diagnosed!

Conclusion

At Gretel.ai we are super excited about the possibility of using synthetic data to augment training sets to create ML and AI models that generalize better against unknown data and with reduced algorithmic biases. We’d love to hear about your use cases- feel free to reach out to us for a more in-depth discussion in the comments, Twitter, or hi@gretel.ai. Follow us to keep up on the latest trends with synthetic data!

Interested in training on your own dataset? Gretel-synthetics is free and open source, and you can start experimenting in seconds via Colaboratory.