DEV Community

Maykeye
Maykeye

Posted on

BakaLLM, part 17. Dancing with SSMnakes

Hello, fairy dairy diary.

Cirno snake

All it took to tame mamba was to add a good normalizer layer and put mamba after Attn+MLP! And it finally happened.

model_type n_params loss ppl note
baka_mamba 213139584 3.88385415077209 48.611209417978 after 3 epoch ber 5batch+mamba again
baka_mamba 213139584 3.99270844459534 54.2014924798642 after 2 epoch ber 5batch+mamba again
baka_mamba 213139584 4.26171875 70.9317927632271 after 1 epoch ber 5batch+mamba again
baka_mamba 201822336 3.94270825386047 51.5580446647835 after 3 epoch ber 5batch+mamba-each-4th
baka_mamba 201822336 4.05729150772095 57.8175005609198 after 2 epoch ber 5batch+mamba-each-4th
baka_mamba 201822336 4.36744785308838 78.8421579437085 after 1 epoch ber 5batch+mamba-each-4th

(I learned that sqlite can output markdown, yes, the life of me has changed to before and after)

I got a new record after E3: 3.883. It happened when I added mamba each 2nd layer. Adding it each 4th layer made results worse, so I'll look into fitting more mamba layers now, maybe reducing MLP or removing them at all. Hyper parameter choosing is so tedious when each experiment take many hours.

It also beats Llama* (which I reverted), it also is close to 3.8890 which I got after 5 epoch of staged training (1st epoch: train 4 layers, freeze them, 2nd: train layers 5-8, freeze them, 3rd: train 9-12, unfreeze everything, 4th: unfreeze everything 5th: unfreeze everything)

I also tried different combinations like (mamba(mlp(attn))) but it didn't work as well in E1 so they were all ignored.

Also as it turned out score of 4.10598945617676 that I celebrated so much last time was measured on incorrect weight after 1.5 epochs (2nd epoch was broken in the middle but weights were saved). But mamba after E1 still shown good results. And good results are good.

Chill out!

Top comments (0)