Transforms and Multi-Table Relational Databases

How to de-identify a relational database for demo or pre-production testing environments while keeping the referential integrity of primary and foreign keys intact.
Copyright © 2022 Gretel.ai
Copyright © 2022 Gretel.ai
Gretel Relational is no longer supported as of October 7, 2024

Intro

The ability to share private, de-identified data is a rapidly growing need. Oftentimes, the original non-private data resides in a multi-table relational database. This blog will walk you through how to de-identify a relational database for demo or pre-production testing environments while keeping the referential integrity of primary and foreign keys intact.

You can follow along with our Gretel Transform notebook:

Our Database

The relational database we'll be using is a mock ecommerce one shown below. The lines between tables represent primary-foreign key relationships. Primary and foreign key relationships are used in relational databases to define many-to-one relationships between tables. To assert referential integrity, a table with a foreign key that references a primary key in another table should be able to be joined with that other table. Below we will demonstrate that referential integrity exists both before and after the data is de-identified.

Example of an ecommerce relational database.

Gathering Data Directly From a Database

After installing the necessary modules and inputting your Gretel API key, we first grab our mock database from S3, and then create an engine using SQLAlchemy:

from sqlalchemy import create_engine

!wget https://gretel-blueprints-pub.s3.amazonaws.com/rdb/ecom.db
    
engine = create_engine("sqlite:///ecom.db")

This notebook can be run on any database SQLAlchemy supports such as PostgreSQL or MySQL. For example, if you have a PostgreSQL database, simply swap the `sqlite:///` connection string above for a `postgres://` one in the `create_engine` command.

Next, using SQLAlchemy's reflection extension, we gather the table data.

# Gather the table data from the database

from sqlalchemy import MetaData, text

# This is the directory where we will temporarily store csv files for the transformer model
base_path = "./"

metadata = MetaData()
metadata.reflect(engine)

rdb_config = {}
rdb_config["table_data"] = {}
rdb_config["table_files"] = {}

for name, table in metadata.tables.items():
    df = pd.read_sql_table(name, engine)
    rdb_config["table_data"][name] = df
    filename = name + ".csv"
    df.to_csv(filename, index=False, header=True)
    rdb_config["table_files"][name] = filename

We then crawl the schema and produce a list of relationships by table primary key.

# Extract primary/foriegn key relationshihps

from collections import defaultdict

rels_by_pkey = defaultdict(list)

for name, table in metadata.tables.items():
    for col in table.columns:
        for f_key in col.foreign_keys:
            rels_by_pkey[(f_key.column.table.name, f_key.column.name)].append((name, col.name))

list_of_rels_by_pkey = []

for p_key, f_keys in rels_by_pkey.items():
    list_of_rels_by_pkey.append([p_key] + f_keys)

rdb_config["relationships"] = list_of_rels_by_pkey
    filename = name + ".csv"
    df.to_csv(filename, index=False, header=True)
    rdb_config["table_files"][name] = filename

Take a Look at the Data

Now let's join the order_items table with the users table using the user_id.

table_to_view1 = "order_items"
table_to_view2 = "users"
df1 = rdb_config["table_data"][table_to_view1]
df2 = rdb_config["table_data"][table_to_view2]

joined_data = df1.join(df2.set_index('id'), how='inner', on='user_id', lsuffix='_order_items', rsuffix='_users')
print("Number of records in order_items table is " + str(len(df1)))
print("Number of records in user table is " + str(len(df2)))
print("Number of records in joined data is " + str(len(joined_data)))

show_fields = ['id', 'user_id', 'inventory_item_id', 'sale_price', 'shipped_at', 'delivered_at', 'first_name', 'last_name', 'age', 'latitude', 'longitude']
joined_data.filter(show_fields).head()

Below is the output. Note how every record in the order_items table matches a distinct record in the users table. A primary goal of this notebook is to show how we can run transforms on the tables in this database and maintain these relationships.

Output of joining order_items and users tables.

Define Our Transform Policies

Now we need to define a transform policy for any table that contains PII or sensitive information that could be used to re-identify a user. We won't include a transform for any of the primary/foreign key combinations, as we'll be handling those separately. Let's take a look at the transform policy for the users table.

schema_version: "1.0"
name: "users_transforms"
models:
  - transforms:
      data_source: "_"
      use_nlp: false
      policies:
        - name: users_transform
          rules:
            - name: fake_names_and_email
              conditions:
                field_label:
                  - person_name
                  - email_address
              transforms:
                - type: fake
            - name: date_shift
              conditions:
                field_name:
                  - created_at
              transforms:
                - type: dateshift
                  attrs:
                    min: -400 
                    max: 65 
                    formats: '%Y-%m-%d %H:%M:%S UTC' 
            - name: numeric_shifts
              conditions:
                field_name:
                  - age
                  - latitude
                  - longitude
              transforms:
                - type: numbershift
                  attrs:
                    min: 10
                    max: 10

Within the "rules" section, we define each type of transformation we want, each one beginning with "- name". We start by replacing any field classified as a person’s name or email address with a fake version. Note, we chose to leave several of the location fields as is, such as "state" and "country,'' since it's public knowledge that this database is about user ecommerce transactions in Arizona. We then transform the "created_at" timestamp using a random date shift. And finally, we transform the numeric fields of age, latitude and longitude with a random numeric shift. Note, we did not transform "id" because it is a primary key that matches a foreign key. We will have special processing for primary and foreign keys later that ensures referential integrity is maintained. 

Each policy should reside in its own yaml file and the locations for each are made known to the notebook as follows:

policy_dir = "https://gretel-blueprints-pub.s3.amazonaws.com/rdb/"

transform_policies = {}
transform_policies["users"] = "users_policy.yaml"
transform_policies["order_items"] = "order_items_policy.yaml"
transform_policies["events"] = "events_policy.yaml"
transform_policies["inventory_items"] =  None  
transform_policies["products"] = None
transform_policies["distribution_center"] = None

Model Training and Initial Data Generation

We first define some handy functions for training models and generating data using the policies we defined above.

import yaml
import numpy as np
from smart_open import open
from sklearn import preprocessing
from gretel_client import create_project
from gretel_client.helpers import poll

def create_model(table:str, project):

    # Read in the transform policy
    policy_file = transform_policies[table]
    policy_file_path = policy_dir + policy_file
    yaml_file = open(policy_file_path, "r")
    policy = yaml_file.read()
    yaml_file.close()

    # Get the dataset_file_path
    dataset_file = rdb_config["table_files"][table]
    dataset_file_path = base_path + dataset_file

    # Create the transform model
    model = project.create_model_obj(model_config=yaml.safe_load(policy))
    
    # Upload the training data.  Train the model.
    model.data_source = dataset_file_path
    model.submit(upload_data_source=True)
    print("Model training started for " + table)
    
    return model
 
def generate_data(table:str, model):
    
    record_handler = model.create_record_handler_obj()
    
    # Get the dataset_file_path
    dataset_file = rdb_config["table_files"][table]
    dataset_file_path = base_path + dataset_file

    # Submit the generation job
    record_handler.submit(
        action = "transform",
        data_source = dataset_file_path,
        upload_data_source = True
        )
    
    print("Generation started for " + table)
    
    return record_handler    
        

Now that we have these functions defined, we can easily run all the training and generation in parallel in the Gretel Cloud. You can find the details of how to monitor this process in the notebook code here. The key API call for checking a model status is as follows:

model._poll_job_endpoint()
status = model.__dict__['_data']['model']['status']

The value of a model status begins with "created", then moves to "pending" (meaning it’s waiting for a worker to pick it up). Once a worker picks it up, the status becomes "active". When the job completes, the status becomes "completed". If there was an error at any point along the way, the status becomes "error". Similarly, the key API call for checking generation status (all the same valid values) is:

rh._poll_job_endpoint()
status = rh.__dict__['_data']['handler']['status']

Note "model" is returned by the above "create_model" function and "rh" (record handler) is returned by the above "generate_data" function.

Transforming Primary/Foreign Key Relationships

To ensure referential integrity on each primary key/foreign key table set, we will first fit a scikit-learn Label Encoder on the combined set of unique values in each table. We then run the Label Encoder on the key field in each table in the set. This both de-identifies the keys as well as serves to ensure referential integrity, which means a table with a foreign key that references a primary key in another table should be able to be joined with that other table. The code to accomplish this is shown below.

def transform_keys(key_set):
    
    # Get array of unique values from each table, can use dfs in transformed_tables    
    field_values = set()
    for table_field_pair in key_set:
        table, field = table_field_pair
        field_values = field_values.union(set(transformed_tables[table][field]))
        
    # Train a label encoder
    field_values_list = list(field_values)
    le = preprocessing.LabelEncoder()
    le.fit(field_values_list)
    
    # Run the label encoder on dfs in transformed_tables
    for table_field_pair in key_set:
        table, field = table_field_pair
        transformed_tables[table][field] = le.transform(transformed_tables[table][field]) 

# Run our transform_keys function on each key set
for key_set in rdb_config["relationships"]:
    transform_keys(key_set)

Take a Look at the Final Data

We will now show the same join on the order_items and users table that we did on the original data, but now on the transformed data.

Joined tables with new transformed data.

Once again, each record in the order_items table matches to a distinct record in the users’ table.

Load Final Data Back into Database

To wind things up, the last step is to now load the final transformed data back into the database.

!cp ecom.db ecom_xf.db
engine_xf = create_engine("sqlite:///ecom_xf.db")

for table in transformed_tables:
    transformed_tables[table].to_sql(table, con=engine_xf, if_exists='replace', index=False)

Conclusion

We've shown how easy it is to combine direct access to a relational database with Gretel's Transform API. We’ve also demonstrated how large multi-table databases can be processed in parallel in the Gretel Cloud. And finally, we've demonstrated a technique for ensuring the referential integrity of all primary/foreign key relationships. Coming soon, we'll show you how to accomplish all of this using Gretel Synthetics. 

Thank you for reading!  Please reach out to me if you have any questions at amy@gretel.ai.