Ever been reading through a PDF document and thought, "Hmm, if only there was a way I could quickly extract the relevant information"? That would save quite a lot of time. It's been a few years since large language models (LLMs) have been introduced, revolutionizing the way we interact with text data.
LLMs are trained on vast amounts of data from the internet and other text sources, making them highly effective at many general-purpose tasks. However, in certain cases, you might want to train or augment an LLM with your own data to make the responses more relevant to your needs. In this article, I’ll show you how to create an app that uses Retrieval-Augmented Generation (RAG) to answer questions specific to particular documents or web pages. But first...
What is RAG?
RAG stands for Retrieval-Augmented Generation. It is a type of natural language processing framework that combines the benefits of retrieval-based and generative models.
These models:
- Retrieve relevant info from a database or knowledge graph
- Use this information to generate a more accurate response
Pretty useful, especially when you're handling rare topics that the LLM may not know about. I took quite a lot of inspiration from Google's NotebookLM, but this is not going to be anywhere close to what NotebookLM does, just a gentle introduction to help you get a bit of an understanding of what's happening and maybe appreciate RAG.
With that, lets proceed with the...
Setup
For this application, we're going to be making use of the following packages:
-
langchain
- to link us to the open ai large language models -
chroma db
- vector store we're gonna be using to store our information -
streamlit
- for the interface because of how easy it is to setup
Before we continue, let me give you a sneak peek at the final product:
You can find the the code for this application on my github.
Install the packages as follows:
pip install streamlit
pip install langchain langchain-community langchain-openai langchain-chroma
We're also going to need BeautifulSoup
so install it as follows:
pip install bs4
The application has two databases. The vector database to store the documents and a relational database to store the chats. sqlite3
is simple to setup hence we are going to be using it in this tutorial. Feel free to make use of any other databases.
And with that, we're ready to begin.
Database design
Our relational database is going to have 3 tables:
- The
chat
table - which will store the chat names - The
sources
table - which will store the different sources we have loaded into the vector store - The
messages
table - which will store the messages between the human and the AI
Create a file named create_relational_db.py
and add the following code:
import sqlite3
# Connect to SQLite database (or create it if it doesn't exist)
conn = sqlite3.connect("doc_sage.sqlite")
cursor = conn.cursor()
# Create 'chat' table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS chat (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
# Create 'sources' table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS sources (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
source_text TEXT,
type TEXT DEFAULT "document",
chat_id INTEGER,
FOREIGN KEY (chat_id) REFERENCES chat(id)
)
"""
)
# Create 'messages' table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chat_id INTEGER NOT NULL,
sender TEXT NOT NULL,
content TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(chat_id) REFERENCES chat(id)
);
"""
)
# Commit the transaction
conn.commit()
# Close the connection
conn.close()
print("Tables created successfully.")
The souces
table stores the name of the document or the link address. The type field indicates whether the source is a document (file) or webpage. It can only have two values document or link with document as the default.
The messages
table stores the messages between the user and the LLM. It is linked to the chat
table by a foreign key on chat_id. The sender is either ai or user, which is a way to keep track of who send which message. The content is the actual content of the message from either the user or the ai.
Create the database and tables by running:
python create_relational_db.py
Now let us move on to the functions that are going to be operating on the database.
Create another file named db.py
and add the following code:
import sqlite3
# Connect to SQLite database
def connect_db():
return sqlite3.connect("doc_sage.sqlite")
# CRUD Operations for 'chat' table
def create_chat(title):
conn = connect_db()
cursor = conn.cursor()
cursor.execute("INSERT INTO chat (title) VALUES (?)", (title,))
chat_id = cursor.lastrowid
conn.commit()
conn.close()
return chat_id
def list_chats():
conn = connect_db()
cursor = conn.cursor()
cursor.execute("SELECT * FROM chat ORDER BY created_at DESC")
chats = cursor.fetchall()
conn.close()
return chats
def read_chat(chat_id):
conn = connect_db()
cursor = conn.cursor()
cursor.execute("SELECT * FROM chat WHERE id = ?", (chat_id,))
result = cursor.fetchone()
conn.close()
return result
def update_chat(chat_id, new_title):
conn = connect_db()
cursor = conn.cursor()
cursor.execute(
"UPDATE chat SET title = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?",
(new_title, chat_id),
)
conn.commit()
conn.close()
def delete_chat(chat_id):
conn = connect_db()
cursor = conn.cursor()
cursor.execute("DELETE FROM chat WHERE id = ?", (chat_id,))
conn.commit()
conn.close()
def create_source(name, source_text, chat_id, source_type="document"):
conn = connect_db()
cursor = conn.cursor()
cursor.execute(
"INSERT INTO sources (name, source_text, chat_id, type) VALUES (?, ?, ?, ?)",
(name, source_text, chat_id, source_type),
)
conn.commit()
conn.close()
def read_source(source_id):
conn = connect_db()
cursor = conn.cursor()
cursor.execute("SELECT * FROM sources WHERE id = ?", (source_id,))
result = cursor.fetchone()
conn.close()
return result
def update_source(source_id, new_name, new_source_text):
conn = connect_db()
cursor = conn.cursor()
cursor.execute(
"UPDATE sources SET name = ?, source_text = ? WHERE id = ?",
(new_name, new_source_text, source_id),
)
conn.commit()
conn.close()
def list_sources(chat_id, source_type=None):
conn = connect_db()
cursor = conn.cursor()
if source_type:
cursor.execute(
"SELECT * FROM sources WHERE chat_id = ? AND type = ?",
(chat_id, source_type),
)
else:
cursor.execute("SELECT * FROM sources WHERE chat_id = ?", (chat_id,))
sources = cursor.fetchall()
conn.close()
return sources
def delete_source(source_id):
conn = connect_db()
cursor = conn.cursor()
cursor.execute("DELETE FROM sources WHERE id = ?", (source_id,))
conn.commit()
conn.close()
# CRUD Operations for 'messages' table
def create_message(chat_id, sender, content):
conn = connect_db()
cursor = conn.cursor()
cursor.execute(
"INSERT INTO messages (chat_id, sender, content) VALUES (?, ?, ?)",
(chat_id, sender, content),
)
conn.commit()
conn.close()
def get_messages(chat_id):
conn = connect_db()
cursor = conn.cursor()
cursor.execute(
"SELECT sender, content FROM messages WHERE chat_id = ? ORDER BY timestamp ASC",
(chat_id,),
)
messages = cursor.fetchall()
conn.close()
return messages
def delete_messages(chat_id):
conn = connect_db()
cursor = conn.cursor()
cursor.execute("DELETE FROM messages WHERE chat_id = ?", (chat_id,))
conn.commit()
conn.close()
The above code is responsible for all the operations on the sqlite3
database.
RAG Functions
In this section, we are going to create the functions that operate on the vector store, from loading documents to retrieving them and generating the responses. This is where most of the magic happens. Get some coffee because this section is rather long and a little more complicated🙂. But worry not, I will explain as simply as possible.
Create a file named vector_functions.py
and add the following lines:
import os
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_text_splitters import CharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.document_loaders import (
TextLoader,
CSVLoader,
PyPDFLoader,
Docx2txtLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
)
import environ
env = environ.Env()
# reading .env file
environ.Env.read_env()
llm = ChatOpenAI(
model="gpt-4o-mini",
api_key=env("OPENAI_API_KEY"),
)
embeddings = OpenAIEmbeddings(
api_key=env("OPENAI_API_KEY"),
)
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
In this example, I am making use of the python-environ
module to handle the api keys so make sure you have it installed with:
pip install python-environ
Create an .env
file and add your open ai key e.g:
OPENAI_API_KEY=Your-API-Key
Initialize the large language model with the ChatOpenAI
class. In this example we are going to be making use of gpt-4o-mini
. You can make use other models Like Anthropic's Claude. Make sure to refer to the documentation for more information.
OpenAIEmbeddings
is a class that generates embeddings — a way to convert text into a numeric format so that AI models can process it more easily.
CharacterTextSplitter
breaks up long text into smaller chunks to make it easier for the AI model to handle. chunk_size=1000
means each text chunk will contain up to 1000 characters. chunk_overlap=0
means there’s no overlap between chunks; each chunk is independent and contains a unique section of the text. This approach is often useful when working with long articles, documents, or books that need to be processed in parts.
Next we create a function that is going to be responsible for loading different doctypes:
def load_document(file_path: str) -> list[Document]:
"""
Load a document from a file path.
Supports .txt, .pdf, .docx, .csv, .html, and .md files.
Args:
file_path (str): Path to the document file.
Returns:
list[Document]: A list of Document objects.
Raises:
ValueError: If the file type is not supported.
"""
_, file_extension = os.path.splitext(file_path)
if file_extension == ".txt":
loader = TextLoader(file_path)
elif file_extension == ".pdf":
loader = PyPDFLoader(file_path)
elif file_extension == ".docx":
loader = Docx2txtLoader(file_path)
elif file_extension == ".csv":
loader = CSVLoader(file_path)
elif file_extension == ".html":
loader = UnstructuredHTMLLoader(file_path)
elif file_extension == ".md":
loader = UnstructuredMarkdownLoader(file_path)
else:
raise ValueError(f"Unsupported file type: {file_extension}")
return loader.load()
The load_document
function returns a list of Document
objects, which are later split into texts by a text splitter and then saved in the vector store.
Install the following packages that are required by some of the document loaders:
pip install pypdf unstructured docx2txt Markdown
Next we add the following code to create and load collections in our vector store:
# vector_functions.py
def create_collection(collection_name, documents):
"""
Create a new Chroma collection from the given documents.
Args:
collection_name (str): The name of the collection to create.
documents (list): A list of documents to add to the collection.
Returns:
None
This function splits the documents into texts, creates a new Chroma collection,
and persists it to disk.
"""
# Split the documents into smaller text chunks
texts = text_splitter.split_documents(documents)
persist_directory = "./persist"
# Create a new Chroma collection from the text chunks
try:
vectordb = Chroma.from_documents(
documents=texts,
embedding=embeddings,
persist_directory=persist_directory,
collection_name=collection_name,
)
except Exception as e:
print(f"Error creating collection: {e}")
return None
return vectordb
def load_collection(collection_name):
"""
Load an existing Chroma collection.
Args:
collection_name (str): The name of the collection to load.
Returns:
Chroma: The loaded Chroma collection.
This function loads a previously created Chroma collection from disk.
"""
persist_directory = "./persist"
# Load the Chroma collection from the specified directory
vectordb = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings,
collection_name=collection_name,
)
return vectordb
def add_documents_to_collection(vectordb, documents):
"""
Add documents to the vector database collection.
Args:
vectordb: The vector database object to add documents to.
documents: A list of documents to be added to the collection.
This function splits the documents into smaller chunks, adds them to the
vector database, and persists the changes.
"""
# Split the documents into smaller text chunks
texts = text_splitter.split_documents(documents)
# Add the text chunks to the vector database
vectordb.add_documents(texts)
return vectordb
We have created three function; create_collection
, load_collection
and add_documents_to_collection
.
The create_collection
function receives the collection_name
and documents
as arguments. The documents
argument is the list of Document
objects loaded using the load_document
function. After loading the documents, we split the texts, 1000 characters per split(this of course can be changed). We initialize a persist_directory
which is where the collections are going to be stored. Finally, we create a vector store from the given documents.
The load_collection
function loads any collection that was created and returns it.
The add_documents_to_collection
lets us add new documents to a collection. As we are using the app, we start with an empty collection. When we upload a document, it is saved to the database using this function. The documents are first loaded using the load_document
function. This function returns the list of Document
objects. The system then loads a collection using the load_collection
function and the documents are finally added to the collection using the add_documents_to_collection
function.
Next, a retriever is required. This is a function that will return the relevant results based on the search. This retriever queries the vector store and returns the most relevant results:
# vector_functions.py
def load_retriever(collection_name, score_threshold: float = 0.6):
"""
Create a retriever from a Chroma collection with a similarity score threshold.
Args:
collection_name (str): The name of the collection to use.
score_threshold (float): The minimum similarity score threshold for retrieving documents.
Documents with scores below this threshold will be filtered out.
Defaults to 0.6.
Returns:
Retriever: A retriever object that can be used to query the collection with similarity
score filtering.
This function loads a Chroma collection and creates a retriever from it that will only
return documents meeting the specified similarity score threshold.
"""
# Load the Chroma collection
vectordb = load_collection(collection_name)
# Create a retriever from the collection with specified search parameters
retriever = vectordb.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": score_threshold},
)
return retriever
The load_retriever
function accepts two arguments: collection_name
and score_threshold
.
Some of the methods we can use to perform searches in the vector store are:
-
Similarity Search: Finds the items that are "closest" to your search queries. This closeness is measured by metrics such as cosine similarity or Euclidean distance. Usually you have to specify that you want
k
number of items. - Similarity Score Threshold: Returns results with a similarity score that is above a given threshold. This is useful when you want the most relevant matches without unrelated results.
- Maximum Margin Ranking: This one focuses on maximizing the gap between relevant and irrelevant items. This approach helps in clearly ordering results, so the most relevant items appear at the top while less relevant ones are pushed lower. We are making use of the Similarity Score Threshold method, however, you can also make use of the other methods to see how well they perform.
And now we add the last function before moving on to the interface:
# vector_functions.py
def generate_answer_from_context(retriever, question: str):
"""
Ask a question and get an answer based on the provided context.
Args:
retriever: A retriever object to fetch relevant context.
question (str): The question to be answered.
Returns:
str: The answer to the question based on the retrieved context.
"""
# Define the message template for the prompt
message = """
Answer this question using the provided context only.
{question}
Context:
{context}
"""
# Create a chat prompt template from the message
prompt = ChatPromptTemplate.from_messages([("human", message)])
# Create a RAG (Retrieval-Augmented Generation) chain
# This chain retrieves context, passes through the question,
# formats the prompt, and generates an answer using the language model
rag_chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | llm
# Invoke the RAG chain with the question and return the generated content
return rag_chain.invoke(question).content
The purpose of this function is to return an answer that the LLM generates using data retrieved by retriever
. It accepts 2 arguments:
-
retriever
- to find relevant text -
question
- a string containing the user's query.
We define a message
template that structures how we want the question and context to look before we send them to the model. {question}
and {context}
are placeholders that are going to be filled later on with actual values.
ChatPromptTemplate.from_messages([("human", message)])
creates the a structured prompt using our template. This is useful as it guides the models response helping it generate meaningful output. You can read the documentation for more information.
Next, we create the RAG chain:
rag_chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | llm
Lets breakdown what's happening:
-
{"context": retriever, "question": RunnablePassthrough()}
- This dictionary stores input components for the pipeline. We set the
"context"
toretriever
, which is going to return the relevant text as context. -
"question"
usesRunnablePassthrough()
, meaning it passes the question along the chain without modifying it. You can get more information about how it works from the documentation.
- This dictionary stores input components for the pipeline. We set the
-
| prompt
- This part pipes the output from the previous dictionary into the
prompt
object which formats these inputs into a structured prompt.
- This part pipes the output from the previous dictionary into the
-
| llm
- The final stage sends the prompt to the LLM which then generates the response.
rag_chain.invoke(question)
runs the entire chain, passing in the question and returning the answer based on the given context.
All that's left now is to create the interface and connect it with the different functions...
User Interface
The application has 2 pages:
- The chats home page
- The chat page
Lets start by creating the chats home page. Create a file named chats.py
and add the following code:
import streamlit as st
import os, time, math
import requests
from bs4 import BeautifulSoup
from langchain_core.documents import Document
from db import (
read_chat,
create_chat,
list_chats,
delete_chat,
create_message,
get_messages,
create_source,
list_sources,
delete_source,
)
from vector_functions import (
load_document,
create_collection,
load_retriever,
generate_answer_from_context,
add_documents_to_collection,
load_collection,
)
def chats_home():
st.markdown(
"<h1 style='text-align: center;'>DocSage🧙♂️</h1>", unsafe_allow_html=True
)
with st.container(border=True):
col1, col2 = st.columns([0.8, 0.2])
with col1:
chat_title = st.text_input(
"Chat Title", placeholder="Enter Chat Title", key="chat_title"
)
with col2:
st.markdown("<br>", unsafe_allow_html=True) # Add vertical space
if st.button("Create Chat", type="primary"):
if chat_title:
chat_id = create_chat(chat_title)
st.success(f"Created new chat: {chat_title}")
st.query_params.from_dict({"chat_id": chat_id})
st.rerun()
else:
st.warning("Please enter a chat title")
with st.container(border=True):
st.subheader("Previous Chats")
# get previous chats from db
previous_chats = list_chats()
# Pagination settings
chats_per_page = 5
total_pages = math.ceil(len(previous_chats) / chats_per_page)
# Get current page from session state
if "current_page" not in st.session_state:
st.session_state.current_page = 1
# Calculate start and end indices for the current page
start_idx = (st.session_state.current_page - 1) * chats_per_page
end_idx = start_idx + chats_per_page
# Display chats for the current page
for chat in previous_chats[start_idx:end_idx]:
chat_id, chat_title = chat[0], chat[1]
with st.container(border=True):
col1, col2, col3 = st.columns([0.6, 0.2, 0.2])
with col1:
st.markdown(f"**{chat_title}**")
with col2:
if st.button("📂 Open", key=f"open_{chat_id}"):
st.query_params.from_dict({"chat_id": chat_id})
st.rerun()
with col3:
if st.button("🗑️ Delete", key=f"delete_{chat_id}"):
delete_chat(chat_id)
st.success(f"Deleted chat: {chat_title}")
st.rerun()
# Pagination controls
col1, col2, col3 = st.columns([1, 2, 1])
with col1:
if st.button("Previous") and st.session_state.current_page > 1:
st.session_state.current_page -= 1
st.rerun()
with col2:
st.write(f"Page {st.session_state.current_page} of {total_pages}")
with col3:
if st.button("Next") and st.session_state.current_page < total_pages:
st.session_state.current_page += 1
st.rerun()
def main():
chats_home()
if __name__ == "__main__":
main()
The chat page is going to consist of a text input that allows us to create chats and a list of the chats we have created with pagination:
When a user adds a chat title and clicks the create chat button, the following code is executed:
if st.button("Create Chat", type="primary"):
if chat_title:
chat_id = create_chat(chat_title)
st.success(f"Created new chat: {chat_title}")
st.query_params.from_dict({"chat_id": chat_id})
st.rerun()
else:
st.warning("Please enter a chat title")
If the user text input is not empty, the create_chat
function is called with the chat_title as the argument. This function returns the id of the created chat. A success message is shown and we set the query_params
using the from_dict
method to the chat_id. This going to allow us to navigate to the chat page for the newly created chat. If the chat_title is empty, a warning is simply displayed.
Now to add the chat page, add the following code to chats.py
:
def stream_response(response):
"""
Stream a response word by word with a delay between each word.
Args:
response (str): The text response to stream
Yields:
str: Individual words from the response with a space appended
Note:
Adds a 50ms delay between each word to create a typing effect
"""
# Split response into words and stream each one
for word in response.split():
# Yield the word with a space and pause briefly
yield word + " "
time.sleep(0.05)
The stream_response
method is going to simulate the chat GPT model streaming response, writing each word after a 50ms delay rather that spilling out all the text in one go.
Next, add the function for the chat page:
def chat_page(chat_id):
"""
Renders the main chats page where users can:
- Create new chats with titles
- View and manage previous chats
- Navigate through paginated chat history
The page displays a header, chat creation form, and list of existing chats
with options to open each chat.
"""
chat = read_chat(chat_id)
if not chat:
st.error("Chat not found")
return
# Retrieve messages from DB
messages = get_messages(chat_id)
# Display messages
if messages:
for sender, content in messages:
if sender == "user":
with st.chat_message("user"):
st.markdown(content)
elif sender == "ai":
with st.chat_message("assistant"):
st.markdown(content)
else:
st.write("No messages yet. Start the conversation!")
The function first retrieves the chat using the read_chat
method. If the chat is not found, it returns an error message. If the chat exists, it uses the get_messages
method to retrieve the messages exchanged between the user and the LLM. These messages are added to the page. If no messages are found, a simple text urging the user to start a conversation with the bot is displayed.
Now add the following code to the chat_page
function:
def chat_page(chat_id):
# Rest of the code ...
# Add a text input for new messages
prompt = st.chat_input("Type your message here...")
if prompt:
# Save user message
create_message(chat_id, "user", prompt)
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Get AI response
# Load retriever for the chat context
collection_name = f"chat_{chat_id}"
if os.path.exists(f"./persist"):
retriever = load_retriever(collection_name=collection_name)
else:
retriever = None
# Ask question using the retriever
response = (
generate_answer_from_context(retriever, prompt)
if retriever
else "I need some context to answer that question."
)
# Save AI response
create_message(chat_id, "ai", response)
# Display AI response
with st.chat_message("assistant"):
st.write_stream(stream_response(response))
st.rerun()
We use the streamlit chat_input
to get the user's prompt. The prompt is then saved to the messages table using the create_message
function and the message is displayed on the page using st.chat_message
.
The next section gets the AI response as follows:
- creates the collection name using the chat_id
- checks if the persist directory exists which is where the collections are stored
- If the persist directory exists, a retriever object is created using
load_retriever
method. - If the persist directory does not exist, the retriever object is set to
None
- If the retriever object is not
None
, thegenerate_answer_from_context
method is used to get a response from the model. Remember the retriever first retrieves relevant texts and feeds them to the LLM. The LLM uses these texts to come up with an appropriate response.
The LLMs response is then saved to the model using the create_message
function and written on to the page using st.write_stream
and the stream_response
function. Finally, the page is reloaded with the rerun
method to display the text correctly.
Sidebar
In this section, we look at creating the sidebar as well as saving documents and links to the vector store. Add the following code to the chat_page
function:
def chat_page(chat_id):
# rest of the code ...
# Sidebar
with st.sidebar:
# Button to return to the main chats page
if st.button("Back to Chats"):
st.query_params.clear()
st.rerun()
# Chat name
st.subheader(f"{chat[1]}")
# Documents Section
st.subheader("📑 Documents")
# Get all "document" type sources
documents = list_sources(chat_id, source_type="document")
if documents:
# list the documents
for doc in documents:
doc_id = doc[0]
doc_name = doc[1]
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.write(doc_name)
with col2:
if st.button("❌", key=f"delete_doc_{doc_id}"):
delete_source(doc_id)
st.success(f"Deleted document: {doc_name}")
st.rerun()
else:
st.write("No documents uploaded.")
uploaded_file = st.file_uploader("Upload Document", key="file_uploader")
if uploaded_file:
# Save document content to database
with st.spinner("Processing document..."):
temp_dir = "temp_files"
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, uploaded_file.name)
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Load document
document = load_document(temp_file_path)
# Create or update collection for this chat
collection_name = f"chat_{chat_id}"
if not os.path.exists(f"./persist/{collection_name}"):
vectordb = create_collection(collection_name, document)
else:
vectordb = load_collection(collection_name)
vectordb = add_documents_to_collection(vectordb, document)
# Save source to database
create_source(uploaded_file.name, "", chat_id, source_type="document")
# Remove temp file
os.remove(temp_file_path)
del st.session_state["file_uploader"]
st.rerun()
The above code lists documents that were previously added to the vector store if there are any, and allows the user to delete them also.
When a user uploads a document, a spinner is shown whilst the document is added to the vector database. A temporary file directory is created if it is not created already and the file is saved to that directory. The document is then loaded using the load_document
function, which returns the file as a list of Document
objects. A collection is created or loaded, depending on whether it had been created before or not. If it had been created before, the documents are added to that collection using the add_documents_to_collection
method. The create_source
method is used to save the information in the relational database and finally the file is deleted from the temporary directory.
Now let us complete the function by adding link processing. In the chat_page
function, add the following code:
def chat_page(chat_id):
# rest of the code
with st.sidebar:
# rest of the code...
# Links Section
st.subheader("🔗 Links")
# Display list of links
links = list_sources(chat_id, source_type="link")
if links:
for link in links:
link_id = link[0]
link_url = link[1]
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.markdown(f"[{link_url}]({link_url})")
with col2:
if st.button("❌ ", key=f"delete_link_{link_id}"):
delete_source(link_id)
st.success(f"Deleted link: {link_url}")
st.rerun()
else:
st.write("No links added.")
# Add new link
new_link = st.text_input("Add a link", key="new_link")
if st.button("Add Link", key="add_link_btn"):
if new_link:
with st.spinner("Processing link..."):
# Fetch content from the link
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.111 Safari/537.36"
}
response = requests.get(new_link, headers=headers)
soup = BeautifulSoup(response.text, "html.parser")
# Check if the content was successfully retrieved
if response.status_code == 200 and soup.text.strip():
link_content = soup.get_text(separator="\n")
else:
st.toast(
"Unable to retrieve content from the link. It may be empty or inaccessible.",
icon="🚨",
)
return
# Save link content to vector store
documents = [
Document(
page_content=link_content, metadata={"source": new_link}
)
]
collection_name = f"chat_{chat_id}"
if not os.path.exists(f"./persist"):
create_collection(collection_name, documents)
else:
vectordb = load_collection(collection_name)
add_documents_to_collection(vectordb, documents)
# Save link to database
create_source(new_link, "", chat_id, source_type="link")
st.success(f"Added link: {new_link}")
del st.session_state["add_link_btn"]
st.rerun()
except Exception as e:
st.toast(
f"Failed to fetch content from the link: {e}", icon="⚠️"
)
else:
st.toast("Please enter a link", icon="❗")
The links are listed just like the documents. If a link is pasted into the input and the user clicks the Add Link button, a spinner is shown to show that the request is being processed. Using the requests library, the code attempts to fetch the content of the webpage. The User-Agent header is used to mimic a browser request since most websites will deny access to their content otherwise.
BeautifulSoup
is then used to extract the content from the response. The code then checks if the response was successful (status code 200) and if the content is not empty. If not, a toast message is shown indicating that the link may be empty or inaccessible. If the content is valid, it is then saved as a Document
object in a list with metadata indicating the source URL and just like before, it is added to the collection.
The full chat_page
function looks as follows:
def chat_page(chat_id):
"""
Display the chat page for a specific chat ID.
This function handles displaying and managing an individual chat conversation, including:
- Showing the chat history
- Allowing users to send new messages
- Streaming AI responses
- Managing chat context through a vector store retriever
Args:
chat_id (int): The ID of the chat to display
Returns:
None
"""
chat = read_chat(chat_id)
if not chat:
st.error("Chat not found")
return
# Retrieve messages from DB
messages = get_messages(chat_id)
# Display messages
if messages:
for sender, content in messages:
if sender == "user":
with st.chat_message("user"):
st.markdown(content)
elif sender == "ai":
with st.chat_message("assistant"):
st.markdown(content)
else:
st.write("No messages yet. Start the conversation!")
# Add a text input for new messages
prompt = st.chat_input("Type your message here...")
if prompt:
# Save user message
create_message(chat_id, "user", prompt)
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Get AI response
# Load retriever for the chat context
collection_name = f"chat_{chat_id}"
if os.path.exists(f"./persist"):
retriever = load_retriever(collection_name=collection_name)
else:
retriever = None
# Ask question using the retriever
response = (
generate_answer_from_context(retriever, prompt)
if retriever
else "I need some context to answer that question."
)
# Save AI response
create_message(chat_id, "ai", response)
# Display AI response
with st.chat_message("assistant"):
st.write_stream(stream_response(response))
st.rerun()
# Sidebar for context
with st.sidebar:
# Button to return to the main chats page
if st.button("Back to Chats"):
st.query_params.clear()
st.rerun()
st.subheader(f"{chat[1]}")
# Documents Section
st.subheader("📑 Documents")
# Display list of documents
documents = list_sources(chat_id, source_type="document")
if documents:
for doc in documents:
doc_id = doc[0]
doc_name = doc[1]
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.write(doc_name)
with col2:
if st.button("❌", key=f"delete_doc_{doc_id}"):
delete_source(doc_id)
st.success(f"Deleted document: {doc_name}")
st.rerun()
else:
st.write("No documents uploaded.")
uploaded_file = st.file_uploader("Upload Document", key="file_uploader")
if uploaded_file:
# Save document content to database
with st.spinner("Processing document..."):
temp_dir = "temp_files"
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, uploaded_file.name)
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Load document
document = load_document(temp_file_path)
# Create or update collection for this chat
collection_name = f"chat_{chat_id}"
if not os.path.exists(f"./persist/{collection_name}"):
vectordb = create_collection(collection_name, document)
else:
vectordb = load_collection(collection_name)
vectordb = add_documents_to_collection(vectordb, document)
# Save source to database
create_source(uploaded_file.name, "", chat_id, source_type="document")
# Remove temp file
os.remove(temp_file_path)
del st.session_state["file_uploader"]
st.rerun()
# Links Section
st.subheader("🔗 Links")
# Display list of links
links = list_sources(chat_id, source_type="link")
if links:
for link in links:
link_id = link[0]
link_url = link[1]
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.markdown(f"[{link_url}]({link_url})")
with col2:
if st.button("❌ ", key=f"delete_link_{link_id}"):
delete_source(link_id)
st.success(f"Deleted link: {link_url}")
st.rerun()
else:
st.write("No links added.")
# Add new link
new_link = st.text_input("Add a link", key="new_link")
if st.button("Add Link", key="add_link_btn"):
if new_link:
with st.spinner("Processing link..."):
# Fetch content from the link
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.111 Safari/537.36"
}
response = requests.get(new_link, headers=headers)
soup = BeautifulSoup(response.text, "html.parser")
# Check if the content was successfully retrieved
if response.status_code == 200 and soup.text.strip():
link_content = soup.get_text(separator="\n")
else:
st.toast(
"Unable to retrieve content from the link. It may be empty or inaccessible.",
icon="🚨",
)
return
# Save link content to vector store
documents = [
Document(
page_content=link_content, metadata={"source": new_link}
)
]
collection_name = f"chat_{chat_id}"
if not os.path.exists(f"./persist"):
create_collection(collection_name, documents)
else:
vectordb = load_collection(collection_name)
add_documents_to_collection(vectordb, documents)
# Save link to database
create_source(new_link, "", chat_id, source_type="link")
st.success(f"Added link: {new_link}")
del st.session_state["add_link_btn"]
st.rerun()
except Exception as e:
st.toast(
f"Failed to fetch content from the link: {e}", icon="⚠️"
)
else:
st.toast("Please enter a link", icon="❗")
Now modify the main
function as follows:
def main():
"""
Main entry point for the chat application.
Handles routing between the chats list page and individual chat pages:
- If a chat_id is present in URL parameters, displays that specific chat
- Otherwise shows the main chats listing page
The function uses Streamlit query parameters to maintain state between page loads
and determine which view to display.
"""
query_params = st.query_params
if "chat_id" in query_params:
chat_id = query_params["chat_id"]
chat_page(chat_id)
else:
chats_home()
if __name__ == "__main__":
main()
What the main function is simply doing is checking if the chat_id is present in the URL parameter. If it is, it calls chat_page
and displays the specific chat. If not, it simply shows the main chat listings page.
With that, the application is complete!!!
Run the app with:
streamlit run chats.py
Conclusion
RAG applications opens up possibilities for building really intelligent, context aware systems. Whether you're looking to build a knowledge-based chatbot, a document search tool, or any application requiring contextual responses, RAG is an ideal approach.
I hope you have found this article helpful!
Top comments (0)