Prompting Llama-2 at Scale with Gretel

Discover how to efficiently use Gretel's platform for prompting Llama-2 on large datasets, whether you're completing answers, generating synthetic text, or labeling.

In this blog, we will walk through prompting the 7 billion parameter Meta Llama-2 chat model with a series of 250 questions from GSM8k- a popular dataset used to assess LLMs ability to solve multi-step problems. 

Follow along in a Colab notebook.

What is Batch Inference?

When you’re working with 100s to 100k’s of records, prompting an LLM via a synchronous API, one query at a time can be inefficient. In this post, we'll explore how to leverage Gretel's platform to prompt Llama-2 efficiently at scale, for example to complete answers to a large set of prompts, to create additional synthetic text examples, or to label a large corpus of data.

Why Llama 2?

Meta’s Llama-2 is one of the most advanced open-source large language models (LLMs) available today. With models ranging in size from 7B, to 13B, to 70B parameters, it can generate remarkably human-like text across many domains. In this example, we will be using the 7B parameter Llama-2-chat model, which has been fine-tuned and optimized for dialogue applications using Reinforcement Learning Human Feedback (RLHF), allowing it to achieve comparable performance to ChatGPT across many evaluations. You can read more in the paper here.

How is this useful?

Batch prompting unlocks new capabilities for organizations using large language models like Llama-2. By being able to get insights across hundreds or thousands of data points at once, teams can utilize LLMs more efficiently and effectively.

  • Enables fast insights across entire datasets instead of slow per-query API calls.
  • Unlocks high-scale content generation, corpus labeling, and other new use cases.
  • Allows LLMs like Llama-2 to be leveraged more efficiently across more data.

How to prompt Llama-2 with Gretel

1. Setup Development Environment

First, we'll install the Gretel Client and HuggingFace Dataset libraries:

# Install dependencies
!pip install -Uqq gretel-client datasets

Import the libraries we need:

# Import required libraries
import textwrap
import pandas as pd
from datasets import load_dataset

from gretel_client import configure_session
from gretel_client.helpers import poll
from gretel_client.projects import create_or_get_unique_project
from gretel_client.projects.models import read_model_config

To access the Gretel GPT API for LLM inference we need to log into our Gretel account. You can do this via the CLI or the Python code below in a notebook.

configure_session(
   api_key="prompt",
   endpoint="https://api.gretel.cloud",
   validate=True,
   clear=True
)

2. Load and prepare the dataset

We'll use the GSM8K (Grade School Math) dataset of 8k real questions to query the LLM. These are problems that take between 2 and 8 steps to solve, primarily performing a sequence of elementary calculations using arithmetic to reach the final answer, creating a good benchmark for assessing LLMs ability to solve everyday real-world multi-step problems. 

First we'll load the dataset and extract the first 250 questions:

# Load gsm8k dataset and get the first 250 questions
dataset = load_dataset('gsm8k', 'main')
questions = dataset['train']['question'][:250]

Next, we'll format the questions into instructions for Llama-2, adding the expected chat formatting for the Llama2 model:

# Function to format prompt according to Llama2 expected chat instruction format
def format_prompt(prompt: str) -> str:
    llama_template = textwrap.dedent(f"""\
    <s>[INST] <<SYS>>You provide just the answer you are asked for with no preamble. Do not repeat the question. Be succinct.<</SYS>>

    {prompt} [/INST]
    """)

    return llama_template

# Add the Llama2 instruction format to each prompt
formatted_prompts = [format_prompt(q) for q in questions]

# Convert the instructions to a DataFrame format
instructions = pd.DataFrame(data={'text': formatted_prompts})

Let’s verify our formatting function by printing out a random sample.

# Print a random sample question and formatted instruction
random_idx = instructions.sample(n=1).index[0]
print(f"Random question:\n```{questions[random_idx]}```\n\n")
print(f"Instruction:\n```{instructions.loc[random_idx]['text']}```")

3. Initialize the LLM

Next, we'll initialize a Gretel synthetic model and configure it to load the Llama 2 chat model:

# Initialize the synthetic model and its configuration
MODEL = 'meta-llama/Llama-2-7b-chat-hf'


config = read_model_config("synthetics/natural-language")
config['models'][0]['gpt_x']['pretrained_model'] = MODEL
# There is no need to define a data source for fine-tuning, as we are prompting a base model
config['models'][0]['gpt_x']['data_source'] = None


model = project.create_model_obj(model_config=config)
model.name = MODEL
model.submit_cloud()


# Wait for the model to be ready
poll(model, verbose=False)

4. Batch Prompt the LLM

With our dataset prepared and Llama-2 model initialized on Gretel, we can now send our batch of questions to prompt the model. On initialization, Gretel's batch APIs load a private container from disk, which can take about 3-5 minutes of processing time to download and initialize the ~13GB model. Therefore, it is most efficient to use the batch APIs when labeling hundreds or thousands of examples. With Gretel operating in either cloud or hybrid mode, you can submit multiple jobs to process datasets in parallel.

# Pass instruction to the model for zero-shot reasoning
record_handler = model.create_record_handler_obj(
    params={"maximum_text_length": 250},
    data_source=instructions
)
record_handler.submit_cloud()

# Wait for the results
poll(record_handler, verbose=False)

# Retrieve results
results = pd.read_csv(record_handler.get_artifact_link("data"), compression='gzip')
results

Finally, we can retrieve the results - Llama-2's responses to our batch of questions:

# Generate a random index
random_idx = results.sample(n=1).index[0]


print(f"Random question:\n```{questions[random_idx]}```\n\n")
print(f"LLM response:\n```{results.loc[random_idx]['text']}```")

Conclusion

By leveraging Gretel's platform and Python client library, we can easily prompt large language models like Meta’s Llama 2 Chat model with batches of data for faster, more efficient results compared to prompting query-by-query. Gretel handles scaling out requests across multiple GPUs, monitoring progress, automatically assessing synthetic text quality, and fine tuning models like Llama 2 on your own data.

The full code for this tutorial is available on Colab and GitHub. Try it out yourself to experience prompting LLMs at scale!