Skip to content

Commit 6f9376d

Browse files
pszemrajPeter Szemraj
and
Peter Szemraj
authored
fix imports, simplify RMSNorm (#1)
* fix recursive imports Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> * torch-only norms Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> * 📝 install tips Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> --------- Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> Co-authored-by: Peter Szemraj <peterszemraj+dev@gmail.com>
1 parent 8c18e44 commit 6f9376d

File tree

4 files changed

+195
-886
lines changed

4 files changed

+195
-886
lines changed

README.md

+36
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,42 @@
44
55
This aims to be a simpler implementation of the [original repo](https://github.com/microsoft/Samba).
66

7+
## Installation
8+
9+
> [!TIP]
10+
> While the `pip install` command _should_ install all deps and the package, in practice some of the more CUDA-heavy deps are better installed separately from source. See section below for more details.
11+
12+
```bash
13+
git clone https://github.com/pszemraj/samba-pytorch.git
14+
cd samba-pytorch
15+
pip install -e .
16+
```
17+
18+
### Installing custom kernel packages first
19+
20+
After installing `torch`, `xformers`, and `flash-attn`, you may want to install `mamba-ssm`, `causal-conv1d`, and `fla` from source:
21+
22+
```bash
23+
pip install --upgrade pip ninja
24+
pip install git+https://github.com/state-spaces/mamba.git --no-build-isolation
25+
pip install git+https://github.com/Dao-AILab/causal-conv1d.git --no-build-isolation
26+
pip install git+https://github.com/sustcsonglin/flash-linear-attention@98c176e --no-build-isolation
27+
```
28+
29+
Then, clone this repo and run commands as above.
30+
31+
## Usage
32+
33+
A basic example of creating a random model from a named config:
34+
35+
```python
36+
from samba_pytorch import Config, GPT
37+
cfg = Config.from_name('Samba_421M_1k_window')
38+
print*(cfg)
39+
model = GPT(cfg)
40+
model
41+
```
42+
743
## repo structure
844

945
```text

samba_pytorch/config.py

+43-40
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33

44
# Copyright Lightning AI. Licensed under the Apache License 2.0,
55
# see LICENSE file at https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
6-
6+
import warnings
77
from dataclasses import dataclass
88
from typing import Any, Literal, Optional, Type
99

1010
import torch
1111
from typing_extensions import Self
1212

13-
import samba_pytorch.samba
1413
from samba_pytorch.utils import find_multiple
1514

1615

@@ -101,8 +100,9 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
101100

102101
@property
103102
def mlp_class(self) -> Type:
103+
from samba_pytorch import samba
104104
# `self._mlp_class` cannot be the type to keep the config json serializable
105-
return getattr(samba_pytorch.samba, self._mlp_class)
105+
return getattr(samba, self._mlp_class)
106106

107107
@property
108108
def norm_class(self) -> Type:
@@ -112,9 +112,12 @@ def norm_class(self) -> Type:
112112

113113
return RMSNorm
114114
elif self._norm_class == "FusedRMSNorm":
115-
from samba_pytorch.modules.rmsnorm import FusedRMSNorm
115+
warnings.warn(
116+
"FusedRMSNorm has been removed, using standard torch RMSNorm instead"
117+
)
118+
from samba_pytorch.modules.rmsnorm import RMSNorm
116119

117-
return FusedRMSNorm
120+
return RMSNorm
118121
return getattr(torch.nn, self._norm_class)
119122

120123

@@ -133,7 +136,7 @@ def norm_class(self) -> Type:
133136
rotary_percentage=1.0,
134137
parallel_residual=False,
135138
bias=False,
136-
_norm_class="FusedRMSNorm",
139+
_norm_class="RMSNorm",
137140
norm_eps=1e-5,
138141
_mlp_class="LLaMAMLP",
139142
intermediate_size=4096,
@@ -150,7 +153,7 @@ def norm_class(self) -> Type:
150153
rotary_percentage=1.0,
151154
parallel_residual=False,
152155
bias=False,
153-
_norm_class="FusedRMSNorm",
156+
_norm_class="RMSNorm",
154157
norm_eps=1e-5,
155158
_mlp_class="LLaMAMLP",
156159
intermediate_size=4096,
@@ -168,7 +171,7 @@ def norm_class(self) -> Type:
168171
rotary_percentage=1.0,
169172
parallel_residual=False,
170173
bias=False,
171-
_norm_class="FusedRMSNorm",
174+
_norm_class="RMSNorm",
172175
norm_eps=1e-5,
173176
full_per_layer=2,
174177
_mlp_class="LLaMAMLP",
@@ -187,7 +190,7 @@ def norm_class(self) -> Type:
187190
rotary_percentage=1.0,
188191
parallel_residual=False,
189192
bias=False,
190-
_norm_class="FusedRMSNorm",
193+
_norm_class="RMSNorm",
191194
norm_eps=1e-5,
192195
_mlp_class="LLaMAMLP",
193196
intermediate_size=4608,
@@ -206,7 +209,7 @@ def norm_class(self) -> Type:
206209
rotary_percentage=1.0,
207210
parallel_residual=False,
208211
bias=False,
209-
_norm_class="FusedRMSNorm",
212+
_norm_class="RMSNorm",
210213
norm_eps=1e-5,
211214
_mlp_class="LLaMAMLP",
212215
intermediate_size=4608,
@@ -225,7 +228,7 @@ def norm_class(self) -> Type:
225228
rotary_percentage=1.0,
226229
parallel_residual=False,
227230
bias=False,
228-
_norm_class="FusedRMSNorm",
231+
_norm_class="RMSNorm",
229232
norm_eps=1e-5,
230233
_mlp_class="LLaMAMLP",
231234
intermediate_size=4608,
@@ -244,7 +247,7 @@ def norm_class(self) -> Type:
244247
rotary_percentage=1.0,
245248
parallel_residual=False,
246249
bias=False,
247-
_norm_class="FusedRMSNorm",
250+
_norm_class="RMSNorm",
248251
norm_eps=1e-5,
249252
_mlp_class="LLaMAMLP",
250253
intermediate_size=4608,
@@ -263,7 +266,7 @@ def norm_class(self) -> Type:
263266
rotary_percentage=1.0,
264267
parallel_residual=False,
265268
bias=False,
266-
_norm_class="FusedRMSNorm",
269+
_norm_class="RMSNorm",
267270
norm_eps=1e-5,
268271
_mlp_class="LLaMAMLP",
269272
intermediate_size=4096,
@@ -280,7 +283,7 @@ def norm_class(self) -> Type:
280283
rotary_percentage=1.0,
281284
parallel_residual=False,
282285
bias=False,
283-
_norm_class="FusedRMSNorm",
286+
_norm_class="RMSNorm",
284287
norm_eps=1e-5,
285288
_mlp_class="LLaMAMLP",
286289
intermediate_size=4096,
@@ -298,7 +301,7 @@ def norm_class(self) -> Type:
298301
rotary_percentage=1.0,
299302
parallel_residual=False,
300303
bias=False,
301-
_norm_class="FusedRMSNorm",
304+
_norm_class="RMSNorm",
302305
norm_eps=1e-5,
303306
_mlp_class="LLaMAMLP",
304307
intermediate_size=4096,
@@ -316,7 +319,7 @@ def norm_class(self) -> Type:
316319
rotary_percentage=1.0,
317320
parallel_residual=False,
318321
bias=False,
319-
_norm_class="FusedRMSNorm",
322+
_norm_class="RMSNorm",
320323
norm_eps=1e-5,
321324
_mlp_class="LLaMAMLP",
322325
intermediate_size=4096,
@@ -335,7 +338,7 @@ def norm_class(self) -> Type:
335338
parallel_residual=True,
336339
shared_attention_norm=True,
337340
bias=False,
338-
_norm_class="FusedRMSNorm",
341+
_norm_class="RMSNorm",
339342
norm_eps=1e-5,
340343
_mlp_class="LLaMAMLP",
341344
intermediate_size=4096,
@@ -354,7 +357,7 @@ def norm_class(self) -> Type:
354357
rotary_percentage=1.0,
355358
parallel_residual=False,
356359
bias=False,
357-
_norm_class="FusedRMSNorm",
360+
_norm_class="RMSNorm",
358361
norm_eps=1e-5,
359362
_mlp_class="LLaMAMLP",
360363
intermediate_size=4096,
@@ -373,7 +376,7 @@ def norm_class(self) -> Type:
373376
rotary_percentage=1.0,
374377
parallel_residual=False,
375378
bias=False,
376-
_norm_class="FusedRMSNorm",
379+
_norm_class="RMSNorm",
377380
norm_eps=1e-5,
378381
_mlp_class="LLaMAMLP",
379382
intermediate_size=4096,
@@ -393,7 +396,7 @@ def norm_class(self) -> Type:
393396
rotary_percentage=1.0,
394397
parallel_residual=False,
395398
bias=False,
396-
_norm_class="FusedRMSNorm",
399+
_norm_class="RMSNorm",
397400
norm_eps=1e-5,
398401
_mlp_class="LLaMAMLP",
399402
intermediate_size=4096,
@@ -412,7 +415,7 @@ def norm_class(self) -> Type:
412415
rotary_percentage=1.0,
413416
parallel_residual=False,
414417
bias=False,
415-
_norm_class="FusedRMSNorm",
418+
_norm_class="RMSNorm",
416419
norm_eps=1e-5,
417420
_mlp_class="LLaMAMLP",
418421
intermediate_size=4096,
@@ -431,7 +434,7 @@ def norm_class(self) -> Type:
431434
rotary_percentage=1.0,
432435
parallel_residual=False,
433436
bias=False,
434-
_norm_class="FusedRMSNorm",
437+
_norm_class="RMSNorm",
435438
norm_eps=1e-5,
436439
_mlp_class="LLaMAMLP",
437440
intermediate_size=4096,
@@ -450,7 +453,7 @@ def norm_class(self) -> Type:
450453
rotary_percentage=1.0,
451454
parallel_residual=False,
452455
bias=False,
453-
_norm_class="FusedRMSNorm",
456+
_norm_class="RMSNorm",
454457
norm_eps=1e-5,
455458
_mlp_class="LLaMAMLP",
456459
intermediate_size=4096,
@@ -469,7 +472,7 @@ def norm_class(self) -> Type:
469472
rotary_percentage=1.0,
470473
parallel_residual=False,
471474
bias=False,
472-
_norm_class="FusedRMSNorm",
475+
_norm_class="RMSNorm",
473476
norm_eps=1e-5,
474477
_mlp_class="LLaMAMLP",
475478
intermediate_size=4096,
@@ -489,7 +492,7 @@ def norm_class(self) -> Type:
489492
rotary_percentage=1.0,
490493
parallel_residual=False,
491494
bias=False,
492-
_norm_class="FusedRMSNorm",
495+
_norm_class="RMSNorm",
493496
norm_eps=1e-5,
494497
_mlp_class="LLaMAMLP",
495498
intermediate_size=4608,
@@ -510,7 +513,7 @@ def norm_class(self) -> Type:
510513
rotary_percentage=1.0,
511514
parallel_residual=False,
512515
bias=False,
513-
_norm_class="FusedRMSNorm",
516+
_norm_class="RMSNorm",
514517
norm_eps=1e-5,
515518
_mlp_class="LLaMAMLP",
516519
intermediate_size=4608,
@@ -531,7 +534,7 @@ def norm_class(self) -> Type:
531534
rotary_percentage=1.0,
532535
parallel_residual=False,
533536
bias=False,
534-
_norm_class="FusedRMSNorm",
537+
_norm_class="RMSNorm",
535538
norm_eps=1e-5,
536539
_mlp_class="LLaMAMLP",
537540
intermediate_size=4608,
@@ -552,7 +555,7 @@ def norm_class(self) -> Type:
552555
rotary_percentage=1.0,
553556
parallel_residual=False,
554557
bias=False,
555-
_norm_class="FusedRMSNorm",
558+
_norm_class="RMSNorm",
556559
norm_eps=1e-5,
557560
_mlp_class="LLaMAMLP",
558561
intermediate_size=4608,
@@ -573,7 +576,7 @@ def norm_class(self) -> Type:
573576
rotary_percentage=1.0,
574577
parallel_residual=False,
575578
bias=False,
576-
_norm_class="FusedRMSNorm",
579+
_norm_class="RMSNorm",
577580
norm_eps=1e-5,
578581
_mlp_class="LLaMAMLP",
579582
intermediate_size=4096,
@@ -592,7 +595,7 @@ def norm_class(self) -> Type:
592595
rotary_percentage=1.0,
593596
parallel_residual=False,
594597
bias=False,
595-
_norm_class="FusedRMSNorm",
598+
_norm_class="RMSNorm",
596599
norm_eps=1e-5,
597600
_mlp_class="LLaMAMLP",
598601
intermediate_size=4096,
@@ -612,7 +615,7 @@ def norm_class(self) -> Type:
612615
rotary_percentage=1.0,
613616
parallel_residual=False,
614617
bias=False,
615-
_norm_class="FusedRMSNorm",
618+
_norm_class="RMSNorm",
616619
norm_eps=1e-5,
617620
_mlp_class="LLaMAMLP",
618621
intermediate_size=4096,
@@ -632,7 +635,7 @@ def norm_class(self) -> Type:
632635
rotary_percentage=1.0,
633636
parallel_residual=False,
634637
bias=False,
635-
_norm_class="FusedRMSNorm",
638+
_norm_class="RMSNorm",
636639
norm_eps=1e-5,
637640
_mlp_class="LLaMAMLP",
638641
intermediate_size=4096,
@@ -653,7 +656,7 @@ def norm_class(self) -> Type:
653656
rotary_percentage=1.0,
654657
parallel_residual=False,
655658
bias=False,
656-
_norm_class="FusedRMSNorm",
659+
_norm_class="RMSNorm",
657660
norm_eps=1e-5,
658661
_mlp_class="LLaMAMLP",
659662
intermediate_size=6144,
@@ -673,7 +676,7 @@ def norm_class(self) -> Type:
673676
rotary_percentage=1.0,
674677
parallel_residual=False,
675678
bias=False,
676-
_norm_class="FusedRMSNorm",
679+
_norm_class="RMSNorm",
677680
norm_eps=1e-5,
678681
_mlp_class="LLaMAMLP",
679682
intermediate_size=6144,
@@ -693,7 +696,7 @@ def norm_class(self) -> Type:
693696
rotary_percentage=1.0,
694697
parallel_residual=False,
695698
bias=False,
696-
_norm_class="FusedRMSNorm",
699+
_norm_class="RMSNorm",
697700
norm_eps=1e-5,
698701
_mlp_class="LLaMAMLP",
699702
intermediate_size=6144,
@@ -712,7 +715,7 @@ def norm_class(self) -> Type:
712715
rotary_percentage=1.0,
713716
parallel_residual=False,
714717
bias=False,
715-
_norm_class="FusedRMSNorm",
718+
_norm_class="RMSNorm",
716719
norm_eps=1e-5,
717720
_mlp_class="LLaMAMLP",
718721
intermediate_size=6144,
@@ -731,7 +734,7 @@ def norm_class(self) -> Type:
731734
rotary_percentage=1.0,
732735
parallel_residual=False,
733736
bias=False,
734-
_norm_class="FusedRMSNorm",
737+
_norm_class="RMSNorm",
735738
norm_eps=1e-5,
736739
_mlp_class="LLaMAMLP",
737740
intermediate_size=6144,
@@ -750,7 +753,7 @@ def norm_class(self) -> Type:
750753
rotary_percentage=1.0,
751754
parallel_residual=False,
752755
bias=False,
753-
_norm_class="FusedRMSNorm",
756+
_norm_class="RMSNorm",
754757
norm_eps=1e-5,
755758
_mlp_class="LLaMAMLP",
756759
intermediate_size=6144,
@@ -769,7 +772,7 @@ def norm_class(self) -> Type:
769772
rotary_percentage=1.0,
770773
parallel_residual=False,
771774
bias=False,
772-
_norm_class="FusedRMSNorm",
775+
_norm_class="RMSNorm",
773776
norm_eps=1e-5,
774777
_mlp_class="LLaMAMLP",
775778
intermediate_size=6144,
@@ -787,7 +790,7 @@ def norm_class(self) -> Type:
787790
rotary_percentage=1.0,
788791
parallel_residual=False,
789792
bias=False,
790-
_norm_class="FusedRMSNorm",
793+
_norm_class="RMSNorm",
791794
norm_eps=1e-5,
792795
_mlp_class="LLaMAMLP",
793796
intermediate_size=8192,

0 commit comments

Comments
 (0)