In the realm of large language models (LLMs) and retrieval-augmented generation (RAG), optimizing the retrieval process is crucial for efficiency and accuracy. In a previous blog post, I discussed how to optimize RAG context by chunking and summarization for technical documents using Chroma VectorDB. Specifically, the post demonstrated how to query summaries while retrieving the original documents.
In this post, we'll explore how to achieve a similar result using Weaviate and its cross-references feature, integrated with LangChain. We'll leverage Weaviate's ability to create cross-references between data objects to efficiently retrieve original documents by querying their summaries.
This tutorial offers a comprehensive and detailed walkthrough of the process. If you prefer to explore the code and outputs directly, you can access the Jupyter notebook here: Jupyter Notebook.
Preparing the Environment
First, set up a Conda environment to manage dependencies and keep the project isolated.
# Create a new environment called 'rag-env'
conda create -n rag-env python=3.12
# Activate the environment
conda activate rag-env
# Install necessary packages
pip install weaviate-client==4.9.0 \
langchain==0.3.5 \
langchain-core==0.3.13 \
langchain-ollama==0.2.0
Note: The following are also required:
- Weaviate: Install it locally using Docker, to Kubernetes cluster or access it via the Weaviate Cloud.
- Ollama: Ensure you have the mxbai-embed-large model for embeddings and llama3.1 for the RAG example.
Initializing the Weaviate Client
Let's begin by initializing the Weaviate client with authentication.
import getpass
import weaviate
from weaviate.classes.init import Auth
# Prompt for the Weaviate API key
WEAVIATE_API_KEY = getpass.getpass()
# Initialize the Weaviate client with authentication
weaviate_client = weaviate.connect_to_local(
auth_credentials=Auth.api_key(WEAVIATE_API_KEY)
)
# Check if the client is ready
print("Client is Ready?", weaviate_client.is_ready())
Importing Data with Cross-References
Next, we'll load the original and summarized documents, create Weaviate collections, and insert the data with cross-references.
Note: The files chunked_docs.json and summarized_docs.json are taken from my previous blog post. They were created by adding the following code to the Summarization Based on Headers and Chunk Text section:
# Save the chunked and summarized documents to JSON files
import json
chunked_docs_json = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in chunked_docs]
with open('files/generated/chunked_docs.json', 'w') as f:
json.dump(chunked_docs_json, f, indent=4)
summarized_docs_json = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in summarized_docs]
with open('files/generated/summarized_docs.json', 'w') as f:
json.dump(summarized_docs_json, f, indent=4)
In the summarized_docs.json file, metadata.id
was changed to metadata.doc_id
to avoid conflicts with Weaviate's id
field.
Now, let's proceed to import the data:
import json
from langchain_ollama import OllamaEmbeddings
# Load the chunked and summarized documents
with open("files/chunked_docs.json", "r") as f:
chunked_docs = json.load(f)
with open("files/summarized_docs.json", "r") as f:
summarized_docs = json.load(f)
# Define the collection names
collections = {
"original": "OriginalDocuments",
"summary": "SummarizedDocuments",
}
# Delete collections if they already exist
for collection in collections.values():
if collection in weaviate_client.collections.list_all(simple=True):
weaviate_client.collections.delete(collection)
# Create the collections
original_collection_db = weaviate_client.collections.create(collections["original"])
summary_collection_db = weaviate_client.collections.create(collections["summary"])
# Initialize the Ollama embedding model
ollama_emb = OllamaEmbeddings(model="mxbai-embed-large")
# Insert the documents into the collections
for summarized_doc in summarized_docs:
summarized_doc_id = summarized_doc["metadata"]["doc_id"]
original_doc = next((doc for doc in chunked_docs if doc.get("metadata", {}).get("summary_id") == summarized_doc_id), None)
if original_doc:
original_uuid = original_collection_db.data.insert(
{
"page_content": original_doc["page_content"],
},
vector=ollama_emb.embed_query(original_doc["page_content"]),
)
summary_collection_db.data.insert(
{
"page_content": summarized_doc["page_content"],
},
references={"originalDocument": original_uuid},
vector=ollama_emb.embed_query(summarized_doc["page_content"]),
)
# Verify the number of documents in the collections
original_count = len(original_collection_db)
summary_count = len(summary_collection_db)
print(f"Number of documents in the original collection: {original_count}")
print(f"Number of documents in the summary collection: {summary_count}")
Output:
Number of documents in the original collection: 34
Number of documents in the summary collection: 34
Explanation:
- Loading Data: The chunked and summarized documents are loaded from JSON files.
-
Collections: Two collections,
OriginalDocuments
andSummarizedDocuments
, are defined and created in Weaviate. - Embeddings: The embedding model is initialized using Ollama's mxbai-embed-large model.
-
Data Insertion: Each summarized document is matched with its corresponding original document. The original documents are inserted into the
OriginalDocuments
collection, and the summarized documents are inserted into theSummarizedDocuments
collection, including a cross-reference that links back to their respective original documents.
Why Create Vectors Manually?
You might be curious why the vectors are created manually using ollama_emb.embed_query
instead of leveraging Ollama Embeddings with Weaviate. The reason is that the Weaviate instance is running on a Kubernetes cluster in a lab environment, while Ollama operates in a different isolated environment to which Weaviate doesn't have connectivity. Since Weaviate cannot access the environment where Ollama is running, this approach demonstrates how to insert documents into Weaviate by manually generating embeddings in a separate environment and then supplying them directly to Weaviate.
Querying with Cross-References
The aim of this example is to demonstrate how to perform queries with cross-references by using Weaviate directly. We'll define a function that retrieves documents by querying the summaries and then obtains the original documents via cross-references. This will allow us to see the outputs for both the summaries and the original documents.
from weaviate.classes.query import QueryReference, MetadataQuery
# Define a function to retrieve documents
def retrieve_documents(query, vector, limit=2, score_threshold=0.8):
response = summary_collection_db.query.hybrid(
query,
vector=vector,
limit=limit,
return_references=QueryReference(link_on="originalDocument"),
return_metadata=MetadataQuery(score=True),
)
summary_docs = []
original_docs = []
for o in response.objects:
if o.metadata.score is not None and o.metadata.score >= score_threshold:
summary_doc = {"page_content": o.properties["page_content"]}
summary_docs.append(summary_doc)
for ref_obj in o.references["originalDocument"].objects:
original_doc = {"page_content": ref_obj.properties["page_content"]}
original_docs.append(original_doc)
return summary_docs, original_docs
# Define a query
query = "I want to write a Python script that prints numbers from 1 to 30."
vector = ollama_emb.embed_query(query)
summary_docs, original_docs = retrieve_documents(query, vector)
# Print the summarized and original documents
print("Summarized Documents:")
for i, doc in enumerate(summary_docs, start=1):
print(f"Summarized Document #{i}")
print("--------------------")
print(doc["page_content"])
print("--------------------")
print()
print("Original Documents:")
for i, doc in enumerate(original_docs, start=1):
print(f"Original Document #{i}")
print("--------------------")
print(doc["page_content"])
print("--------------------")
print()
Output:
Summarized Documents:
Summarized Document #1
--------------------
The `break` statement exits the innermost enclosing for or while loop, stopping execution of the loop and continuing with the next statement. This is demonstrated by a nested for loop that prints factors of numbers from 2 to 9, where the break statement stops the loop when a factor is found. The `continue` statement skips the rest of the current iteration in a loop and moves on to the next one, as shown by a for loop that iterates over numbers from 2 to 9, printing even numbers and skipping odd ones.
--------------------
Summarized Document #2
--------------------
The built-in `range()` function generates arithmetic progressions that can be used for iteration over a sequence of numbers. It takes three parameters: start point, end point, and step (default is 1), and returns an iterator that produces the specified range of values. The end point is never part of the generated sequence. To iterate over the indices of a sequence, `range()` can be combined with `len()`, but in most cases it's more convenient to use the `enumerate()` function for this purpose.
--------------------
Original Documents:
Original Document #1
--------------------
The [`break`](../reference/simple_stmts.html#break) statement breaks out of the innermost enclosing
[`for`](../reference/compound_stmts.html#for) or [`while`](../reference/compound_stmts.html#while) loop:
```
>>> for n in range(2, 10):
... for x in range(2, n):
... if n % x == 0:
... print(f"{n} equals {x} * {n//x}")
... break
...
4 equals 2 * 2
6 equals 2 * 3
8 equals 2 * 4
9 equals 3 * 3
```
The [`continue`](../reference/simple_stmts.html#continue) statement continues with the next
iteration of the loop:
```
>>> for num in range(2, 10):
... if num % 2 == 0:
... print(f"Found an even number {num}")
... continue
... print(f"Found an odd number {num}")
...
Found an even number 2
Found an odd number 3
Found an even number 4
Found an odd number 5
Found an even number 6
Found an odd number 7
Found an even number 8
Found an odd number 9
```
--------------------
Original Document #2
--------------------
If you do need to iterate over a sequence of numbers, the built\-in function
[`range()`](../library/stdtypes.html#range "range") comes in handy. It generates arithmetic progressions:
```
>>> for i in range(5):
... print(i)
...
0
1
2
3
4
```
The given end point is never part of the generated sequence; `range(10)` generates
10 values, the legal indices for items of a sequence of length 10\. It
is possible to let the range start at another number, or to specify a different
increment (even negative; sometimes this is called the βstepβ):
```
>>> list(range(5, 10))
[5, 6, 7, 8, 9]
>>> list(range(0, 10, 3))
[0, 3, 6, 9]
>>> list(range(-10, -100, -30))
[-10, -40, -70]
```
To iterate over the indices of a sequence, you can combine [`range()`](../library/stdtypes.html#range "range") and
[`len()`](../library/functions.html#len "len") as follows:
```
>>> a = ['Mary', 'had', 'a', 'little', 'lamb']
>>> for i in range(len(a)):
... print(i, a[i])
...
0 Mary
1 had
2 a
3 little
4 lamb
```
In most such cases, however, it is convenient to use the [`enumerate()`](../library/functions.html#enumerate "enumerate")
function, see [Looping Techniques](datastructures.html#tut-loopidioms).
--------------------
Explanation:
-
Hybrid Query: The
retrieve_documents
function utilizes Weaviate's hybrid search on theSummarizedDocuments
collection. Hybrid search combines the results of a vector similarity search and a keyword (BM25F) search by fusing the two result sets. This approach enhances retrieval accuracy by considering both semantic similarity (from embeddings) and keyword relevance. -
Cross-References: It retrieves the summarized documents and accesses the original documents via the
originalDocument
cross-reference. - Filtering: Only documents with a score above the threshold are considered.
- Result: The retrieved summarized and original documents are printed to display the outputs.
Creating a Custom Retriever
To integrate this retrieval mechanism with LangChain, we'll implement a custom retriever that leverages Weaviate's cross-references.
from typing import List, Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from weaviate.classes.query import QueryReference, MetadataQuery
class VectorDBRetrieverCrossReferences(BaseRetriever):
"""A custom retriever that retrieves documents from a Weaviate vector database."""
summary_collection_db: Any
ollama_emb: Any
k: int = 2
score_threshold: float = 0.8
return_source_documents: bool = False
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
"""Sync implementation for retriever."""
vector = self.ollama_emb.embed_query(query)
response = self.summary_collection_db.query.hybrid(
query,
vector=vector,
limit=self.k,
return_references=QueryReference(link_on="originalDocument"),
return_metadata=MetadataQuery(score=True),
)
original_docs = []
for o in response.objects:
if o.metadata.score is not None and o.metadata.score >= self.score_threshold:
for ref_obj in o.references["originalDocument"].objects:
doc_content = ref_obj.properties["page_content"]
metadata = {"source": ref_obj.properties.get("source")} if self.return_source_documents else {}
original_docs.append(Document(page_content=doc_content, metadata=metadata))
return original_docs
# Initialize the custom retriever
retriever = VectorDBRetrieverCrossReferences(
summary_collection_db=summary_collection_db,
ollama_emb=ollama_emb
)
# Retrieve documents using the custom retriever
query = "I want to write a Python script that prints numbers from 1 to 30."
documents = retriever.invoke(query)
# Print the retrieved documents
for i, doc in enumerate(documents, start=1):
print(f"Document #{i}")
print("--------------------")
print(doc.page_content)
print("--------------------")
print()
Output:
Document #1
--------------------
The [`break`](../reference/simple_stmts.html#break) statement breaks out of the innermost enclosing
[`for`](../reference/compound_stmts.html#for) or [`while`](../reference/compound_stmts.html#while) loop:
```
>>> for n in range(2, 10):
... for x in range(2, n):
... if n % x == 0:
... print(f"{n} equals {x} * {n//x}")
... break
...
4 equals 2 * 2
6 equals 2 * 3
8 equals 2 * 4
9 equals 3 * 3
```
The [`continue`](../reference/simple_stmts.html#continue) statement continues with the next
iteration of the loop:
```
>>> for num in range(2, 10):
... if num % 2 == 0:
... print(f"Found an even number {num}")
... continue
... print(f"Found an odd number {num}")
...
Found an even number 2
Found an odd number 3
Found an even number 4
Found an odd number 5
Found an even number 6
Found an odd number 7
Found an even number 8
Found an odd number 9
```
--------------------
Document #2
--------------------
If you do need to iterate over a sequence of numbers, the built\-in function
[`range()`](../library/stdtypes.html#range "range") comes in handy. It generates arithmetic progressions:
```
>>> for i in range(5):
... print(i)
...
0
1
2
3
4
```
The given end point is never part of the generated sequence; `range(10)` generates
10 values, the legal indices for items of a sequence of length 10\. It
is possible to let the range start at another number, or to specify a different
increment (even negative; sometimes this is called the βstepβ):
```
>>> list(range(5, 10))
[5, 6, 7, 8, 9]
>>> list(range(0, 10, 3))
[0, 3, 6, 9]
>>> list(range(-10, -100, -30))
[-10, -40, -70]
```
To iterate over the indices of a sequence, you can combine [`range()`](../library/stdtypes.html#range "range") and
[`len()`](../library/functions.html#len "len") as follows:
```
>>> a = ['Mary', 'had', 'a', 'little', 'lamb']
>>> for i in range(len(a)):
... print(i, a[i])
...
0 Mary
1 had
2 a
3 little
4 lamb
```
In most such cases, however, it is convenient to use the [`enumerate()`](../library/functions.html#enumerate "enumerate")
function, see [Looping Techniques](datastructures.html#tut-loopidioms).
--------------------
Explanation:
-
Custom Retriever: The
VectorDBRetrieverCrossReferences
class extends LangChain'sBaseRetriever
. -
Method Override: The
_get_relevant_documents
method performs the hybrid query and retrieves the original documents via cross-references. -
Integration: The retriever is initialized with the
summary_collection_db
and the embedding model, making it compatible with LangChain.
Example of Simple Retrieval-Augmented Generation (RAG)
Finally, let's build a simple RAG pipeline using the custom retriever and a language model.
from langchain_ollama.chat_models import ChatOllama
from langchain.chains import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
# Initialize the ChatOllama model
llm = ChatOllama(model="llama3.1", temperature=0, num_ctx=16384)
# Define the system prompt template
system_prompt = (
"You are an assistant for answering questions. "
"Use only the exact information provided in the context, do not include external knowledge or guesses. "
"If the answer cannot be inferred from the context, reply: 'I don't know based on the provided context.' "
"Do not provide answers that are not based on the context, including code examples or references to other libraries. "
"Format your entire response in valid Markdown, including code snippets and links. "
"Always adhere to these rules strictly.\n\n"
"Context: \n"
"{context}"
)
# Define the chat prompt
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
# Create the question-answer chain
question_answer_chain = create_stuff_documents_chain(llm, prompt)
# Create the retrieval-augmented generation (RAG) chain
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
# Query #1
query = "I want to write a Python script that prints numbers from 1 to 30."
response = rag_chain.invoke({"input": query})
# Print the response
print("Example #1")
print("--------------------")
print(f"Query: {query}")
print(f"Answer: {response['answer']}")
print("--------------------")
print("\n")
# Query #2 (check that context only is used)
query = "I want to write a Go script that prints numbers from 1 to 30."
response = rag_chain.invoke({"input": query})
# Print the response
print("Example #2")
print("--------------------")
print(f"Query: {query}")
print(f"Answer: {response['answer']}")
print("--------------------")
Output:
Example #1
--------------------
Query: I want to write a Python script that prints numbers from 1 to 30.
Answer: You can use the `range()` function in Python to generate a sequence of numbers and print them.
Here's how you can do it:
```
for i in range(1, 31):
print(i)
```
This will print numbers from 1 to 30. The `range()` function generates numbers starting from 0 by default, so we start at 1 and end at 30 (which is exclusive).
--------------------
Example #2
--------------------
Query: I want to write a Go script that prints numbers from 1 to 30.
Answer: I don't know based on the provided context. The given text is about Python programming and does not provide any information about writing a Go script. If you need help with a specific task, I'll be happy to assist you in another way.
--------------------
Explanation:
-
LLM Initialization: The
ChatOllama
is initialized using thellama3.1
model. - Prompt Setup: A system prompt is defined that instructs the assistant to use only the provided context.
- Chain Creation: The question-answer chain and the RAG chain are created.
-
Testing: The RAG chain is tested with two queries:
- The first query is about writing a Python script, which is covered in the context.
- The second query is about writing a Go script, which is not in the context, so the assistant appropriately replies that it doesn't know based on the provided context.
Conclusion
By leveraging Weaviate's cross-references, we can efficiently retrieve original documents by querying their summaries. Integrating this mechanism with LangChain allows us to build powerful RAG pipelines that provide accurate and context-specific responses.
Feel free to explore and modify the code to suit your own datasets and use cases!
Top comments (0)