Introduction
This article is Part 2 of a series. If you haven't read Part 1 yet, please check it out first:
Part 1: Creating an Image Gallery using Streamlit in Snowflake
In Part 1, we created an image gallery app using Streamlit in Snowflake that displayed images stored in the app's default internal stage. In Part 2, we'll build upon that image gallery app to generate captions for each image, making it easier to utilize unstructured data.
Note: This article represents my personal views and not those of Snowflake.
Feature Overview
Goals
- (Done) Display image data using Streamlit in Snowflake
- *Add captions to images using Streamlit in Snowflake
- Generate vector data based on image captions
- Implement image search using Streamlit in Snowflake
*: Scope for Part 2
Features to be Implemented in Part 2
- Manually create and edit image captions
- Automatically generate captions for individual images
- Bulk generate captions for images without captions
Final Image of Part 2
Prerequisites
- Snowflake
- A Snowflake account
- Streamlit in Snowflake installation package
- boto3 1.28.64
- AWS
- An AWS account with access to Amazon Bedrock (we'll be using Claude 3.5 Sonnet in the us-east-1 region for this tutorial)
Steps
(Omitted) Create a Streamlit in Snowflake App and Upload Images
If you haven't done this yet, please follow the steps in Part 1.
(Omitted) Enable access to Amazon Bedrock from the Streamlit in Snowflake app
To automatically add captions to images, there are several options to consider:
- Implement processing using Python image processing libraries
- Create an ML model for image recognition to generate captions for images
- Use existing AI models for images like BLIP-2 to generate captions
- Pass images to a multimodal GenAI to generate captions
- Use a SaaS service for caption generation
Since we've previously introduced how to connect to Amazon Bedrock, we'll use Amazon Bedrock's anthropic.claude-3-5-sonnet
as a multimodal GenAI (option 4) to generate image captions.
For instructions on setting up access to Amazon Bedrock, please refer to the Calling Amazon Bedrock directly from Streamlit in Snowflake (SiS).
Run the Streamlit in Snowflake App
Copy and paste the following code into the Streamlit in Snowflake app editor:
import streamlit as st
import pandas as pd
import os
import base64
import boto3
import json
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col, when_matched, when_not_matched
import _snowflake
from PIL import Image
import io
# Streamlit page configuration
st.set_page_config(layout="wide", page_title="Image Gallery")
# Image folder path
IMAGE_FOLDER = "image"
# Get Snowflake session
session = get_active_session()
# Create table if it doesn't exist (only on first run)
@st.cache_resource
def create_table_if_not_exists():
session.sql("""
CREATE TABLE IF NOT EXISTS IMAGE_METADATA (
FILE_NAME STRING,
DESCRIPTION STRING,
VECTOR VECTOR(FLOAT, 1024)
)
""").collect()
create_table_if_not_exists()
# Function to get AWS credentials
def get_aws_credentials():
aws_key_object = _snowflake.get_username_password('bedrock_key')
region = 'us-east-1'
return {
'aws_access_key_id': aws_key_object.username,
'aws_secret_access_key': aws_key_object.password,
'region_name': region
}, region
# Set up Bedrock client
boto3_session_args, region = get_aws_credentials()
boto3_session = boto3.Session(**boto3_session_args)
bedrock = boto3_session.client('bedrock-runtime', region_name=region)
# Get image data
@st.cache_data
def get_image_data():
image_files = [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif'))]
return [{"FILE_NAME": f, "IMG_PATH": os.path.join(IMAGE_FOLDER, f)} for f in image_files]
# Get metadata
@st.cache_data
def get_metadata():
return session.table("IMAGE_METADATA").select("FILE_NAME", "DESCRIPTION").to_pandas()
# Convert image to thumbnail and encode as base64
@st.cache_data
def get_thumbnail_base64(img_path, max_size=(300, 300)):
with Image.open(img_path) as img:
img.thumbnail(max_size)
buffered = io.BytesIO()
img.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
# Initialize image data and metadata
if 'img_df' not in st.session_state:
st.session_state.img_df = get_image_data()
if 'metadata_df' not in st.session_state:
st.session_state.metadata_df = get_metadata()
# Display image gallery
def show_image_gallery():
st.title("Image Gallery")
num_columns = st.slider("Width:", min_value=1, max_value=5, value=4)
cols = st.columns(num_columns)
for i, img in enumerate(st.session_state.img_df):
with cols[i % num_columns]:
st.image(img["IMG_PATH"], caption=None, use_column_width=True)
# Edit image descriptions
def edit_image_descriptions():
st.title("Edit Image Descriptions")
st.session_state.metadata_df = get_metadata()
# Add new images to metadata
for img in st.session_state.img_df:
if img["FILE_NAME"] not in st.session_state.metadata_df["FILE_NAME"].values:
new_row = pd.DataFrame({"FILE_NAME": [img["FILE_NAME"]], "DESCRIPTION": [""]})
st.session_state.metadata_df = pd.concat([st.session_state.metadata_df, new_row], ignore_index=True)
merged_df = pd.merge(st.session_state.metadata_df, pd.DataFrame(st.session_state.img_df), on="FILE_NAME", how="left")
with st.form("edit_descriptions"):
for _, row in merged_df.iterrows():
col1, col2 = st.columns([1, 3])
with col1:
st.image(row["IMG_PATH"], width=100)
with col2:
new_description = st.text_input(f"Description for {row['FILE_NAME']}", value=row["DESCRIPTION"], key=row['FILE_NAME'])
merged_df.loc[merged_df["FILE_NAME"] == row["FILE_NAME"], "DESCRIPTION"] = new_description
submit_button = st.form_submit_button("Save Changes")
if submit_button:
update_snowflake_table(merged_df[['FILE_NAME', 'DESCRIPTION']])
st.success("Changes saved successfully!")
st.cache_data.clear()
st.session_state.metadata_df = get_metadata()
# Function to generate image description
def generate_description(image_path):
image_base64 = get_thumbnail_base64(image_path)
prompt = """
Describe this image in English within 400 characters, in a single line.
Only output the image description without any additional response.
"""
request_body = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 200000,
"messages": [
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": image_base64
}
},
{
"type": "text",
"text": prompt
}
]
}
]
}
response = bedrock.invoke_model(
body=json.dumps(request_body),
modelId="anthropic.claude-3-5-sonnet-20240620-v1:0",
accept='application/json',
contentType='application/json'
)
response_body = json.loads(response.get('body').read())
return response_body["content"][0]["text"]
# Function to update Snowflake table
def update_snowflake_table(update_df):
snow_df = session.create_dataframe(update_df)
session.table("IMAGE_METADATA").merge(
snow_df,
(session.table("IMAGE_METADATA").FILE_NAME == snow_df.FILE_NAME),
[
when_matched().update({
"DESCRIPTION": snow_df.DESCRIPTION
}),
when_not_matched().insert({
"FILE_NAME": snow_df.FILE_NAME,
"DESCRIPTION": snow_df.DESCRIPTION
})
]
)
# Generate image descriptions
def generate_image_descriptions():
st.title("Generate Image Descriptions")
if 'generated_description' not in st.session_state:
st.session_state.generated_description = None
if 'selected_image' not in st.session_state:
st.session_state.selected_image = None
# Generate description for individual image
with st.form("generate_description"):
selected_image = st.selectbox("Select an image", options=[img["FILE_NAME"] for img in st.session_state.img_df])
generate_button = st.form_submit_button("Generate Description")
if generate_button:
image_info = next(img for img in st.session_state.img_df if img['FILE_NAME'] == selected_image)
generated_description = generate_description(image_info['IMG_PATH'])
st.session_state.generated_description = generated_description
st.session_state.selected_image = selected_image
st.image(image_info['IMG_PATH'], width=300)
st.write("Generated Description:")
st.write(generated_description)
if st.session_state.generated_description is not None:
if st.button("Save Description"):
update_snowflake_table(pd.DataFrame({'FILE_NAME': [st.session_state.selected_image], 'DESCRIPTION': [st.session_state.generated_description]}))
st.success("Description saved successfully!")
st.cache_data.clear()
st.session_state.metadata_df = get_metadata()
st.session_state.generated_description = None
st.session_state.selected_image = None
# Bulk process images without descriptions
st.subheader("Bulk Process Images Without Descriptions")
images_without_description = [
img for img in st.session_state.img_df
if img["FILE_NAME"] not in st.session_state.metadata_df[
st.session_state.metadata_df["DESCRIPTION"].notna() &
(st.session_state.metadata_df["DESCRIPTION"] != "")
]["FILE_NAME"].values
]
if images_without_description:
st.write(f"{len(images_without_description)} images do not have descriptions.")
if st.button("Generate Descriptions in Bulk"):
progress_bar = st.progress(0)
for i, img in enumerate(images_without_description):
generated_description = generate_description(img['IMG_PATH'])
update_snowflake_table(pd.DataFrame({'FILE_NAME': [img['FILE_NAME']], 'DESCRIPTION': [generated_description]}))
progress_bar.progress((i + 1) / len(images_without_description))
st.success("Descriptions generated and saved for all images!")
st.cache_data.clear()
st.session_state.metadata_df = get_metadata()
else:
st.write("All images have descriptions.")
# Display debug information
st.subheader("Metadata Information")
st.write(st.session_state.metadata_df)
# Main application execution
if __name__ == "__main__":
page = st.sidebar.selectbox(
"Choose a page",
["Image Gallery", "Edit Descriptions", "Generate Descriptions"]
)
if page == "Image Gallery":
show_image_gallery()
elif page == "Edit Descriptions":
edit_image_descriptions()
elif page == "Generate Descriptions":
generate_image_descriptions()
Code Explanation
The following section creates a table to store image metadata. The third column, VECTOR
, is intended to store vector data in Part 3:
# Create table if it doesn't exist (only on first run)
@st.cache_resource
def create_table_if_not_exists():
session.sql("""
CREATE TABLE IF NOT EXISTS IMAGE_METADATA (
FILE_NAME STRING,
DESCRIPTION STRING,
VECTOR VECTOR(FLOAT, 1024)
)
""").collect()
create_table_if_not_exists()
This function resizes the image and encodes it as Base64 when passing image data to Amazon Bedrock. This is done to improve caption generation performance and because Amazon Bedrock likely cannot accept raw binary image data:
# Convert image to thumbnail and encode as base64
@st.cache_data
def get_thumbnail_base64(img_path, max_size=(300, 300)):
with Image.open(img_path) as img:
img.thumbnail(max_size)
buffered = io.BytesIO()
img.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
This section passes the caption generation prompt to Amazon Bedrock. The prompt may need to be adjusted in Part 3 to improve search accuracy:
# Function to generate image description
def generate_description(image_path):
image_base64 = get_thumbnail_base64(image_path)
prompt = """
Describe this image in English within 400 characters, in a single line.
Only output the image description without any additional response.
"""
Conclusion
With these improvements, we can now automatically generate captions for images in our gallery. While our ultimate goal is to implement image search, having captions associated with images opens up many possibilities for utilizing image data.
In Part 3, we'll implement image search by vectorizing the image captions and performing vector searches. Stay tuned!
Next Article
Promotion
Follow Snowflake What's New on Twitter
For updates on Snowflake's What's New, follow these Twitter accounts:
English Version
Snowflake What's New Bot (English Version)
Japanese Version
Snowflake's What's New Bot (Japanese Version)
Change Log
(20240923) Initial post
Top comments (0)