DEV Community

Cover image for AnyModal: Train Multimodal LLMs in PyTorch
Ritabrata Maiti
Ritabrata Maiti

Posted on

AnyModal: Train Multimodal LLMs in PyTorch

Today, I want to introduce an open-source framework I’ve been working on: AnyModal.

Introduction

During my work on machine learning projects, I struggled to find flexible solutions for training multimodal LLMs. While there are plenty of great tools for specific tasks—like image classification or audio processing—there was no straightforward way to combine these modalities with large language models (LLMs). The process was often tedious, involving boilerplate code, custom integration, and a lot of trial and error to make different components work together.

This frustration led me to build AnyModal, a framework designed to reduce the complexity of multimodal AI development. It provides a modular, reusable structure that makes it easier for developers and researchers to combine diverse data types and experiment with new ideas without reinventing the wheel every time.

The Goal

AnyModal is built with the following objectives in mind:

Reduce Boilerplate Code

Combining modalities like images or audio with LLMs typically involves repetitive steps—preprocessing, encoding, tokenizing, and integrating. AnyModal minimizes this boilerplate by providing reusable modules for common tasks, letting developers focus on building smarter systems faster.

Enable Seamless Integration

Whether you're working with images using a Vision Transformer (ViT) or audio spectrograms, AnyModal offers plug-and-play components that simplify the integration process. This makes it easy to handle multiple data types within a single framework.

Encourage Experimentation and Customization

AnyModal supports rapid prototyping while offering the flexibility to customize components like feature encoders, projection layers, and tokenizers. It’s versatile enough for both quick experiments and production-level deployments.


Example Usage: Integrating Images with LLMs

Here’s a detailed example of how AnyModal simplifies the integration of image data into LLMs:

1. Install Dependencies

pip install torch transformers datasets torchvision tqdm
Enter fullscreen mode Exit fullscreen mode

2. Initialize Vision Components

from transformers import ViTImageProcessor, ViTForImageClassification

# Load a pre-trained Vision Transformer
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
vision_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

# Define a Vision Encoder to extract feature embeddings
from vision import VisionEncoder
vision_encoder = VisionEncoder(vision_model)
Enter fullscreen mode Exit fullscreen mode

3. Initialize Tokenizer and LLM

from transformers import AutoTokenizer, AutoModelForCausalLM

# Load a pre-trained LLM and its tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained("gpt2")
llm_model = AutoModelForCausalLM.from_pretrained("gpt2")
Enter fullscreen mode Exit fullscreen mode

4. Define a Projection Layer

from vision import Projector

# Create a projection layer to map vision embeddings to LLM token space
vision_tokenizer = Projector(
    in_features=vision_model.config.hidden_size, 
    out_features=768
)
Enter fullscreen mode Exit fullscreen mode

5. Combine Everything with AnyModal

from anymodal import MultiModalModel

# Build the multimodal model
multimodal_model = MultiModalModel(
    input_processor=None,
    input_encoder=vision_encoder,
    input_tokenizer=vision_tokenizer,
    language_tokenizer=llm_tokenizer,
    language_model=llm_model,
    input_start_token='<|imstart|>',
    input_end_token='<|imend|>',
    prompt_text="Describe this image: "
)
Enter fullscreen mode Exit fullscreen mode

6. Training and Inference

Training involves processing batches of image-text pairs and optimizing the model:

from torch.utils.data import DataLoader
from datasets import load_dataset

# Load a sample dataset
dataset = load_dataset("image_caption_dataset", split="train")

# Prepare DataLoader
train_loader = DataLoader(dataset, batch_size=2, shuffle=True)

# Training Loop
optimizer = torch.optim.AdamW(multimodal_model.parameters(), lr=3e-4)

for epoch in range(10):
    for batch in train_loader:
        optimizer.zero_grad()
        logits, loss = multimodal_model(batch)
        loss.backward()
        optimizer.step()

# Generate captions
sample_input = dataset[0]['image']
generated_caption = multimodal_model.generate(sample_input, max_new_tokens=30)
print("Generated Caption:", generated_caption)
Enter fullscreen mode Exit fullscreen mode

Current Status

AnyModal is currently in its early stages, with the latest version supporting tasks like:

  • LaTeX OCR
  • Chest X-Ray Captioning (in progress)
  • Image Captioning

Future planned features include support for visual question answering and audio captioning.

As the framework evolves, I’m focusing on expanding its functionality, refining the codebase, and addressing community feedback to move towards a stable release.


Links


If you’re looking for a way to simplify multimodal AI development, give AnyModal a try. I’d love to hear your feedback or ideas for new features. Contributions are always welcome!

Top comments (1)

Collapse
 
raffdoc profile image
Rafik Margaryan

Do you have experience with lightning module for pytorch? Is it supported ?