DEV Community

Oleh Halytskyi
Oleh Halytskyi

Posted on

Retrieving Original Documents via Summaries with Weaviate and LangChain

In the realm of large language models (LLMs) and retrieval-augmented generation (RAG), optimizing the retrieval process is crucial for efficiency and accuracy. In a previous blog post, I discussed how to optimize RAG context by chunking and summarization for technical documents using Chroma VectorDB. Specifically, the post demonstrated how to query summaries while retrieving the original documents.

In this post, we'll explore how to achieve a similar result using Weaviate and its cross-references feature, integrated with LangChain. We'll leverage Weaviate's ability to create cross-references between data objects to efficiently retrieve original documents by querying their summaries.

This tutorial offers a comprehensive and detailed walkthrough of the process. If you prefer to explore the code and outputs directly, you can access the Jupyter notebook here: Jupyter Notebook.

Preparing the Environment

First, set up a Conda environment to manage dependencies and keep the project isolated.

# Create a new environment called 'rag-env'
conda create -n rag-env python=3.12

# Activate the environment
conda activate rag-env

# Install necessary packages
pip install weaviate-client==4.9.0 \
    langchain==0.3.5 \
    langchain-core==0.3.13 \
    langchain-ollama==0.2.0
Enter fullscreen mode Exit fullscreen mode

Note: The following are also required:

Initializing the Weaviate Client

Let's begin by initializing the Weaviate client with authentication.

import getpass
import weaviate
from weaviate.classes.init import Auth

# Prompt for the Weaviate API key
WEAVIATE_API_KEY = getpass.getpass()

# Initialize the Weaviate client with authentication
weaviate_client = weaviate.connect_to_local(
    auth_credentials=Auth.api_key(WEAVIATE_API_KEY)
)

# Check if the client is ready
print("Client is Ready?", weaviate_client.is_ready())
Enter fullscreen mode Exit fullscreen mode

Importing Data with Cross-References

Next, we'll load the original and summarized documents, create Weaviate collections, and insert the data with cross-references.

Note: The files chunked_docs.json and summarized_docs.json are taken from my previous blog post. They were created by adding the following code to the Summarization Based on Headers and Chunk Text section:

# Save the chunked and summarized documents to JSON files
import json
chunked_docs_json = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in chunked_docs]
with open('files/generated/chunked_docs.json', 'w') as f:
    json.dump(chunked_docs_json, f, indent=4)

summarized_docs_json = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in summarized_docs]
with open('files/generated/summarized_docs.json', 'w') as f:
    json.dump(summarized_docs_json, f, indent=4)
Enter fullscreen mode Exit fullscreen mode

In the summarized_docs.json file, metadata.id was changed to metadata.doc_id to avoid conflicts with Weaviate's id field.

Now, let's proceed to import the data:

import json
from langchain_ollama import OllamaEmbeddings

# Load the chunked and summarized documents
with open("files/chunked_docs.json", "r") as f:
    chunked_docs = json.load(f)

with open("files/summarized_docs.json", "r") as f:
    summarized_docs = json.load(f)

# Define the collection names
collections = {
    "original": "OriginalDocuments",
    "summary": "SummarizedDocuments",
}

# Delete collections if they already exist
for collection in collections.values():
    if collection in weaviate_client.collections.list_all(simple=True):
        weaviate_client.collections.delete(collection)

# Create the collections
original_collection_db = weaviate_client.collections.create(collections["original"])
summary_collection_db = weaviate_client.collections.create(collections["summary"])

# Initialize the Ollama embedding model
ollama_emb = OllamaEmbeddings(model="mxbai-embed-large")

# Insert the documents into the collections
for summarized_doc in summarized_docs:
    summarized_doc_id = summarized_doc["metadata"]["doc_id"]
    original_doc = next((doc for doc in chunked_docs if doc.get("metadata", {}).get("summary_id") == summarized_doc_id), None)

    if original_doc:
        original_uuid = original_collection_db.data.insert(
            {
                "page_content": original_doc["page_content"],
            },
            vector=ollama_emb.embed_query(original_doc["page_content"]),
        )
        summary_collection_db.data.insert(
            {
                "page_content": summarized_doc["page_content"],
            },
            references={"originalDocument": original_uuid},
            vector=ollama_emb.embed_query(summarized_doc["page_content"]),
        )

# Verify the number of documents in the collections
original_count = len(original_collection_db)
summary_count = len(summary_collection_db)
print(f"Number of documents in the original collection: {original_count}")
print(f"Number of documents in the summary collection: {summary_count}")
Enter fullscreen mode Exit fullscreen mode

Output:

Number of documents in the original collection: 34
Number of documents in the summary collection: 34
Enter fullscreen mode Exit fullscreen mode

Explanation:

  • Loading Data: The chunked and summarized documents are loaded from JSON files.
  • Collections: Two collections, OriginalDocuments and SummarizedDocuments, are defined and created in Weaviate.
  • Embeddings: The embedding model is initialized using Ollama's mxbai-embed-large model.
  • Data Insertion: Each summarized document is matched with its corresponding original document. The original documents are inserted into the OriginalDocuments collection, and the summarized documents are inserted into the SummarizedDocuments collection, including a cross-reference that links back to their respective original documents.

Why Create Vectors Manually?

You might be curious why the vectors are created manually using ollama_emb.embed_query instead of leveraging Ollama Embeddings with Weaviate. The reason is that the Weaviate instance is running on a Kubernetes cluster in a lab environment, while Ollama operates in a different isolated environment to which Weaviate doesn't have connectivity. Since Weaviate cannot access the environment where Ollama is running, this approach demonstrates how to insert documents into Weaviate by manually generating embeddings in a separate environment and then supplying them directly to Weaviate.

Querying with Cross-References

The aim of this example is to demonstrate how to perform queries with cross-references by using Weaviate directly. We'll define a function that retrieves documents by querying the summaries and then obtains the original documents via cross-references. This will allow us to see the outputs for both the summaries and the original documents.

from weaviate.classes.query import QueryReference, MetadataQuery

# Define a function to retrieve documents
def retrieve_documents(query, vector, limit=2, score_threshold=0.8):
    response = summary_collection_db.query.hybrid(
        query,
        vector=vector,
        limit=limit,
        return_references=QueryReference(link_on="originalDocument"),
        return_metadata=MetadataQuery(score=True),
    )

    summary_docs = []
    original_docs = []
    for o in response.objects:
        if o.metadata.score is not None and o.metadata.score >= score_threshold:
            summary_doc = {"page_content": o.properties["page_content"]}
            summary_docs.append(summary_doc)
            for ref_obj in o.references["originalDocument"].objects:
                original_doc = {"page_content": ref_obj.properties["page_content"]}
                original_docs.append(original_doc)

    return summary_docs, original_docs

# Define a query
query = "I want to write a Python script that prints numbers from 1 to 30."
vector = ollama_emb.embed_query(query)
summary_docs, original_docs = retrieve_documents(query, vector)

# Print the summarized and original documents
print("Summarized Documents:")
for i, doc in enumerate(summary_docs, start=1):
    print(f"Summarized Document #{i}")
    print("--------------------")
    print(doc["page_content"])
    print("--------------------")
    print()

print("Original Documents:")
for i, doc in enumerate(original_docs, start=1):
    print(f"Original Document #{i}")
    print("--------------------")
    print(doc["page_content"])
    print("--------------------")
    print()
Enter fullscreen mode Exit fullscreen mode

Output:

Summarized Documents:
Summarized Document #1
--------------------
The `break` statement exits the innermost enclosing for or while loop, stopping execution of the loop and continuing with the next statement. This is demonstrated by a nested for loop that prints factors of numbers from 2 to 9, where the break statement stops the loop when a factor is found. The `continue` statement skips the rest of the current iteration in a loop and moves on to the next one, as shown by a for loop that iterates over numbers from 2 to 9, printing even numbers and skipping odd ones.
--------------------

Summarized Document #2
--------------------
The built-in `range()` function generates arithmetic progressions that can be used for iteration over a sequence of numbers. It takes three parameters: start point, end point, and step (default is 1), and returns an iterator that produces the specified range of values. The end point is never part of the generated sequence. To iterate over the indices of a sequence, `range()` can be combined with `len()`, but in most cases it's more convenient to use the `enumerate()` function for this purpose.
--------------------

Original Documents:
Original Document #1
--------------------
The [`break`](../reference/simple_stmts.html#break) statement breaks out of the innermost enclosing
[`for`](../reference/compound_stmts.html#for) or [`while`](../reference/compound_stmts.html#while) loop:

```
>>> for n in range(2, 10):
...     for x in range(2, n):
...         if n % x == 0:
...             print(f"{n} equals {x} * {n//x}")
...             break
...
4 equals 2 * 2
6 equals 2 * 3
8 equals 2 * 4
9 equals 3 * 3
```

The [`continue`](../reference/simple_stmts.html#continue) statement continues with the next
iteration of the loop:

```
>>> for num in range(2, 10):
...     if num % 2 == 0:
...         print(f"Found an even number {num}")
...         continue
...     print(f"Found an odd number {num}")
...
Found an even number 2
Found an odd number 3
Found an even number 4
Found an odd number 5
Found an even number 6
Found an odd number 7
Found an even number 8
Found an odd number 9
```
--------------------

Original Document #2
--------------------
If you do need to iterate over a sequence of numbers, the built\-in function
[`range()`](../library/stdtypes.html#range "range") comes in handy. It generates arithmetic progressions:

```
>>> for i in range(5):
...     print(i)
...
0
1
2
3
4
```

The given end point is never part of the generated sequence; `range(10)` generates
10 values, the legal indices for items of a sequence of length 10\. It
is possible to let the range start at another number, or to specify a different
increment (even negative; sometimes this is called the β€˜step’):

```
>>> list(range(5, 10))
[5, 6, 7, 8, 9]

>>> list(range(0, 10, 3))
[0, 3, 6, 9]

>>> list(range(-10, -100, -30))
[-10, -40, -70]
```

To iterate over the indices of a sequence, you can combine [`range()`](../library/stdtypes.html#range "range") and
[`len()`](../library/functions.html#len "len") as follows:

```
>>> a = ['Mary', 'had', 'a', 'little', 'lamb']
>>> for i in range(len(a)):
...     print(i, a[i])
...
0 Mary
1 had
2 a
3 little
4 lamb
```

In most such cases, however, it is convenient to use the [`enumerate()`](../library/functions.html#enumerate "enumerate")
function, see [Looping Techniques](datastructures.html#tut-loopidioms).
--------------------
Enter fullscreen mode Exit fullscreen mode

Explanation:

  • Hybrid Query: The retrieve_documents function utilizes Weaviate's hybrid search on the SummarizedDocuments collection. Hybrid search combines the results of a vector similarity search and a keyword (BM25F) search by fusing the two result sets. This approach enhances retrieval accuracy by considering both semantic similarity (from embeddings) and keyword relevance.
  • Cross-References: It retrieves the summarized documents and accesses the original documents via the originalDocument cross-reference.
  • Filtering: Only documents with a score above the threshold are considered.
  • Result: The retrieved summarized and original documents are printed to display the outputs.

Creating a Custom Retriever

To integrate this retrieval mechanism with LangChain, we'll implement a custom retriever that leverages Weaviate's cross-references.

from typing import List, Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from weaviate.classes.query import QueryReference, MetadataQuery

class VectorDBRetrieverCrossReferences(BaseRetriever):
    """A custom retriever that retrieves documents from a Weaviate vector database."""
    summary_collection_db: Any
    ollama_emb: Any
    k: int = 2
    score_threshold: float = 0.8
    return_source_documents: bool = False

    def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
        """Sync implementation for retriever."""
        vector = self.ollama_emb.embed_query(query)
        response = self.summary_collection_db.query.hybrid(
            query,
            vector=vector,
            limit=self.k,
            return_references=QueryReference(link_on="originalDocument"),
            return_metadata=MetadataQuery(score=True),
        )

        original_docs = []
        for o in response.objects:
            if o.metadata.score is not None and o.metadata.score >= self.score_threshold:
                for ref_obj in o.references["originalDocument"].objects:
                    doc_content = ref_obj.properties["page_content"]
                    metadata = {"source": ref_obj.properties.get("source")} if self.return_source_documents else {}
                    original_docs.append(Document(page_content=doc_content, metadata=metadata))

        return original_docs

# Initialize the custom retriever
retriever = VectorDBRetrieverCrossReferences(
    summary_collection_db=summary_collection_db,
    ollama_emb=ollama_emb
)

# Retrieve documents using the custom retriever
query = "I want to write a Python script that prints numbers from 1 to 30."
documents = retriever.invoke(query)

# Print the retrieved documents
for i, doc in enumerate(documents, start=1):
    print(f"Document #{i}")
    print("--------------------")
    print(doc.page_content)
    print("--------------------")
    print()
Enter fullscreen mode Exit fullscreen mode

Output:

Document #1
--------------------
The [`break`](../reference/simple_stmts.html#break) statement breaks out of the innermost enclosing
[`for`](../reference/compound_stmts.html#for) or [`while`](../reference/compound_stmts.html#while) loop:

```
>>> for n in range(2, 10):
...     for x in range(2, n):
...         if n % x == 0:
...             print(f"{n} equals {x} * {n//x}")
...             break
...
4 equals 2 * 2
6 equals 2 * 3
8 equals 2 * 4
9 equals 3 * 3
```

The [`continue`](../reference/simple_stmts.html#continue) statement continues with the next
iteration of the loop:

```
>>> for num in range(2, 10):
...     if num % 2 == 0:
...         print(f"Found an even number {num}")
...         continue
...     print(f"Found an odd number {num}")
...
Found an even number 2
Found an odd number 3
Found an even number 4
Found an odd number 5
Found an even number 6
Found an odd number 7
Found an even number 8
Found an odd number 9
```
--------------------

Document #2
--------------------
If you do need to iterate over a sequence of numbers, the built\-in function
[`range()`](../library/stdtypes.html#range "range") comes in handy. It generates arithmetic progressions:

```
>>> for i in range(5):
...     print(i)
...
0
1
2
3
4
```

The given end point is never part of the generated sequence; `range(10)` generates
10 values, the legal indices for items of a sequence of length 10\. It
is possible to let the range start at another number, or to specify a different
increment (even negative; sometimes this is called the β€˜step’):

```
>>> list(range(5, 10))
[5, 6, 7, 8, 9]

>>> list(range(0, 10, 3))
[0, 3, 6, 9]

>>> list(range(-10, -100, -30))
[-10, -40, -70]
```

To iterate over the indices of a sequence, you can combine [`range()`](../library/stdtypes.html#range "range") and
[`len()`](../library/functions.html#len "len") as follows:

```
>>> a = ['Mary', 'had', 'a', 'little', 'lamb']
>>> for i in range(len(a)):
...     print(i, a[i])
...
0 Mary
1 had
2 a
3 little
4 lamb
```

In most such cases, however, it is convenient to use the [`enumerate()`](../library/functions.html#enumerate "enumerate")
function, see [Looping Techniques](datastructures.html#tut-loopidioms).
--------------------
Enter fullscreen mode Exit fullscreen mode

Explanation:

  • Custom Retriever: The VectorDBRetrieverCrossReferences class extends LangChain's BaseRetriever.
  • Method Override: The _get_relevant_documents method performs the hybrid query and retrieves the original documents via cross-references.
  • Integration: The retriever is initialized with the summary_collection_db and the embedding model, making it compatible with LangChain.

Example of Simple Retrieval-Augmented Generation (RAG)

Finally, let's build a simple RAG pipeline using the custom retriever and a language model.

from langchain_ollama.chat_models import ChatOllama
from langchain.chains import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain

# Initialize the ChatOllama model
llm = ChatOllama(model="llama3.1", temperature=0, num_ctx=16384)

# Define the system prompt template
system_prompt = (
    "You are an assistant for answering questions. "
    "Use only the exact information provided in the context, do not include external knowledge or guesses. "
    "If the answer cannot be inferred from the context, reply: 'I don't know based on the provided context.' "
    "Do not provide answers that are not based on the context, including code examples or references to other libraries. "
    "Format your entire response in valid Markdown, including code snippets and links. "
    "Always adhere to these rules strictly.\n\n"
    "Context: \n"
    "{context}"
)

# Define the chat prompt
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)

# Create the question-answer chain
question_answer_chain = create_stuff_documents_chain(llm, prompt)

# Create the retrieval-augmented generation (RAG) chain
rag_chain = create_retrieval_chain(retriever, question_answer_chain)

# Query #1
query = "I want to write a Python script that prints numbers from 1 to 30."
response = rag_chain.invoke({"input": query})

# Print the response
print("Example #1")
print("--------------------")
print(f"Query: {query}")
print(f"Answer: {response['answer']}")
print("--------------------")
print("\n")

# Query #2 (check that context only is used)
query = "I want to write a Go script that prints numbers from 1 to 30."
response = rag_chain.invoke({"input": query})

# Print the response
print("Example #2")
print("--------------------")
print(f"Query: {query}")
print(f"Answer: {response['answer']}")
print("--------------------")
Enter fullscreen mode Exit fullscreen mode

Output:

Example #1
--------------------
Query: I want to write a Python script that prints numbers from 1 to 30.
Answer: You can use the `range()` function in Python to generate a sequence of numbers and print them.

Here's how you can do it:

```
for i in range(1, 31):
    print(i)
```

This will print numbers from 1 to 30. The `range()` function generates numbers starting from 0 by default, so we start at 1 and end at 30 (which is exclusive).
--------------------


Example #2
--------------------
Query: I want to write a Go script that prints numbers from 1 to 30.
Answer: I don't know based on the provided context. The given text is about Python programming and does not provide any information about writing a Go script. If you need help with a specific task, I'll be happy to assist you in another way.
--------------------
Enter fullscreen mode Exit fullscreen mode

Explanation:

  • LLM Initialization: The ChatOllama is initialized using the llama3.1 model.
  • Prompt Setup: A system prompt is defined that instructs the assistant to use only the provided context.
  • Chain Creation: The question-answer chain and the RAG chain are created.
  • Testing: The RAG chain is tested with two queries:
    • The first query is about writing a Python script, which is covered in the context.
    • The second query is about writing a Go script, which is not in the context, so the assistant appropriately replies that it doesn't know based on the provided context.

Conclusion

By leveraging Weaviate's cross-references, we can efficiently retrieve original documents by querying their summaries. Integrating this mechanism with LangChain allows us to build powerful RAG pipelines that provide accurate and context-specific responses.

Feel free to explore and modify the code to suit your own datasets and use cases!

Top comments (0)