DEV Community

Cover image for Part 2: Implementing Su-scaled Rotary Position Embeddings (RoPE) for Phi-3-Vision
Josef Albers
Josef Albers

Posted on

Part 2: Implementing Su-scaled Rotary Position Embeddings (RoPE) for Phi-3-Vision

Introduction

Welcome to Part 2 of our Phi-3-Vision porting series. In Part 1, we've created a basic implementation of the model in MLX. However, we also noted that it struggles with longer sequences. Today, we'll address this limitation by implementing Su-scaled Rotary Position Embeddings (RoPE), which will significantly enhance our model's ability to handle long contexts of up to 128K tokens.

The full implementation of what we'll cover today is available at https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_2.py

1. Understanding Rotary Position Embeddings (RoPE)

Before we delve into Su-scaled RoPE, let's first understand the basics of Rotary Position Embeddings.

RoPE is a technique that injects positional information into the model's token representations without adding extra tokens or increasing the model's parameter count. The key idea is to apply a rotation to each token's embedding based on its position in the sequence.

  1. Frequency Calculation: For each dimension d in the embedding space, RoPE calculates a frequency:

    inv_freq = 1 / (base ** (d / dim))
    

  2. Position-Frequency Interaction: These frequencies are then multiplied by the token positions to create unique sinusoidal patterns for each position.

    freqs = inv_freq @ position_ids.T
    

  3. Rotation Application: The resulting patterns are used to rotate the token embeddings in 2D planes.

    For a token at position pos, RoPE applies the following rotation:

    x_rotated = [x * cos(pos * freq) - y * sin(pos * freq),
                 y * cos(pos * freq) + x * sin(pos * freq)]
    

Now that we understand RoPE, let's explore how Su-scaled RoPE builds upon and enhances this concept.

2. Understanding Su-RoPE

Su-RoPE extends RoPE by introducing scaling factors for different sequence length ranges.

freq = 1 / (SU_FACTOR * theta ** (d / dim))
Enter fullscreen mode Exit fullscreen mode

This allows the model to better generalize to sequences longer than those seen during training.

  1. Short and Long Factors: Two sets of scaling factors are used, one for shorter sequences and one for longer sequences.

  2. Adaptive Scaling: The choice between short and long factors is made based on the sequence length.

  3. Scaling Factor: An additional scaling factor is applied to adjust for the extended maximum position embeddings.

3. Implementing Su-scaled RoPE

Now that we understand the theory behind Su-scaled RoPE, let's implement it in code. We'll create a SuRoPE class that encapsulates all the functionality we've discussed:

import mlx.core as mx
import mlx.nn as nn
import math

class SuRoPE:
    def __init__(self, config):
        self.dim = config.hidden_size // config.num_attention_heads
        self.original_max_position_embeddings = config.original_max_position_embeddings
        self.rope_theta = config.rope_theta
        self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))
        self.long_factor = config.rope_scaling["long_factor"]
        self.short_factor = config.rope_scaling["short_factor"]

    def __call__(self, q, k, position_ids=None):
        position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids
        cos, sin = self._get_cos_sin(position_ids)
        q = (q * cos) + (self._rotate_half(q) * sin)
        k = (k * cos) + (self._rotate_half(k) * sin)
        return q, k

    def _get_cos_sin(self, position_ids):
        su_factor = self.long_factor if mx.max(position_ids) > self.original_max_position_embeddings else self.short_factor
        position_ids_expanded = position_ids[:, None, :]
        inv_freq = 1.0 / (mx.array(su_factor, dtype=mx.float32) * self.rope_theta**(mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
        inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0)
        freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)
        emb = mx.concatenate([freqs, freqs], axis=-1)
        cos = mx.expand_dims(mx.cos(emb) * self.scaling_factor, axis=1)
        sin = mx.expand_dims(mx.sin(emb) * self.scaling_factor, axis=1)
        return cos, sin

    @staticmethod
    def _rotate_half(x):
        midpoint = x.shape[-1] // 2
        x1, x2 = x[..., :midpoint], x[..., midpoint:]
        return mx.concatenate([-x2, x1], axis=-1)
Enter fullscreen mode Exit fullscreen mode

4. Integrating Su-scaled RoPE into Phi-3-Vision

Integrating our Su-scaled RoPE implementation into the Phi-3-Vision model is straightforward. We only need to add two lines to our Phi3Attention module:

class Phi3Attention(nn.Module):
    def __init__(self, config):
        # ...
        self.rope = SuRoPE(config)

    def __call__(self, x):
        # ...
        q, k = self.rope(q, k)
        # ...
Enter fullscreen mode Exit fullscreen mode

And now our ported model can handle up to 128K tokens!

Conclusion

In this tutorial, we implemented Su-scaled Rotary Position Embeddings (RoPE), enabling our model to handle sequences up to 128K tokens.

The full implementation is available at https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_2.py

Next, we'll explore efficient batching techniques to further optimize our Phi-3-Vision implementation in MLX.

Top comments (0)