Building a RAG Pipeline using gemma-2B

Sept 10, 2024

Introduction

Recently, I got my hands dirty building a RAG Pipeline from scratch, and I have to say, it's an amazing way to handle large text documents. I worked with an open-source Human Nutrition textbook and transformed it into a system where I can query the document, retrieve relevant information, and generate intelligent answers using an LLM (Large Language Model).

Here's a detailed breakdown of the entire process, and some real outputs!

Preparing and Reading the PDF

The first step was extracting text from the PDF. I used PyMuPDF for reading the PDF because it's fast and handles large documents like a champ. The document used was the Human Nutrition 2020 Edition.

Here's the code to get the text from each page and store it in a structured way:

def open_pdf(pdf_path: str):
    doc = fitz.open(pdf_path)
    pages_and_texts = []

    for page_number, page in enumerate(doc):
        text = page.get_text()
        pages_and_texts.append({"page_number": page_number, "text": text.strip()})

    return pages_and_texts

pages_and_texts = open_pdf("human-nutrition-text.pdf")
        

Here's a sample output from one of the extracted pages:

Page 84:

"The cardiovascular system is one of the eleven organ systems of the human body. Its main function is to transport nutrients to cells and wastes from cells. This system consists of the heart, blood, and blood vessels. The heart pumps the blood, and the blood is the transportation fluid."
        

Chunking the Text into Pieces

Since embedding models like Sentence-BERT have a limit on how much text they can handle (usually around 384 tokens), Thus splitting the PDF text into smaller, manageable chunks.

For chunking, I split the text into groups of sentences using spaCy. This makes sure we don't overload the model with too many tokens at once.

import spacy
nlp = spacy.load("en_core_web_sm")

def chunk_text(text, chunk_size=10):
    doc = nlp(text)
    sentences = [sent.text for sent in doc.sents]
    return [sentences[i:i + chunk_size] for i in range(0, len(sentences), chunk_size)]
        

Embedding the Chunks

Each chunk of text is then converted into numerical representations (embeddings) using Sentence-BERT. These embeddings allow for efficient similarity searches based on the meaning of the text.

from sentence_transformers import SentenceTransformer

# Load pre-trained Sentence-BERT model
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

# Embedding the text chunks
embeddings = [model.encode(chunk) for page in pages_and_texts for chunk in page['chunks']]
        

This step produces vector representations of the chunks, which can be used to find relevant information during retrieval. A sample of an embedding: This was only a portion of the output, lol.

Sample Embedding Vector (truncated):
[0.123, -0.245, 0.003, 0.157, ..., 0.056]
        

Retrieving Relevant Chunks

To retrieve relevant chunks, cosine similarity is used to compare the query's embedding with the document embeddings. The most similar chunks (based on vector closeness) are returned.

from sklearn.metrics.pairwise import cosine_similarity

def search(query, embeddings):
    query_embedding = model.encode(query)
    similarities = cosine_similarity([query_embedding], embeddings)[0]
    top_n_indices = similarities.argsort()[-5:][::-1] 
    return top_n_indices
        

For example, when querying "What is the cardiovascular system?":

query = "What is the cardiovascular system?"
results = search(query, embeddings)

# Retrieve the most relevant chunks
for idx in results:
    print(pages_and_texts[idx]['chunks'])
        
Retrieved Chunk 1:
"The cardiovascular system is one of the eleven organ systems of the human body. Its main function is to transport nutrients to cells and wastes from cells."

Retrieved Chunk 2:
"This system consists of the heart, blood, and blood vessels. The heart pumps the blood, and the blood is the transportation fluid."
        

Setting Up the Generation with a LLM

Now that the retrieval pipeline is ready, the next step is to add generation functionality using a Large Language Model (LLM). The goal here is to generate context-aware answers based on the information retrieved from the document.

To do this, we'll use the Gemma model, which can run on a local GPU. The model will take a query and the relevant context as input, generating a response that answers the question while considering the extracted information.

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
gpu_memory_bytes = torch.cuda.get_device_properties(0).total_memory
gpu_memory_gb = round(gpu_memory_bytes / (2**30))

if gpu_memory_gb < 5.1:
    model_id = "google/gemma-2b"
    use_quantization_config = True
elif gpu_memory_gb < 8.1:
    model_id = "google/gemma-2b-it"
    use_quantization_config = True
elif gpu_memory_gb < 19.0:
    model_id = "google/gemma-2b-it"
    use_quantization_config = False
else:
    model_id = "google/gemma-7b-it"
    use_quantization_config = False

print(f"Model selected: {model_id}")
        
Model selected: google/gemma-2b-it
        

Loading the Model and Tokenizer

Once the right model is selected, it can be loaded with or without quantization, depending on the available GPU memory.

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) if use_quantization_config else None
tokenizer = AutoTokenizer.from_pretrained(model_id)

llm_model = AutoModelForCausalLM.from_pretrained(model_id, 
                                                 torch_dtype=torch.float16, 
                                                 quantization_config=quantization_config)
        

Generating Responses Based on Query

After retrieving relevant text from the document, the next step is to generate responses. The Gemma model, fine-tuned on instruction-based data, requires that the input prompt be formatted correctly, often in a conversational template. Formatting the Prompt with Context

The prompt includes both the user's query and relevant context retrieved from the text. Here's how to format it:

def prompt_formatter(query: str, context_items: list[dict]) -> str:
    # Combine context items into a single paragraph
    context = "- " + "\n- ".join([item["sentence_chunk"] for item in context_items])

    # Base prompt with examples
    base_prompt = f"""Based on the following context items, please answer the query.
Relevant passages: <extract relevant passages from the context here>
User query: {query}
Answer:"""

    # Format dialogue template for instruction-tuned model
    dialogue_template = [{"role": "user", "content": base_prompt}]

    # Apply chat template for the model
    return tokenizer.apply_chat_template(conversation=dialogue_template, tokenize=False, add_generation_prompt=True)
        

Example: Augmenting Query with Context

Here's an example query, formatted with context:

query = "What role does fiber play in digestion?"
context_items = [  # Retrieved relevant chunks
    {"sentence_chunk": "Fiber helps digestion by regulating bowel movements and maintaining gut health."},
    {"sentence_chunk": "It adds bulk to stool and prevents constipation."}
]

# Format the prompt
formatted_prompt = prompt_formatter(query, context_items)
print(formatted_prompt)
        
Based on the following context items, please answer the query.
Relevant passages: 
- Fiber helps digestion by regulating bowel movements and maintaining gut health.
- It adds bulk to stool and prevents constipation.
User query: What role does fiber play in digestion?
Answer:
        

Generating the Response

With the formatted prompt ready, the Gemma model generates an answer using the relevant context:


input_ids = tokenizer(formatted_prompt, return_tensors="pt").to("cuda")
outputs = llm_model.generate(input_ids, max_new_tokens=256, temperature=0.7)

# Decode the output to readable text
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
        

Sample output for the query "What role does fiber play in digestion?" might look like:

"Fiber plays a crucial role in digestion by promoting regular bowel movements, preventing constipation, and maintaining gut health. It adds bulk to stool and facilitates smoother passage through the digestive tract."
        

Putting it All Together

At this point, the RAG pipeline is complete. The system can:

  1. Retrieve relevant information from a large text document (like a textbook).
  2. Generate human-like answers to queries by incorporating context from the retrieved information.
def ask(query: str, max_new_tokens=512):
    # Retrieve relevant chunks
    scores, indices = retrieve_relevant_resources(query=query, embeddings=embeddings)
    context_items = [pages_and_chunks[i] for i in indices]

    # Format the prompt
    prompt = prompt_formatter(query=query, context_items=context_items)

    # Generate an answer
    input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = llm_model.generate(input_ids, max_new_tokens=max_new_tokens)

    # Return the answer
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test the complete pipeline
query = "What are the benefits of vitamin C?"
print(ask(query))
        

Sample generated answer for the query "What are the benefits of vitamin C?":

"Vitamin C is essential for the growth and repair of tissues in the body. It plays a role in collagen production, wound healing, and maintaining healthy skin and blood vessels. Vitamin C is also a powerful antioxidant, helping to protect cells from oxidative stress and boosting the immune system."
        

Conclusion

Woah! The complete RAG pipeline—retrieval, augmentation, and generation all happening on your own GPU!

This setup opens the door to a variety of use cases, such as:

  1. Customer support: Answering questions based on a knowledge base.
  2. Research assistance: Extracting and generating summaries from scientific papers.
  3. Interactive textbooks: Turning large textbooks into a Q&A system.

References & Guide

Daniel Brouke - Youtube Link
Langchain documentation - Github repo