NVIDIA GTC Spring 2022 Workshop: Outperforming Real-World Data With Synthetic Data
Alex Watson, Co-founder and Chief Product Officer at Gretel.ai hosts a workshop on how to generate synthetic data that is even better than the real thing.
Generate textbook-quality synthetic data for training LLMs and SLMs
How to get started creating high quality synthetic data for LLM training
How to use Gretel's PostgreSQL connector to automate synthetic data generation
Automating synthetic data generation using Gretel and PostgreSQL
Hi, my name is Alex Watson. I'm a co-founder at Gretel.ai.
Today, we're going to start with an overview of synthetic data: what it is, some of the popular use cases that we're seeing around synthetic data, and also from our research, how synthetic data is capable of outperforming real world data for some machine learning tasks.
For a little background on Gretel, we are a startup that helps developers and data scientists build with synthetic data.
Starting at the very top. What is synthetic data? I like this definition from Nvidia: "Synthetic data is annotated information that computer simulations or algorithms generate as an alternative to real world data."
The concept of synthetic data is nothing that's new. It's been around for tens of years, really, since the first time that someone had to create an artificial data set, or one that didn't exist, to demonstrate a solution to a problem.
What is fascinating about synthetic data are the recent advancements we've had in machine learning, whether we're dealing with language models, whether we're dealing with generative adversarial networks or other types of models, we've created the ability for machine-learning models to learn the insights and the distributions, and to be able to recreate those insights and distributions from real world data sets.
Talking about some of the use cases that we see for synthetic data. One is with synthetic data, you have a unique ability to influence the output of that model. One of the things that we can enforce during the time that a model is being trained is privacy. We can make sure the model doesn't memorize rare secrets that might exist in the data, or that the model outputs data that could never be linked back to the original data that it was generated from.
This helps address some of the biggest challenges that we have around data access and sharing today. And highlights the first use case that we see here.
Number one, making data, private data, accessible. One of the patterns we've seen talking with tens or hundreds of customers at this point, are increasingly decentralized development teams that are building in their own environments and incredibly empowered to build what they need to build. However, this naturally creates data silos and it makes it difficult for data sharing to be possible, even data sharing within the same team. For example, if you have a production application that has sensitive customer information, you might want to build a dev test or staging environment that you can use to simulate scale with your service. You can use to test your service. However, one of the big concerns you have is actually copying that sensitive production data into that pre-production environment.
So one of these cases that we can dive in here, is it possible to train a machine learning model on our production data, create a new anonymized version of that production data could never be linked back to the individuals that the data was based on, but it would simulate the same type of dynamism and insights as the original production data that is trained on?
Number two, generating samples from limited data sets. In the age of big data, it's surprising that no one seems to have enough of the right type of data. So here we talk about this pattern of starting from just a few examples. For example, I'm a machine-learning researcher and I'm training a chatbot on utterances or commands that might come from our users. And anytime there's a new type of utterance or command, perhaps the model hasn't seen that well before, is having a hard time with it. So they need to start from just a couple examples, but then generate a multitude of examples from that, is a real challenge.
An evolution of that thinking too, is when you look across the typical data science project and how long it takes to get access to the original data you need, sometimes having that data doesn't exist at all.
So another neat use case for synthetic data are models that are trained on massive amounts of public information. This could be on public location data. It could be on shopping cart behavior. It could be even examples of open source, heart rate, or EKG type monitors, and allowing you to compose your own data sets. So you can quickly test your ideas on realistic synthetic data.
Number three, this ability to influence a model and the type of output that it has has allowed us to address some of the bigger challenges with machine learning today. So essentially, what we've done is created a toolkit around synthetic data that you can use to influence the output of a model. And when you can do that, this is something we'll dive into in our examples today, you have the ability to correct bias. You have the ability to influence the distribution of data, which has implications anywhere from AI ethics to fairness.
So how big is synthetic data and what's new with it. And it's rare that I put a Gartner slide inside my presentations, but I think this is one that seems to highlight some of the potential for synthetic data very well. Here on the diagram in front of us, we see that by the year 2030 Gartner predicts that AI models in the future will be completely overshadowed. Synthetic data will completely overshadow the use of real data inside of the model. How is this possible? And what we're seeing today is we all, I think we all know that the advanced machine learning models require a lot of data to work well. And when we get into some of the more advanced applications of machine learning today, whether it is speech recognition, whether it's self-driving cars, whether it's even kind of face recognition or things like that, there are infinite examples that exist in the real world.
And the training set that can be used to train a machine learning model is based on real world examples that have been annotated. This process can get very expensive, especially as you're working on the long tail, trying to improve your machine learning algorithm. So how do you account for different, if you're creating something to recognize people's faces, how do you account for different shades, backgrounds, things like that. If you're working with voice, how do you account for background noises? How do you account or different tonal variations in speech?
The promise in synthetic data is the ability to work from just a few examples or an idea, and to create a multitude, almost unlimited amount of examples of different permutations of that data that can be used to train machine learning algorithms. So when it comes to the generalization problem, how do algorithms learn to recognize data they've never seen before, synthetic data is a very promising approach.
A second potential trend that we're seeing too is with the trends that are happening in devices. So more work is being done on the computer, on your iPhone, for example, or on your Alexa, than ever before. And what that means is less data is being sent to the cloud. So the traditional approach of annotating data in the cloud and using that to train your models isn't always possible because less data is being sent to the cloud. Some really positive privacy benefits here, and really good things for consumers. But it's created a challenge for people building applications that, for example, need to understand your voice and need to work from there. Whereas the synthetic data will allow you to take advantage of a small amount of samples that you might have, an increasingly small amount of examples, but create algorithms that will generalize the new inputs extremely well.
So here, we're going to pivot a little bit and start talking about the APIs that Gretel has built for synthetic data. And we're going to use these APIs in the next couple slides to work on popular data sets. And we're going to take one particular example and address it both from an accuracy and a fairness perspective.
At Gretel, we have three core APIs. On the far left here you see the area we spend the most time and our research efforts around synthetic data. So here we talk about a couple of the advantages, but what Gretel is really trying to do is make data, synthetic data, APIs, not complicated, not scary, make it available to any developer, any person that wants to sign it and use it.
So here on the left, you see synthetics. This can be language models. This can be gans. We really believe there is no single bullet approach to creating synthetic data.
However, there is a lot that you can do to make it accessible. So anyone working with the data set, how do we create that to be part of the pipeline they have for training the machine learning model, if they are trying to enable access to a data warehouse or to a data set for another team, how do we allow you to really quickly de-identify? So that remove the known variables, names, addresses, things like that, and then synthesize, create a model that can be used to create unlimited amounts of data, either the same size in the shape as original data or 10 times as many records. For example, if you wanted to have additional variations with different levels of privacy. And how do we help you find that right balance between privacy and accuracy for your downstream use cases?
The two APIs here are on the right transforms and the data classification, the labeling are a pre-processing step that we use with synthetic data. As I mentioned a second ago, it's really important to identify the no knowns inside of a data set. For example, if you are training a chat bot on customer service or customer reviews, for example, you want that chat bot to learn from semantics of the data. You don't want to simply redact names though. And that would kind of send the chat bot down the wrong direction. So one of the options here are using the data classification APIs to identify PII names, addresses, credit cards, things like that that should never find their way into machine learning model transforms lie to create a simple policy, to replace that with a fake version to encrypt it in place, to redact it, to drop the record. Really, you have a complete building block of different options here.
And then finally, synthetics, which really gives you the finite control over what you went into with that data. One of the top questions that we get around synthetic data is how accurate is it? And the use case being I'm trying to enable data access inside of my organization, and I want teams to be able to share data, I want teams to be able to access our awesome data warehouse. How do I make that possible? Question number one. Synthetic data's a promising approach for that, where you can create a synthetic twin of your original data set that has increased privacy guarantees and you as a developer inside your business, instead of waiting two weeks, four weeks, even six months, in some cases, as we've seen in the genomics world, to get access to data sets. What if I could get access to a data set right now that had 97, 98% the accuracy of the real world data? As a developer, as a data scientist, do I always need to see real names? Do I always need to see real addresses? The answer is no. Often we don't.
Traditional de-identification, that would be replacing names and addresses, but keeping the rest of the records the same, has been proven over and over again in the privacy space to have to be inadequate, really, to protect the privacy of the users that it's based on. Simplest example I could give would be the Netflix challenge, for example. Where a couple years ago, Netflix listed a competition on data science platform Kaggle, and they de-identified a hundred million different movie reviews and they did an excellent job de-identifying this data set. It had only a movie ID and a user ID and a date and a number of stars that they gave for the movie. They ended up having to pull down the competition because some of the competition teams realized that just that combination of the date review was done and a movie ID and a user was identifying enough that they could unmask the users by joining it with movie reviews that they had seen, for example, on IMDB.
So this is a real challenge. This is called a data linkage or re-identification attack. Synthetic data, how it helps get around this is it trains on the overall Corpus of data. And it creates a new Corpus of data where none of those individual records are linked to a real record that might exist in a database somewhere else. So it's one of the real promises of synthetic data. The question is, how accurate is this new Corpus of data that I created? And we wanted to go with some examples we thought would be right in line with data scientists and developers are working with here. We took the eight most popular data sets on Kaggle. So just looking at the hotness or the relevance metric on Kaggle and used completely default parameters for Gretel synthetics and created another data set of the exact same size and shape as the original real world data that is trained on.
Then we took, as you can see here, and we can dive into the details on how this worked, we ran each one of these on a downstream use case or task that was associated with this. So we took the top notebooks on Kaggle that were running on real world data. And then we ran it on our synthetic version as well. And we compared the results. And really excited results here. You see for just complete replacement, not augmentation, complete replacement of data, here you see for stroke prediction, which we will dive into, in later example, we actually saw an improvement here using the synthetic data. That means the model must have keyed or learned on something that helps with the downstream analysis. In some cases you see a slight degradation performance. Here we can see on the data science job candidates use case and things like that. But really important thing is overall just using our default parameters here, we had only a 2.58% decrease in accuracy between the real world data and the synthetic data that was created. So pretty exciting initial results here.
So if we can get within 2.58% using our standard data classification or sorry, our standard synthetic data libraries, can we make it better? And can we improve on real world data? And what types of problems are we seeing in this space right now anyways? Here, I wanted to highlight just how powerful, and I think many people in this audience understand just how powerful machine learning models are in our lives and how much they influence things increasingly day to day that are really important to us. These are based on data sets and often those data sets, I think we all know are limited. They may not have the right distribution of people that they're based on. They might be out of date. All sorts of different things that can happen that impact from a fairness perspective or an accuracy perspective, impact us in a very real way.
A couple of the examples to call out here, you know, whether you're getting hired for a job or not, right? So organizations increasingly creating bots that will scan resumes and help with that really manually intensive process of scanning resumes and selecting candidates that would be selected for a following phone screen. We've seen through examples, there was a famous kind of Amazon example where those models were largely focusing in on the data that they were trained on, which was largely male candidates. And they were actually disqualifying different terms that female candidates used inside of their resumes. So really dangerous example there. Medical use cases, diagnosing heart disease, things like that. Even the ability for us with a slightly different voice pattern, inflections or slang to talk to the devices around us, how do we make sure we have a good representation across all different possibilities of people and demographics and things like that to ensure that these algorithms are going to give us a fair response?
Well, there is no bullet, but we have tools. And that's what we're going to dive into here and talk about how we might be able to make this better. One of the times this was described to me really well by one of our customers who is running a major data science team for a really large gaming company said, "If there are biases in the real world, in the virtual world, they're often magnified." We want to make sure and provide a set of tools to help developers and data scientists here is to help them influence this, to minimize new biases being introduced and to create essentially the most fair experience possible.
So I think it's time now to jump into a real use case. So we will provide some code. You're welcome to follow along and run this yourself. But, starting at one of those data sets we saw earlier, an extremely popular data set, one of the top five most popular data sets on data science platform, Kaggle is this heart failure prediction data set. It was published by the folks at UCI. So really incredible data set. And it's been one of the canonical examples for data classification techniques on this data science platform for quite some time. The question is, what is the distribution of the data that this was trained on? Where did this come from? Are there any gaps in this data that we might want to address? And if we can address those gaps, can we make a better data set? And better, we use the definition here, could be either more accurate, so better overall performance for the data set or more fair. When you look at some of the different categories inside of this data set, for example, we look at age distribution, we look at here we use the term sex for male or female. If we look at the location the user came from, things like that, how do we make sure that those demographics are evenly balanced inside of the data set? If we do that, what impact does that have and the overall accuracy of our data?
So what we're going to do is we are going to train a synthetic model on this data set, and we're going to create essentially another version of this, where we will balance out one of the attributes that stood out really quickly here. And here we see 32% female records over 68% male records. So a theory here is that an algorithm trained on this data set would possibly over index on being really great at male heart disease detection, and pay less attention to minority classes, for example, different age groups, different sexes, genders, things like that that might exist in the data. And if we balance this out, so essentially help the algorithm to create a more fair response by boosting the representation in this case, just a single attribute, we're going to try boosting the female representation. How does that affect overall performance?
So diving right in, we'll provide links to run all this code yourself. One of the areas I like to do first is just run a parameter suite. So what we did is we took this data set, loaded it in Gretel through a Python notebook, which we can walk through here. And we tried a set of different parameters. Really our goal here is maximum accuracy and maximum fairness. So how do we get there? You can use the default settings as we've shown earlier, which work quite well. Or we can try a bunch of different parameters. We can try our downstream use case here. So I ran a standard random forest classifier on the results and tried to predict heart disease detection with the purely synthetic data set. And here we can get a good feeling for which different sets of APIs and configurations here work best.
So we'll jump right over to the actual parameter sweep here, and we could take a look. So here's one, we had the best classification accuracy here for the pure synthetic version here. And we can see batch size, relatively small number of epox, a pretty conservative learning rate. Since this is a small data set, this is kind of interesting here. A smaller number of RNN units for this synthetic data task and a vocabulary sized. So essentially allowing, we use something out of the hood called sentence piece, which is built by Google to find tokens within the input data. So using that versus pure character based tokenization seemed to work pretty well. So this is the configuration we're going to start with essentially using this to train the neural network. We will have the neural network essentially output additional records so we can balance out the male and female class. And then we'll compare essentially the pure real world data to the real world data plus augmenting it with a balanced number of female and male records. Let's jump right over to that use case.
Let's jump right into our example. So we'll go to docs.gretel.ai. From there, click on STK notebooks. And go down to the bottom. And you can see an example called "Improve accuracy, heart disease". So we're going to go ahead and click that. This is going to open up a CoLab notebook that will walk us through our entire experiment. So here we see it loading up in the free Jupyter notebook experience with CoLab. We'll go ahead and clear out the outputs that exist. Click run all. So we're going to go ahead and run this. We'll run the entire experiment using the parameters that we selected in the previous parameter sweep. While it's installing dependencies, we're going to go ahead and log into the Gretel console service here. We need to access the API key. So the notebook knows how to talk to our cloud service. Go ahead and sign in with my Gmail ID.
It's just about finished installing dependencies. Next, it's going to go ahead and ask for my API key, which we can go ahead and enter in and copy that. As we could see here, the next step is loading up the train and test sets. So pulling down the example heart disease data set from Kaggle, we did the 70/30 split, created two different sets. What we're going to be doing here is comparing their real world performance of the training set versus a train set that's been augmented with synthetic data. So that's what we're going to do for the rest of this notebook. So go ahead and get it, go ahead and let it run.
So similarly to what we saw in the original graph that we just showed in the previous slide show here, we see a big skew in distribution between male and female records. And what we're going to do here is create a synthetic model train on the real world data set, tell it to boost the representation, in this case of female records. And we're going to see how that data set that performs against the original real world data. Next, we have a really important part. So what's happening here is we are training a synthetic model on the real world data set. So here you can see we're bringing in the Kaggle client. We pull back the default configuration parameter, we'll make a few updates to match the parameters that we had from our parameters suite.
And here we're going to use that conditional data generation task. This is a really kind of neat task with more advanced machine learning models, where I'm defining a single field, in this case, sex. You could define multiple fields. You could build something for a particular age range, plus sex or different heart rate. You know, really any of the different features you wish. In this case, as we said, we're just be balancing out this sex attribute. We're going to define this as a seed task. This is the same thing as conditional model generation. Essentially, we are telling the model what type of record we want it to produce. And as we indicated earlier, we'll be telling it to produce additional female records to balance out the quantity of male versus female records in the training set. Why would we do this? I think is an important question. Why balance this out? Why not just gather more data? Often the steps necessary to reproduce and experiment and get the same results become prohibitively expensive, if not impossible to recreate.
So often we are limited with the data that we have. And in this case, there was no way to go back and recreate additional heart patient experiments with the same methods and procedures that were used before. So this leaves us with the option of boosting the representation across that data set. In this initial run, we're going to tell it to generate 500 records. We can use that to look at the quality of the synthetic model, essentially generate data, throw the kitchen sink from a statistical perspective at the real world data versus synthetic. And we can compare the two. You can see here also we are turning off the default privacy filters. We're working with the real small data set here. We want to make sure we capture every insight. So in this case where privacy filters often give very little actual real world accuracy hit, when we're going from maximum performance on a really small data set like this. In this case, we've made the decision here to remove this and laity the outlier filtering.
This will create a new project in Gretel called UCI heart disease. And we'll take a look at that in just a minute. And we're going to go ahead and kick off training. So here we can see, training is starting. What's happening is behind the scenes here, the cloud is firing up a container with access to a GPU. It's going to start processing this data set and it's going to iterate over the data set until the model has effectively learned the parameters of the underlying data set. Let's go ahead and let that start.
So here we can see it's loaded it up. It has started creating the data set. It's going to start creating the validators here in a minute. Essentially what validators do is they ensure that the output of the neural network matches the same types of distributions as the original real world data. So one of the downsides of neural networks is they can output anything. When you're looking for high quality synthetic data, one of the things that we can do to make sure that it works better is to enforce that the somewhat random output of interval network makes sense for your particular use case based on what we saw in the training data. Here, we can see it trained for about 44 epox. I sped this up a few minutes just so you guys didn't have to wait through it. Here we see 44 epox we see very good accuracy. So we're at 90% prediction accuracy. Loss is very small. So you want to see accuracy going up. You want to see loss going down. We're looking good on that. We had to generate 500 records. So it's just starting that record generation process now. This invalid count, which you're seeing here, is some of these re records didn't pass validation. So something about that record, whether a floating point number was too high or a new category was invented that didn't exist, something like that occasionally happens, we'll see that being dropped.
As that data set's being generated here, we're going to go back to the Gretel console and we can see this model and let's go ahead and take a look at the performance of this model and how it's doing. So the model train pretty quickly. We can see the records being generated, just like what you're seeing in the Python notebook over there.
And record generation is complete. At the end, we generate a synthetic quality score. You can see here is very good. So we're happy with our ability to reproduce the insights and distributions to the original data. Privacy protection level is much lower since we turned off the default privacy mechanisms. Looking through here, we see a quick summary of the privacy statistics and then diving right into the accuracy, which is so important to us. This is my favorite graphs to look at. What we see here are the correlation graphs that exist between the training data. So the correlations that sat in the original training set and the synthetic data set, and what we're looking for is how closely is it able to match that. You don't want to recreate it exactly, you want to match the correlations. So here, where we see stronger correlations, for example, between these two different sensor readings here, we want to see them be as close as possible. See 0.41 in the training data looks like synthetic correlation may have been slightly stronger in this case, but still very close at a 0.45.
This is this second one of my favorite graphs to look at. Principal component analysis. This really helps you understand that whether the model overfit on a few features in the data, or if it was able to kind of recreate some favorite in a data science toolkit, whether it was able to recreate the same type of dimensionality as the original data. So we're looking for distributions and shape between these two different features to be very similar. Looking at that, even for such a small data set that we're working with here with only a little bit over a thousand rows, we see it learned it pretty well. We see the distribution shape being just about the same. So it's another kind of sense of confidence that it learned the distribution fairly well.
Finally, when we're looking at per field or per column distributions, we've got the field distribution plots. This shows you for every value in the synthetic data set we created versus the data that was trained on, how closely are we matching the distributions? You don't want this to match exactly. What you're looking for are patterns that are close, but not the same. You see a nice kind of [inaudible 00:28:50] look on some of these distributions, which is very good. So we see things being close, but not quite the same. The intent of this report is really to give you a sense of confidence that yes, like this data set, this synthetic model, learned the nuances of my data and is able to recreate it. If we task the model to create enough records, so we created a total 2000 records here that we're using to augment our original training set, that we would create an even distribution between male and female records here.
So here's our augmented data set, synthetic plus training data. We'll take a look at it very close within one. So it was a math error on my part. And next thing we're going to do is run a set of data classification tasks on there. So we are comparing in this case, we're comparing real world data versus our augmented synthetic data set. And we're comparing the accuracy for a set of downstream classifiers. So does this patient, particular patient in the test data set that we isolated at the very beginning, have heart disease or not? Really encouraging results that you see here as we go through here. So we took six or in this case five, the more popular classifiers that are out there, random forest, decision trees, XG boost, SVM, and we ran the classification task as you can see here for each one of these. On the left, you see real world data. The accuracy for predicting heart disease here in this case with random forest was 95% on this data set.
So pretty good accuracy. The really cool and inspiring thing here is we actually see an increase in synthetic data accuracy. So when we're using our synthetic data set on four out of five use cases here. So support vector machine, for some reason, the accuracy from our synthetic model was slightly lower. But in many of these cases, using this boosted data set with an equal number of male and female patient records, we're able to achieve just as good, if not better, accuracy that we did with the real world data.
Going on a little bit more. Now we're going to take a look at the distributions, not just at accuracy, but let's take a look at fairness. So essentially we're going to break down some of those minority classes. In this case, we're going to take a look at the sex attribute and see for males and for females across synthetic data and real world data, which one performed the best. We'll jump back over to the slide so we can see this a little bit better.
So we'll jump back into the slides here and take a look at the results from our notebook. So pretty exciting results. As we saw here just a second ago, we see improvements in four out of five classifiers working with the augmented data set versus the real world data and pretty close here. So about a 1.3% improvement in accuracy, which is pretty substantial coming from an 88.97% start. So pretty exciting. Looking at the class distribution here. And this is where you start looking at fairness. How well did the algorithm work for different, in this case, we're looking at the sex attribute inside of the data set, which was defined as male and female. How well do we perform here? So in the light colored blue on the left here, we see real world male performance, heart disease using the real world data set versus the augmented data set.
So we see, for example, a random forest, see a 1% improvement here. And we see with random forest female heart disease detection staying very consistent. However, across the rest of the algorithms here, you see sometimes really substantial improvements here. SBM is up nearly 15% in heart disease detection for females, XG boost, which is I think widely regarded is a really excellent classification algorithm, the accuracy went from 96.63% to 100% for female with heart disease detection using this augmented synthetic data set. So really encouraging results across the board for females. As you can see here at the top, we had a 4.49% improvement in heart disease. And 1.25% improvement for males resulting in an average 1.3% improvement across the entire data set. So really exciting and something that many of the customers we're working with right now are working on. How do we open it and build better data sets? How do we automate this process and create kind of the most fair and inclusive responses that we can?
If you'd like to run through any of this stuff yourself, feel free. We've got tons of open source examples. It's free to use Gretel. We have a developer here. So you've got several hours of compute that you can use for free, or you can deploy Gretel and run inside your own environment as well. So here's a link to YouTube channel. I will drop links to our notebooks and our docs below, but docs.gretel.ai works as well. If you have any questions and would like to follow up on your use case, ask any questions and have a discussion. Either join us on our community slack, where many of us spend a lot of our time. You can get to that at Gretel.ai/slackinvite or reach out to us at firstname.lastname@example.org. Thank you very much for your time. Bye.