3
3
4
4
# Copyright Lightning AI. Licensed under the Apache License 2.0,
5
5
# see LICENSE file at https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
6
-
6
+ import warnings
7
7
from dataclasses import dataclass
8
8
from typing import Any , Literal , Optional , Type
9
9
10
10
import torch
11
11
from typing_extensions import Self
12
12
13
- import samba_pytorch .samba
14
13
from samba_pytorch .utils import find_multiple
15
14
16
15
@@ -101,8 +100,9 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
101
100
102
101
@property
103
102
def mlp_class (self ) -> Type :
103
+ from samba_pytorch import samba
104
104
# `self._mlp_class` cannot be the type to keep the config json serializable
105
- return getattr (samba_pytorch . samba , self ._mlp_class )
105
+ return getattr (samba , self ._mlp_class )
106
106
107
107
@property
108
108
def norm_class (self ) -> Type :
@@ -112,9 +112,12 @@ def norm_class(self) -> Type:
112
112
113
113
return RMSNorm
114
114
elif self ._norm_class == "FusedRMSNorm" :
115
- from samba_pytorch .modules .rmsnorm import FusedRMSNorm
115
+ warnings .warn (
116
+ "FusedRMSNorm has been removed, using standard torch RMSNorm instead"
117
+ )
118
+ from samba_pytorch .modules .rmsnorm import RMSNorm
116
119
117
- return FusedRMSNorm
120
+ return RMSNorm
118
121
return getattr (torch .nn , self ._norm_class )
119
122
120
123
@@ -133,7 +136,7 @@ def norm_class(self) -> Type:
133
136
rotary_percentage = 1.0 ,
134
137
parallel_residual = False ,
135
138
bias = False ,
136
- _norm_class = "FusedRMSNorm " ,
139
+ _norm_class = "RMSNorm " ,
137
140
norm_eps = 1e-5 ,
138
141
_mlp_class = "LLaMAMLP" ,
139
142
intermediate_size = 4096 ,
@@ -150,7 +153,7 @@ def norm_class(self) -> Type:
150
153
rotary_percentage = 1.0 ,
151
154
parallel_residual = False ,
152
155
bias = False ,
153
- _norm_class = "FusedRMSNorm " ,
156
+ _norm_class = "RMSNorm " ,
154
157
norm_eps = 1e-5 ,
155
158
_mlp_class = "LLaMAMLP" ,
156
159
intermediate_size = 4096 ,
@@ -168,7 +171,7 @@ def norm_class(self) -> Type:
168
171
rotary_percentage = 1.0 ,
169
172
parallel_residual = False ,
170
173
bias = False ,
171
- _norm_class = "FusedRMSNorm " ,
174
+ _norm_class = "RMSNorm " ,
172
175
norm_eps = 1e-5 ,
173
176
full_per_layer = 2 ,
174
177
_mlp_class = "LLaMAMLP" ,
@@ -187,7 +190,7 @@ def norm_class(self) -> Type:
187
190
rotary_percentage = 1.0 ,
188
191
parallel_residual = False ,
189
192
bias = False ,
190
- _norm_class = "FusedRMSNorm " ,
193
+ _norm_class = "RMSNorm " ,
191
194
norm_eps = 1e-5 ,
192
195
_mlp_class = "LLaMAMLP" ,
193
196
intermediate_size = 4608 ,
@@ -206,7 +209,7 @@ def norm_class(self) -> Type:
206
209
rotary_percentage = 1.0 ,
207
210
parallel_residual = False ,
208
211
bias = False ,
209
- _norm_class = "FusedRMSNorm " ,
212
+ _norm_class = "RMSNorm " ,
210
213
norm_eps = 1e-5 ,
211
214
_mlp_class = "LLaMAMLP" ,
212
215
intermediate_size = 4608 ,
@@ -225,7 +228,7 @@ def norm_class(self) -> Type:
225
228
rotary_percentage = 1.0 ,
226
229
parallel_residual = False ,
227
230
bias = False ,
228
- _norm_class = "FusedRMSNorm " ,
231
+ _norm_class = "RMSNorm " ,
229
232
norm_eps = 1e-5 ,
230
233
_mlp_class = "LLaMAMLP" ,
231
234
intermediate_size = 4608 ,
@@ -244,7 +247,7 @@ def norm_class(self) -> Type:
244
247
rotary_percentage = 1.0 ,
245
248
parallel_residual = False ,
246
249
bias = False ,
247
- _norm_class = "FusedRMSNorm " ,
250
+ _norm_class = "RMSNorm " ,
248
251
norm_eps = 1e-5 ,
249
252
_mlp_class = "LLaMAMLP" ,
250
253
intermediate_size = 4608 ,
@@ -263,7 +266,7 @@ def norm_class(self) -> Type:
263
266
rotary_percentage = 1.0 ,
264
267
parallel_residual = False ,
265
268
bias = False ,
266
- _norm_class = "FusedRMSNorm " ,
269
+ _norm_class = "RMSNorm " ,
267
270
norm_eps = 1e-5 ,
268
271
_mlp_class = "LLaMAMLP" ,
269
272
intermediate_size = 4096 ,
@@ -280,7 +283,7 @@ def norm_class(self) -> Type:
280
283
rotary_percentage = 1.0 ,
281
284
parallel_residual = False ,
282
285
bias = False ,
283
- _norm_class = "FusedRMSNorm " ,
286
+ _norm_class = "RMSNorm " ,
284
287
norm_eps = 1e-5 ,
285
288
_mlp_class = "LLaMAMLP" ,
286
289
intermediate_size = 4096 ,
@@ -298,7 +301,7 @@ def norm_class(self) -> Type:
298
301
rotary_percentage = 1.0 ,
299
302
parallel_residual = False ,
300
303
bias = False ,
301
- _norm_class = "FusedRMSNorm " ,
304
+ _norm_class = "RMSNorm " ,
302
305
norm_eps = 1e-5 ,
303
306
_mlp_class = "LLaMAMLP" ,
304
307
intermediate_size = 4096 ,
@@ -316,7 +319,7 @@ def norm_class(self) -> Type:
316
319
rotary_percentage = 1.0 ,
317
320
parallel_residual = False ,
318
321
bias = False ,
319
- _norm_class = "FusedRMSNorm " ,
322
+ _norm_class = "RMSNorm " ,
320
323
norm_eps = 1e-5 ,
321
324
_mlp_class = "LLaMAMLP" ,
322
325
intermediate_size = 4096 ,
@@ -335,7 +338,7 @@ def norm_class(self) -> Type:
335
338
parallel_residual = True ,
336
339
shared_attention_norm = True ,
337
340
bias = False ,
338
- _norm_class = "FusedRMSNorm " ,
341
+ _norm_class = "RMSNorm " ,
339
342
norm_eps = 1e-5 ,
340
343
_mlp_class = "LLaMAMLP" ,
341
344
intermediate_size = 4096 ,
@@ -354,7 +357,7 @@ def norm_class(self) -> Type:
354
357
rotary_percentage = 1.0 ,
355
358
parallel_residual = False ,
356
359
bias = False ,
357
- _norm_class = "FusedRMSNorm " ,
360
+ _norm_class = "RMSNorm " ,
358
361
norm_eps = 1e-5 ,
359
362
_mlp_class = "LLaMAMLP" ,
360
363
intermediate_size = 4096 ,
@@ -373,7 +376,7 @@ def norm_class(self) -> Type:
373
376
rotary_percentage = 1.0 ,
374
377
parallel_residual = False ,
375
378
bias = False ,
376
- _norm_class = "FusedRMSNorm " ,
379
+ _norm_class = "RMSNorm " ,
377
380
norm_eps = 1e-5 ,
378
381
_mlp_class = "LLaMAMLP" ,
379
382
intermediate_size = 4096 ,
@@ -393,7 +396,7 @@ def norm_class(self) -> Type:
393
396
rotary_percentage = 1.0 ,
394
397
parallel_residual = False ,
395
398
bias = False ,
396
- _norm_class = "FusedRMSNorm " ,
399
+ _norm_class = "RMSNorm " ,
397
400
norm_eps = 1e-5 ,
398
401
_mlp_class = "LLaMAMLP" ,
399
402
intermediate_size = 4096 ,
@@ -412,7 +415,7 @@ def norm_class(self) -> Type:
412
415
rotary_percentage = 1.0 ,
413
416
parallel_residual = False ,
414
417
bias = False ,
415
- _norm_class = "FusedRMSNorm " ,
418
+ _norm_class = "RMSNorm " ,
416
419
norm_eps = 1e-5 ,
417
420
_mlp_class = "LLaMAMLP" ,
418
421
intermediate_size = 4096 ,
@@ -431,7 +434,7 @@ def norm_class(self) -> Type:
431
434
rotary_percentage = 1.0 ,
432
435
parallel_residual = False ,
433
436
bias = False ,
434
- _norm_class = "FusedRMSNorm " ,
437
+ _norm_class = "RMSNorm " ,
435
438
norm_eps = 1e-5 ,
436
439
_mlp_class = "LLaMAMLP" ,
437
440
intermediate_size = 4096 ,
@@ -450,7 +453,7 @@ def norm_class(self) -> Type:
450
453
rotary_percentage = 1.0 ,
451
454
parallel_residual = False ,
452
455
bias = False ,
453
- _norm_class = "FusedRMSNorm " ,
456
+ _norm_class = "RMSNorm " ,
454
457
norm_eps = 1e-5 ,
455
458
_mlp_class = "LLaMAMLP" ,
456
459
intermediate_size = 4096 ,
@@ -469,7 +472,7 @@ def norm_class(self) -> Type:
469
472
rotary_percentage = 1.0 ,
470
473
parallel_residual = False ,
471
474
bias = False ,
472
- _norm_class = "FusedRMSNorm " ,
475
+ _norm_class = "RMSNorm " ,
473
476
norm_eps = 1e-5 ,
474
477
_mlp_class = "LLaMAMLP" ,
475
478
intermediate_size = 4096 ,
@@ -489,7 +492,7 @@ def norm_class(self) -> Type:
489
492
rotary_percentage = 1.0 ,
490
493
parallel_residual = False ,
491
494
bias = False ,
492
- _norm_class = "FusedRMSNorm " ,
495
+ _norm_class = "RMSNorm " ,
493
496
norm_eps = 1e-5 ,
494
497
_mlp_class = "LLaMAMLP" ,
495
498
intermediate_size = 4608 ,
@@ -510,7 +513,7 @@ def norm_class(self) -> Type:
510
513
rotary_percentage = 1.0 ,
511
514
parallel_residual = False ,
512
515
bias = False ,
513
- _norm_class = "FusedRMSNorm " ,
516
+ _norm_class = "RMSNorm " ,
514
517
norm_eps = 1e-5 ,
515
518
_mlp_class = "LLaMAMLP" ,
516
519
intermediate_size = 4608 ,
@@ -531,7 +534,7 @@ def norm_class(self) -> Type:
531
534
rotary_percentage = 1.0 ,
532
535
parallel_residual = False ,
533
536
bias = False ,
534
- _norm_class = "FusedRMSNorm " ,
537
+ _norm_class = "RMSNorm " ,
535
538
norm_eps = 1e-5 ,
536
539
_mlp_class = "LLaMAMLP" ,
537
540
intermediate_size = 4608 ,
@@ -552,7 +555,7 @@ def norm_class(self) -> Type:
552
555
rotary_percentage = 1.0 ,
553
556
parallel_residual = False ,
554
557
bias = False ,
555
- _norm_class = "FusedRMSNorm " ,
558
+ _norm_class = "RMSNorm " ,
556
559
norm_eps = 1e-5 ,
557
560
_mlp_class = "LLaMAMLP" ,
558
561
intermediate_size = 4608 ,
@@ -573,7 +576,7 @@ def norm_class(self) -> Type:
573
576
rotary_percentage = 1.0 ,
574
577
parallel_residual = False ,
575
578
bias = False ,
576
- _norm_class = "FusedRMSNorm " ,
579
+ _norm_class = "RMSNorm " ,
577
580
norm_eps = 1e-5 ,
578
581
_mlp_class = "LLaMAMLP" ,
579
582
intermediate_size = 4096 ,
@@ -592,7 +595,7 @@ def norm_class(self) -> Type:
592
595
rotary_percentage = 1.0 ,
593
596
parallel_residual = False ,
594
597
bias = False ,
595
- _norm_class = "FusedRMSNorm " ,
598
+ _norm_class = "RMSNorm " ,
596
599
norm_eps = 1e-5 ,
597
600
_mlp_class = "LLaMAMLP" ,
598
601
intermediate_size = 4096 ,
@@ -612,7 +615,7 @@ def norm_class(self) -> Type:
612
615
rotary_percentage = 1.0 ,
613
616
parallel_residual = False ,
614
617
bias = False ,
615
- _norm_class = "FusedRMSNorm " ,
618
+ _norm_class = "RMSNorm " ,
616
619
norm_eps = 1e-5 ,
617
620
_mlp_class = "LLaMAMLP" ,
618
621
intermediate_size = 4096 ,
@@ -632,7 +635,7 @@ def norm_class(self) -> Type:
632
635
rotary_percentage = 1.0 ,
633
636
parallel_residual = False ,
634
637
bias = False ,
635
- _norm_class = "FusedRMSNorm " ,
638
+ _norm_class = "RMSNorm " ,
636
639
norm_eps = 1e-5 ,
637
640
_mlp_class = "LLaMAMLP" ,
638
641
intermediate_size = 4096 ,
@@ -653,7 +656,7 @@ def norm_class(self) -> Type:
653
656
rotary_percentage = 1.0 ,
654
657
parallel_residual = False ,
655
658
bias = False ,
656
- _norm_class = "FusedRMSNorm " ,
659
+ _norm_class = "RMSNorm " ,
657
660
norm_eps = 1e-5 ,
658
661
_mlp_class = "LLaMAMLP" ,
659
662
intermediate_size = 6144 ,
@@ -673,7 +676,7 @@ def norm_class(self) -> Type:
673
676
rotary_percentage = 1.0 ,
674
677
parallel_residual = False ,
675
678
bias = False ,
676
- _norm_class = "FusedRMSNorm " ,
679
+ _norm_class = "RMSNorm " ,
677
680
norm_eps = 1e-5 ,
678
681
_mlp_class = "LLaMAMLP" ,
679
682
intermediate_size = 6144 ,
@@ -693,7 +696,7 @@ def norm_class(self) -> Type:
693
696
rotary_percentage = 1.0 ,
694
697
parallel_residual = False ,
695
698
bias = False ,
696
- _norm_class = "FusedRMSNorm " ,
699
+ _norm_class = "RMSNorm " ,
697
700
norm_eps = 1e-5 ,
698
701
_mlp_class = "LLaMAMLP" ,
699
702
intermediate_size = 6144 ,
@@ -712,7 +715,7 @@ def norm_class(self) -> Type:
712
715
rotary_percentage = 1.0 ,
713
716
parallel_residual = False ,
714
717
bias = False ,
715
- _norm_class = "FusedRMSNorm " ,
718
+ _norm_class = "RMSNorm " ,
716
719
norm_eps = 1e-5 ,
717
720
_mlp_class = "LLaMAMLP" ,
718
721
intermediate_size = 6144 ,
@@ -731,7 +734,7 @@ def norm_class(self) -> Type:
731
734
rotary_percentage = 1.0 ,
732
735
parallel_residual = False ,
733
736
bias = False ,
734
- _norm_class = "FusedRMSNorm " ,
737
+ _norm_class = "RMSNorm " ,
735
738
norm_eps = 1e-5 ,
736
739
_mlp_class = "LLaMAMLP" ,
737
740
intermediate_size = 6144 ,
@@ -750,7 +753,7 @@ def norm_class(self) -> Type:
750
753
rotary_percentage = 1.0 ,
751
754
parallel_residual = False ,
752
755
bias = False ,
753
- _norm_class = "FusedRMSNorm " ,
756
+ _norm_class = "RMSNorm " ,
754
757
norm_eps = 1e-5 ,
755
758
_mlp_class = "LLaMAMLP" ,
756
759
intermediate_size = 6144 ,
@@ -769,7 +772,7 @@ def norm_class(self) -> Type:
769
772
rotary_percentage = 1.0 ,
770
773
parallel_residual = False ,
771
774
bias = False ,
772
- _norm_class = "FusedRMSNorm " ,
775
+ _norm_class = "RMSNorm " ,
773
776
norm_eps = 1e-5 ,
774
777
_mlp_class = "LLaMAMLP" ,
775
778
intermediate_size = 6144 ,
@@ -787,7 +790,7 @@ def norm_class(self) -> Type:
787
790
rotary_percentage = 1.0 ,
788
791
parallel_residual = False ,
789
792
bias = False ,
790
- _norm_class = "FusedRMSNorm " ,
793
+ _norm_class = "RMSNorm " ,
791
794
norm_eps = 1e-5 ,
792
795
_mlp_class = "LLaMAMLP" ,
793
796
intermediate_size = 8192 ,
0 commit comments