Augmenting ML Datasets with Gretel and Vertex AI
Here at Gretel, we are thrilled to officially be a technology partner with Google Cloud Platform (GCP). For the last few years, Gretel has been unblocking Machine Learning (ML) operations by enabling the use of synthetic data to augment or replace ML training data. With this partnership, we're even more excited to bring the power of Gretel to Google’s Vertex AI customers to accelerate MLOps.
In this blog we'll show you how to utilize Gretel to create high-quality synthetic tabular data that you can use as training data for a classification model in Vertex AI.
Overview
Vertex AI is an extremely powerful platform for automating MLOps. Having used the product as part of building this integration, it was very easy to see how Vertex can automate away the complexity of model training with AutoML and model serving with both real-time predictions and asynchronous batch predictions.
Regardless of the MLOps platform you use, there's often friction in getting started because of having limited or unbalanced training data. This is where Gretel plugs in! In this example we'll utilize the popular UCI heart disease dataset to build a binary classifier with Vertex AI and use Gretel to increase the number of training samples and balance the male and female records to eliminate bias in the dataset.
Vertex AI provides a wide variety of notebook tutorials. For this walkthrough, we've created our own version of the AutoML tabular training and prediction tutorial.
Getting started
Before working with the Notebook, the following steps should be taken for setting up GCP and Gretel:
For GCP:
- Select or create a GCP Project
- Enable billing for your GCP Project
- Enable the Vertex and Compute APIs for your GCP Project
For Gretel:
- Create a Gretel account and generate an API key. You will need your Gretel API key when running the Notebook. All Gretel accounts come with free credits, so this tutorial should be runnable within Gretel's free tier limits!
Diving in
With the steps above completed, let's run some code! We suggest using Google Colab or Vertex Workbench for the ease of installing Google SDKs and authenticating with Google's APIs.
Make a copy of the notebook and execute the first few cells to install dependencies and authenticate with Google's APIs.
Next, you'll see a cell that looks like this:
This does two things:
- Saves your GCP Project ID to a variable which we'll use to configure the Google SDK.
- Selects a name for the Gretel Project that will be used. You can leave the default value here as the Gretel SDK will automatically add a slug to the Gretel Project name to make it unique.
The next couple of cells will authenticate with Google and create (or access) the GCS bucket that will eventually hold the synthetic training data we'll be making. Feel free to not update the name of the bucket as the code will create a unique GCS bucket name for you.
NOTE: Make sure to change the region variable to the region you need to use. A default one is already provided.
Your bucket and region configuration should end up looking like this:
There shouldn't be any files in this newly created bucket, so the last cell shouldn't have any output.
Next, we'll take a peek at the training data:
There are 717 records and there's a pretty big delta between female and male-based rows. Vertex AI requires a minimum of 1000 training samples and based on previous exploration of this dataset, we also want to balance out the number of records for males and females. We'll use Gretel to take care of both constraints!
Now that we've explored the dataset, let's use Gretel to train and generate new data!
We authenticate with Gretel and create a Project:
If you follow the URL to the Gretel Console provided at the bottom, you should see an empty Gretel Project. Don't worry, you'll have a synthetic data model there soon.
In the next large code cell, we'll begin training a Gretel Model on the training data. For this particular dataset, we've chosen to use Gretel's Long-Short Term Memory (LSTM) model with record conditioning. We use the Gretel SDK to load our default LSTM configuration as a Python dictionary and make a couple of modifications:
- Disable privacy filtering because we are augmenting our original training data with additional synthetic records, and for downstream ML use cases we recommend testing with filtering disabled.
- Configure the LSTM to use field conditioning, which requires identifying the specific field names. For this use case, we configure conditioning for the "sex" field.
Run the entire cell and the code will run and wait for the Gretel Model to complete training. If you revisit your Gretel Project page, you should see the Model running:
For this particular training data, the model should take 7-10 minutes to train. Good time for a coffee or tea break!
When the Gretel Model finishes training, the Notebook code block will display the Model ID like this:
If you ever need to access this model again, for generating more data, as an example, you can load that back in like so:
Now that our Gretel Model is trained, we can take a look at the Synthetic Data Quality Score (SQS). We can take a peek at the score and also download the full report with this cell:
If you refresh the Files listing in the left pane, you'll see "report.html." You can download and view this report in your web browser, which will give you the full details on the quality and usability of the synthetic data. It will look similar to this:
Next, we'll use our trained Gretel Model and generate a dataset that has an equal number of male and female records and ensure the final synthetic dataset is a combination of the original training data and our synthetic records.
We do a basic calculation to determine how many of each "sex" to generate, and create a DataFrame that has exactly one column: "sex" and its associated values. When we submit this conditioning data back to the Gretel Model and for each "sex" value, the model will generate the rest of the record.
Here we'll be prompting our model to generate 765 female records and 494 male records. Running the next cell will trigger our Gretel Model to generate the new data. If you click into the model in the Gretel Console, you'll see this data generation job eventually be in an "active" and finally a "completed" state:
The data generation job should only take a few minutes. Once this is completed we download our newly generated synthetic data and combine it back with the original training data:
This gives us a total of 1,976 records for training with an even split of male and female records. Additionally, we upload our synthetic training data to our GCS bucket, and now the GCS_SYN_SOURCE variable directly points to the synthetic data we'll now use in Vertex AutoML training.
Run the next few cells, which will create a Vertex Dataset, set up the Vertex model training job, and kick off the AutoML training. When you execute the cell below, you can expect the training to take about ~2 hours:
If you open up the Vertex AI console in your GCP project, and click on "Training" on the left side-bar, you will see your Vertex AutoML model being trained:
When this model is complete, you can explore the evaluation metrics by clicking on the model Name from the table above. Here we show the combined metrics for both target labels:
In the above view, click on the Version Detail tab and copy the Model ID:
Back in our Notebook we'll reload this Model:
Now you can run the remaining two cells which will deploy this model to a real-time HTTP Endpoint and allow real-time predictions. It takes a few minutes to deploy the Endpoint and once it's deployed you can now make real-time predictions like so:
That's it! You've successfully used Gretel to augment an ML training dataset for use in Vertex AI. At this point you can run the rest of the Notebook (if needed) to tear down the resources you created throughout the tutorial.
Next steps: Deployment options
This tutorial uses Gretel Cloud as there is zero infrastructure setup required and it allows users to get familiar with Gretel's capabilities very quickly. Gretel offers options for the Data Plane, which are the infrastructure components that consume training data and produce synthetic data models. When using Gretel Cloud, Gretel manages a backend Data Plane that requires no setup from customers.
We also offer the ability for customers to deploy their own Data Plane, which we call a Gretel Hybrid Deployment. For existing GCP customers, you can deploy the Gretel Data Plane with the following GCP services:
- Source Bucket: A GCS bucket that will hold training data for both Gretel and Vertex (which you may have already created in this tutorial). You can easily export data from Big Query to GCS and then use that exported data to train and generate synthetic data with Gretel.
- Sink Bucket: A GCS bucket that will store Gretel Model artifacts such as the SQS report and the generated synthetic data. This Sink Bucket holds all of the data usually stored in Gretel Cloud when using our managed Data Plane.
- Google Kubernetes Engine (GKE): GKE is used to launch Gretel worker containers that will consume training data from the Source Bucket and write out all relevant artifacts to the Sink Bucket. You can then harvest the data from the Sink Bucket and use that within Vertex.
In this Hybrid Deployment, none of your data is sent to Gretel Cloud. This allows all data processing to happen in your own GCP account by way of GKE.
In a future post, we'll do a deep dive on configuring GKE for use with Vertex. Gretel Hybrid Deployments are currently available as a private beta. If you're interested, please contact us and we'll be happy to work with you!