Introduction
Welcome to Part 2 of our Phi-3-Vision porting series. In Part 1, we've created a basic implementation of the model in MLX. However, we also noted that it struggles with longer sequences. Today, we'll address this limitation by implementing Su-scaled Rotary Position Embeddings (RoPE), which will significantly enhance our model's ability to handle long contexts of up to 128K tokens.
The full implementation of what we'll cover today is available at https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_2.py
1. Understanding Rotary Position Embeddings (RoPE)
Before we delve into Su-scaled RoPE, let's first understand the basics of Rotary Position Embeddings.
RoPE is a technique that injects positional information into the model's token representations without adding extra tokens or increasing the model's parameter count. The key idea is to apply a rotation to each token's embedding based on its position in the sequence.
-
Frequency Calculation: For each dimension d in the embedding space, RoPE calculates a frequency:
inv_freq = 1 / (base ** (d / dim))
-
Position-Frequency Interaction: These frequencies are then multiplied by the token positions to create unique sinusoidal patterns for each position.
freqs = inv_freq @ position_ids.T
-
Rotation Application: The resulting patterns are used to rotate the token embeddings in 2D planes.
For a token at position
pos
, RoPE applies the following rotation:
x_rotated = [x * cos(pos * freq) - y * sin(pos * freq), y * cos(pos * freq) + x * sin(pos * freq)]
Now that we understand RoPE, let's explore how Su-scaled RoPE builds upon and enhances this concept.
2. Understanding Su-RoPE
Su-RoPE extends RoPE by introducing scaling factors for different sequence length ranges.
freq = 1 / (SU_FACTOR * theta ** (d / dim))
This allows the model to better generalize to sequences longer than those seen during training.
-
Short and Long Factors: Two sets of scaling factors are used, one for shorter sequences and one for longer sequences.
-
Adaptive Scaling: The choice between short and long factors is made based on the sequence length.
Scaling Factor: An additional scaling factor is applied to adjust for the extended maximum position embeddings.
3. Implementing Su-scaled RoPE
Now that we understand the theory behind Su-scaled RoPE, let's implement it in code. We'll create a SuRoPE
class that encapsulates all the functionality we've discussed:
import mlx.core as mx
import mlx.nn as nn
import math
class SuRoPE:
def __init__(self, config):
self.dim = config.hidden_size // config.num_attention_heads
self.original_max_position_embeddings = config.original_max_position_embeddings
self.rope_theta = config.rope_theta
self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))
self.long_factor = config.rope_scaling["long_factor"]
self.short_factor = config.rope_scaling["short_factor"]
def __call__(self, q, k, position_ids=None):
position_ids = mx.arange(q.shape[2], dtype=mx.float32)[None] if position_ids is None else position_ids
cos, sin = self._get_cos_sin(position_ids)
q = (q * cos) + (self._rotate_half(q) * sin)
k = (k * cos) + (self._rotate_half(k) * sin)
return q, k
def _get_cos_sin(self, position_ids):
su_factor = self.long_factor if mx.max(position_ids) > self.original_max_position_embeddings else self.short_factor
position_ids_expanded = position_ids[:, None, :]
inv_freq = 1.0 / (mx.array(su_factor, dtype=mx.float32) * self.rope_theta**(mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0)
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.expand_dims(mx.cos(emb) * self.scaling_factor, axis=1)
sin = mx.expand_dims(mx.sin(emb) * self.scaling_factor, axis=1)
return cos, sin
@staticmethod
def _rotate_half(x):
midpoint = x.shape[-1] // 2
x1, x2 = x[..., :midpoint], x[..., midpoint:]
return mx.concatenate([-x2, x1], axis=-1)
4. Integrating Su-scaled RoPE into Phi-3-Vision
Integrating our Su-scaled RoPE implementation into the Phi-3-Vision model is straightforward. We only need to add two lines to our Phi3Attention module:
class Phi3Attention(nn.Module):
def __init__(self, config):
# ...
self.rope = SuRoPE(config)
def __call__(self, x):
# ...
q, k = self.rope(q, k)
# ...
And now our ported model can handle up to 128K tokens!
Conclusion
In this tutorial, we implemented Su-scaled Rotary Position Embeddings (RoPE), enabling our model to handle sequences up to 128K tokens.
The full implementation is available at https://github.com/JosefAlbers/Phi-3-Vision-MLX/tree/main/assets/tutorial_2.py
Next, we'll explore efficient batching techniques to further optimize our Phi-3-Vision implementation in MLX.
Top comments (0)