Comprehensive Data Cleaning for AI and ML

Learn to prepare tabular data for AI and ML with an end-to-end data cleaning workflow.

Data cleaning is a fundamental first step in any ML or AI project. You’ll often hear the phrase “garbage in, garbage out”. What that means is, if the data you enter is really messy, that increases the chances of your project results being messy. In this blog, we step you through some approaches to data cleaning with a particular eye towards AI/ML. Please feel free to follow along with our Jupyter notebook here.

The issues we address are as follows. Note, the order is such that earlier steps can benefit the analysis of future steps.

  • Standardize empty values
  • Remove duplicate records
  • Highly missing fields
  • Missing value imputation
  • Redundant fields
  • Capping high float precision
  • Constant fields
  • Field level outliers
  • Record level outliers

Our dataset

The dataset we use is the popular Adult Census Income dataset that has been modified to include an example of all the different types of problems we may encounter. We can read in the dataset as follows:

import pandas as pd
data = pd.read_csv("https://gretel-blueprints-pub.s3.us-west-2.amazonaws.com/data_cleaning/messy_data.csv")
data.head()
age workclass fnlwgt education education.num marital.status occupation relationship race sex capital.gain capital.loss hours.per.week
0 90.0 ? 77053.0 HS-grad 9.0 Widowed ? Not-in-family White Female 0.0 4356.0 40.0
1 82.0 Private 132870.0 HS-grad 9.0 Widowed Exec-managerial Not-in-family White Female 0.0 4356.0 18.0
2 66.0 ? 186061.0 Some-college 10.0 Widowed ? Unmarried Black Female 0.0 4356.0 40.0
3 54.0 Private 140359.0 7th-8th 4.0 Divorced Machine-op-inspct Unmarried White Female 0.0 3900.0 40.0
4 41.0 Private 264663.0 Some-college 10.0 Separated Prof-specialty Own-child White Female 0.0 3900.0 40.0

Standardize empty values

Missing values can be represented by a variety of fields such as "?", "Missing", “N/A,” or “Not applicable”. Standardize all these values to be `np.nan`. This will simplify the imputation of missing values. You can see in our dataset that “?” is often used to represent missing data.

import numpy as np
data['workclass'] = np.where(data['workclass']=='?',np.nan, data['workclass']) 
data['occupation'] = np.where(data['occupation']=='?',np.nan, data['occupation']) 
data.head()
age workclass fnlwgt education education.num marital.status occupation relationship race sex capital.gain capital.loss hours.per.week
0 90.0 NaN 77053.0 HS-grad 9.0 Widowed NaN Not-in-family White Female 0.0 4356.0 40.0
1 82.0 Private 132870.0 HS-grad 9.0 Widowed Exec-managerial Not-in-family White Female 0.0 4356.0 18.0
2 66.0 NaN 186061.0 Some-college 10.0 Widowed NaN Unmarried Black Female 0.0 4356.0 40.0
3 54.0 Private 140359.0 7th-8th 4.0 Divorced Machine-op-inspct Unmarried White Female 0.0 3900.0 40.0
4 41.0 Private 264663.0 Some-college 10.0 Separated Prof-specialty Own-child White Female 0.0 3900.0 40.0

Remove duplicate records

First, show the first occurrence of all duplicated rows. For brevity's sake, we’ll just show the first five.

duplicate = data[data.duplicated()] 
print("Duplicate Rows :") 
duplicate 
age workclass fnlwgt education education.num marital.status occupation relationship race sex capital.gain capital.loss hours.per.week
8453 25.0 Private 308144.0 Bachelors 13.0 Never-married Craft-repair Not-in-family White Male 0.0 0.0 40.0
8645 90.0 Private 52386.0 Some-college 10.0 Never-married Other-service Not-in-family Asian-Pac-Islander Male 0.0 0.0 35.0
12202 21.0 Private 250051.0 Some-college 10.0 Never-married Prof-specialty Own-child White Female 0.0 0.0 10.0
14346 20.0 Private 107658.0 Some-college 10.0 Never-married Tech-support Not-in-family White Female 0.0 0.0 10.0
15603 25.0 Private 195994.0 1st-4th 2.0 Never-married Priv-house-serv Not-in-family White Female 0.0 0.0 40.0

Your job now is to study the duplicated records and decide if they are in error or not. ‌If they are in error, you can remove them with the following command:

data = data.drop_duplicates().reset_index(drop=True)

Drop columns with mostly missing data

In our experience, you're better off dropping columns with 60% or more of missing data. When so much data is missing, the signal from the field can be messy. Below, we’ll show the percent missing for each field.

df_len = len(data)
data.isnull().sum()/df_len
age 0.000399
workclass 0.056823
fnlwgt 0.000399
education 0.000399
education.num 0.003471
marital.status 0.000399
occupation 0.057069
relationship 0.000399
race 0.000399
sex 0.000399
capital.gain 0.000399
capital.loss 0.000399
hours.per.week 0.000399
native.country 0.000399
income 0.000399
dummy_col1 0.999601
dummy_col2 0.000399
dummy_col3 0.000399

Note, only `dummy_col1` has more than 60% data missing, so now we drop that column.

column = "dummy_col1"
data = data.drop([column], axis=1)

Impute missing values

There are many simple ways to fill in missing data. ‌For example, if the field is numeric, you could fill in missing values with the mean or the median. ‌If the field is categorical, you could fill missing values with the most frequent categorical value. By far the most effective way to fill in missing data is with a machine learning imputation approach. KNN imputation is a machine learning-based imputation algorithm that's seen success but requires tuning of the parameter k and additionally, is vulnerable to many of KNN’s weaknesses, like being sensitive to outliers and noise. 

In this blog, we’ll be using MissForest, which is another machine learning-based data imputation algorithm. It uses a random forest trained on the observed values of a data matrix to predict the missing values. There are many benefits of using MissForest. For one, it can be applied to mixed data types, numerical and categorical.

We’ll start by listing the missing percent of records per field:

data.isnull().sum()
age 13
workclass 1850
fnlwgt 13
education 13
education.num 113
marital.status 13
occupation 1858
relationship 13
race 13
sex 13
capital.gain 13
capital.loss 13
hours.per.week 13
native.country 13
income 13
dummy_col2 13
dummy_col3 13
0 32544

Now let’s create several functions for massaging the data into the numeric format MissForest needs:

# Get the list of nominal columns
nominal_columns = list(data.select_dtypes(include=["object", "category"]).columns)

# Get the indices of the nominal columns
cat = data[data.select_dtypes('object').columns]  
cat_ind = [data.columns.get_loc(c) for c in cat]

# Translate categorical fields to numeric
from sklearn import preprocessing
col_le = {}
data_trans = data.copy()
for col in nominal_columns:
    le = preprocessing.LabelEncoder()
    le.fit(data[col])
    data_trans[col] = le.transform(data[col])
    col_le[col] = le

# Add back in the NaNs
for col in nominal_columns:
    new_values = []
    for i in range(len(data)):
        if pd.isna(data.loc[i][col]):
            new_values.append(np.nan)
        else:
            new_values.append(data_trans.loc[i][col])
    data_trans[col] = new_values

Let’s train and run the MissForest algorithm:

import sklearn
import sys
sys.modules['sklearn.neighbors.base'] = sklearn.neighbors._base
from missingpy import MissForest
imputer = MissForest()
    
cols = data.columns
data = pd.DataFrame(imputer.fit_transform(data_trans, cat_vars=cat_ind))
data.columns = cols

Now we translate the categorical values (which are currently numeric) back into strings:

for col in col_le:
    le = col_le[col]
    data[col] = le.inverse_transform(data[col].astype(int))

Finally, let’s take a look at our new data. ‌Note that there are no longer any missing values.

data.head()
age workclass fnlwgt education education.num marital.status occupation relationship race sex capital.gain capital.loss hours.per.week
0 90.0 Private 77053.0 HS-grad 9.0 Widowed Sales Not-in-family White Female 0.0 4356.0 40.0
1 82.0 Private 132870.0 HS-grad 9.0 Widowed Exec-managerial Not-in-family White Female 0.0 4356.0 18.0
2 66.0 Private 186061.0 Some-college 10.0 Widowed Exec-managerial Unmarried Black Female 0.0 4356.0 40.0
3 54.0 Private 140359.0 7th-8th 4.0 Divorced Machine-op-inspct Unmarried White Female 0.0 3900.0 40.0
4 41.0 Private 264663.0 Some-college 10.0 Separated Prof-specialty Own-child White Female 0.0 3900.0 40.0

Remove redundant fields

When two fields are completely correlated, they're redundant and one of them could be removed. First, here's some code to compute ‌field-by-field correlations.

from gretel_synthetics.utils import stats
import matplotlib.pyplot as plt 
import seaborn as sns 

# Compute field correlations
cm = stats.calculate_correlation(data, nominal_columns=nominal_columns)
cm = cm[cm.columns].astype(float)

# Graph the correlations 
plt.figure(figsize=(10, 6))
sns.heatmap(cm, cmap='RdYlGn')
plt.show()

And here’s the output from running the code:

Figure 1: Field-by-field correlation heatmap

As you can see, `education` and `education.num` are 100% correlated. The best thing to do is to remove one and then proceed with your AI/ML pipeline. If after the AI/ML you really need both fields in the data (as can be the case with creating synthetic data), then before you start, create a map between the two fields. At the end, use the map to get the removed field back in. Here’s example code removing education:

data = data.drop(['education'], axis=1)

Capping high float precision

Sometimes float values can have excessively long precision, particularly when they are the result of some mathematical computation. Some models, like Gretel’s LSTM, are sensitive to values with excessive precision and do much better if the floats are rounded. Sometimes (as is the case when creating synthetic data), how much precision you need is dependent on how much precision you need in the output. If you can reduce the precision, then rounding to two spots to the right of the decimal generally works well. In our dataset, there is only one column with a long floating-point precision, and that column is `dummy_col3`.

data["dummy_col3"].head()

Now we drop some of the precision by rounding every float to have two spots to the right of the decimal point, and look at the data again.

data = data.round(decimals = 2)
data["dummy_col3"].head()

Remove constant fields

When a field only has one consistent value, it has very low predictive power and you're better off removing that field. This next code snippet will detect any constant fields:

for col in data.columns:
    vc = data[col].value_counts()
    if len(vc) == 1:
        print("Column " + col + " is a constant column of " + str(vc.index[0]))

The above returns the following line:

`Column dummy_col2 is a constant column of 0.0`

Now you can remove that column with the below code snippet:

col_to_drop = "dummy_col2"
data = data.drop([col_to_drop], axis=1)

Address field level outliers

Records that contain outlier values can be very disruptive to an AI/ML pipeline. ‌It's important to look at each outlier found and determine if the reason it’s an outlier is that there’s an error in the data. Your choice is then to either fix the error or just remove the whole record. ‌Remember: garbage in, garbage out. If there are errors in the training data, there will be errors in the AI/ML analysis. In this section we are looking for outliers at the field level, and in the next section we’ll look at outliers at the record level.

We’ll start by displaying the numeric columns:

for col in data.columns:
    if col not in nominal_columns:
        print(col)

Now we’ll pick the column `age` and graph a box plot. The nice thing about a box plot is you can see how outliers fit within the context of properties like minimum, first quartile, median, third quartile, and maximum.

import matplotlib.pyplot as plt
fig1, ax1 = plt.subplots()
outlier_column = "age"
ax1.set_title('Outlier Box Plot for ' + outlier_column)
ax1.boxplot(data[outlier_column])
Figure 2: Outlier box plot for age

You can see that ages above 80 are outliers. In this case, we decide that these are valid records (not errors) and leave the records in. If you had wanted to remove the outliers, the following code would do the trick:

data = data[data["age"] <= 80]

Handle record level outliers

A record level outlier is one where all the fields in the record together form an outlier. A nice routine for measuring record level outliers is IsolationForest. This algorithm is an unsupervised decision-tree-based one originally developed for outlier detection in tabular data. You can read more about it here.

We start by defining a set of functions we’ll be using:

import category_encoders as ce
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import IsolationForest

def normalize_dataset(df: pd.DataFrame) -> pd.DataFrame:
    # Divide the dataframe into numeric and categorical
    nominal_columns = list(df.select_dtypes(include=["object", "category"]).columns)
    numeric_columns = []
    for c in df.columns:
        if not c in nominal_columns:
            numeric_columns.append(c)
    df_cat = df.reindex(columns=nominal_columns)
    df_num = df.reindex(columns=numeric_columns)
    # Fill missing values and encode categorical columns by the frequency of each value
    if len(numeric_columns) > 0:
        df_num = df_num.fillna(df_num.median())
    if len(nominal_columns) > 0:
        df_cat = df_cat.fillna("Missing")
        encoder = ce.count.CountEncoder()
        df_cat_labels = pd.DataFrame(encoder.fit_transform(df_cat))
    # Merge numeric and categorical back into one dataframe
    new_df = pd.DataFrame()
    if len(nominal_columns) == 0:
        new_df = df_num
    elif len(numeric_columns) == 0:
        new_df = df_cat_labels
    else:
        new_df = pd.concat([df_num, df_cat_labels], axis=1, sort=False)
    # Finally, standardize all values
    all_columns = nominal_columns + numeric_columns
    new_df = pd.DataFrame(StandardScaler().fit_transform(new_df),columns=all_columns)
    return new_df

def train_outlier_model(train_norm:pd.DataFrame):
    outlier_model = IsolationForest()
    outlier_model.fit(train_norm)
    return outlier_model

def scale(scores):
    new_scores = scores * -1 + .5
    return new_scores
    
def score_outlers_df(df_norm:pd.DataFrame, outlier_model):
    scores = outlier_model.decision_function(df_norm)  # outlier scores
    scaled_scores = scale(scores)
    # Save outlier scores and predictions in the passed dataframe
    df_norm["outlier_score"] = scaled_scores
    return df_norm

Now we normalize the data, train, and score the outlier model:

data_normalize = normalize_dataset(data)
outlier_model = train_outlier_model(data_normalize)
data_norm = score_outlers_df(data_normalize, outlier_model)

Our scoring mechanism has been normalized such that an outlier score of 0.7 or above means the record is for sure an outlier. ‌A score of 0.6 to 0.7 means the record is possibly an “outlier” and should be manually looked at.

Let’s start by looking at the for sure outliers:

outliers = data.loc[data_norm['outlier_score'] >= .7]
outliers
age workclass fnlwgt education education.num marital.status occupation relationship race sex capital.gain capital.loss hours.per.week
32554 6.0 non-worklass 171351.0 elementary 1.0 Widowed student Other-relative Asian-Pac-Islander female 0.0 0.0 0.0

There is only one record, a dummy error record we entered where the person is a 6-year-old elementary school student earning more than 50K. The following code shows how to remove this outlier:

data = data.loc[data_norm['outlier_score'] < .7]

Conclusion

Data cleaning can often make or break the success of an AI/ML project. Gretel encourages all its users to fully clean their data before initiating a Gretel project. Most of Gretel’s models to create synthetic data use AI/ML. Data cleaning will increase the odds of creating high quality synthetic data, which then increases the odds of a downstream successful AI/ML project. Please feel free to contact us with any questions or comments: hi@gretel.ai.