Fine-Tuning CodeLlama on Gretel's Synthetic Text-to-SQL Dataset using Amazon SageMaker JumpStart

Fine-tune CodeLlama with Gretel's Synthetic Text-to-SQL on BIRDBench, achieving a 36% relative improvement in EX and 38% in VES.

Introduction

In our previous blog post, we introduced Gretel's release of a high-quality, purely synthetic Text-to-SQL dataset, available on Hugging Face here. The dataset was generated using the powerful Gretel Navigator model and mimics the complexities and variety of real-world SQL queries. As we continue to explore the potential of synthetic data, this follow-up post delves into the application of this dataset for fine-tuning Large Language Models (LLMs), specifically using the CodeLlama models on Amazon SageMaker Jumpstart. Our goal is to demonstrate the efficacy of synthetic data in enhancing LLMs and AI capabilities for complex tasks such as Text-to-SQL parsing.

Fine-Tuning CodeLlama Using Amazon SageMaker Jumpstart

In this blog post, we use the SageMaker Python SDK to instruction-based fine-tune the CodeLlama-7B and CodeLlama-13B models. We start by downloading the dataset from Hugging Face and export the training portion into a JSONL format.

from datasets import load_dataset

gretel_text_to_sql = load_dataset("gretelai/synthetic_text_to_sql")
gretel_text_to_sql["train"].to_json("train.jsonl")

Each record in the dataset is a pairing of a natural language query with its corresponding SQL query, including metadata detailing the database schema, domain descriptions, and specific tasks.

{
  "id": 2,
  "domain": "aerospace",
  "domain_description": "Aircraft manufacturing data, satellite deployment projects, flight safety records, and space exploration research.",
  "sql_complexity": "basic SQL",
  "sql_complexity_description": "basic SQL with a simple select statement",
  "sql_task_type": "analytics and reporting",
  "sql_task_type_description": "generating reports, dashboards, and analytical insights",
  "sql_prompt": "Count the number of aircraft produced by company XYZ",
  "sql_context": "CREATE TABLE aircraft (id INT, model VARCHAR(50), manufacturer VARCHAR(50), production_year INT);",
  "sql": "SELECT COUNT(*) FROM aircraft WHERE manufacturer = 'XYZ';",
  "sql_explanation": "This query counts the number of records in the aircraft table where the manufacturer is XYZ."
}

To fine-tune CodeLlama, we created the following instruction prompt template. This template guides the model on what to generate and provides it with contextual knowledge about the database schema and the nature of the SQL task at hand.

template = {
    "prompt": """
[INST] Write a SQL query that answers the following question based on the given database schema and any additional information provided. Use SQLite syntax.

[SCHEMA] {sql_context}

[KNOWLEDGE] This is an '{sql_task_type}' task, commonly used for {sql_task_type_description}. In the domain of {domain}, which involves {domain_description}.

[QUESTION] {sql_prompt}

[/INST]""",
    "completion": """
      ```{sql}```
{sql_explanation}""",
}

with open("template.json", "w") as f:
    json.dump(template, f)

With the dataset and template prepared, we establish a SageMaker session and upload them to Amazon S3.

from sagemaker import Session
from sagemaker.s3 import S3Uploader

session = Session()
output_bucket = session.default_bucket()
local_data_file = "train.jsonl"
train_data_location = f"s3://{output_bucket}/gretel_text_to_sql"
S3Uploader.upload(local_data_file, train_data_location)
S3Uploader.upload("template.json", train_data_location)
print(f"Training data: {train_data_location}")

The process was executed using SageMaker's JumpStart feature using the following hyperparameters, fine-tuning the model for one epoch:

hyperparameters = {
    "instruction_tuned": "True",
    "epoch": "1",
    "learning_rate": "0.0002",
    "lora_r": "8",
    "lora_alpha": "32",
    "lora_dropout": "0.05",
    "int8_quantization": "False",
    "enable_fsdp": "True",
    "per_device_train_batch_size": "8",
    "per_device_eval_batch_size": "2",
    "max_input_length": "512",
    "validation_split_ratio": "0.2",
    "train_data_split_seed": "0",
}

Now, we are ready to fine-tune the CodeLlama model:

from sagemaker.jumpstart.estimator import JumpStartEstimator

model_id = "meta-textgeneration-llama-codellama-7b"
estimator = JumpStartEstimator(
    model_id=model_id,
    environment={"accept_eula": "true"},  # Accept EULA for gated models
    disable_output_compression=True,
    hyperparameters=hyperparameters,
    sagemaker_session=session,
)
estimator.fit({"training": train_data_location})

The model fine-training process starts and takes about 3.5 hours per epoch. Once the model is trained, we can deploy the fine-tuned model to an endpoint directly from the estimator.

predictor = estimator.deploy()

Evaluation on the BIRD Benchmark

We assessed the impact of fine-tuning the CodeLlama model on the BIRD benchmark—an evaluation platform for large-scale cross-domain, text-to-SQL. The benchmark tests across various difficulties: simple, moderate, and challenging SQL queries, revealing the model's comprehensive performance improvements. We conducted the experiments on the BIRD development dataset, containing 1,534 Text-to-SQL pairs, using the two evaluation metrics:

  • Execution Accuracy (EX), a metric measuring if the result of executing the generated SQL query matches the gold value. 
  • Valid Efficiency Score (VES), an integrated metric assessing both accuracy of execution results (EX) and the execution efficiency of generated SQL queries. 

Here’s how the base and fine-tuned models compared on the DEV set by including knowledge evidence (see [1] for more details):

Figure 1. Comparison of base and fine-tuned models on the DEV set.

These results show that the fine-tuned model significantly performed better across both CodeLlama models on the EX and VES metrics. This underscores the potential of synthetic datasets in fine-tuning LLMs for specialized tasks, such as Text-to-SQL. 

[1] Li, J., Hui, B., Qu, G., Yang, J., Li, B., Li, B., ... & Li, Y. (2024). Can llm already serve as a database interface? a big bench for large-scale database grounded text-to-sqls. Advances in Neural Information Processing Systems, 36.

Ready to try yourself?

The SageMaker notebook for this blog post is available here.

You can find Gretel on the AWS Marketplace. If you have any questions about Gretel’s platform or would like to learn more about how synthetic data can help your business, send us a note.

‍

About the authors

Gretel:

  • Maarten Van Segbroeck, Yev Meyer, Marjan Emadi, Dhruv Nathawani, Lipika Ramaswamy, Kendrick Boyd, Matthew Grossman, Piotr Mlocek, Drew Newberry

AWS:

  • Qiong (Jo) Zhang, PhD, is a Senior Partner SA at AWS, specializing in AI/ML. Her current areas of interest include federated learning, distributed training, and generative AI. She holds 30+ patents and has co-authored 100+ journal/conference papers. She is also the recipient of the Best Paper Award at IEEE NetSoft 2016, IEEE ICC 2011, ONDM 2010, and IEEE GLOBECOM 2005.
  • Shashi Raina is a Senior Partner Solutions Architect at Amazon Web Services (AWS), where he specializes in supporting generative AI (GenAI) startups. With close to 6 years of experience at AWS, Shashi has developed deep expertise across a range of domains, including DevOps, analytics, and generative AI.