Skip to content

Commit 146d124

Browse files
committed
♻️ match updated module name
Signed-off-by: peter szemraj <peterszemraj@gmail.com>
1 parent 7b3481a commit 146d124

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

samba_pytorch/config.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from dataclasses import dataclass
88
from typing import Any, Literal, Optional, Type
99

10-
import lit_gpt.model
1110
import torch
12-
from lit_gpt.utils import find_multiple
1311
from typing_extensions import Self
1412

13+
import samba_pytorch.samba
14+
from samba_pytorch.utils import find_multiple
15+
1516

1617
@dataclass
1718
class Config:
@@ -101,17 +102,17 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
101102
@property
102103
def mlp_class(self) -> Type:
103104
# `self._mlp_class` cannot be the type to keep the config json serializable
104-
return getattr(lit_gpt.model, self._mlp_class)
105+
return getattr(samba_pytorch.samba, self._mlp_class)
105106

106107
@property
107108
def norm_class(self) -> Type:
108109
# `self._norm_class` cannot be the type to keep the config json serializable
109110
if self._norm_class == "RMSNorm":
110-
from lit_gpt.rmsnorm import RMSNorm
111+
from samba_pytorch.modules.rmsnorm import RMSNorm
111112

112113
return RMSNorm
113114
elif self._norm_class == "FusedRMSNorm":
114-
from lit_gpt.rmsnorm import FusedRMSNorm
115+
from samba_pytorch.modules.rmsnorm import FusedRMSNorm
115116

116117
return FusedRMSNorm
117118
return getattr(torch.nn, self._norm_class)

samba_pytorch/samba.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,22 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
from lit_gpt.config import Config
1413
from torch import Tensor
1514
from typing_extensions import Self
1615
from xformers.ops import SwiGLU
1716

18-
from .fused_rotary_embedding import apply_rotary_emb_func
19-
from .mamba_simple import Mamba
20-
2117
try:
2218
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
2319
except ImportError:
2420
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
2521
from causal_conv1d import causal_conv1d_fn
2622
from einops import rearrange
2723

28-
from .gla import GatedLinearAttention
29-
from .multiscale_retention import MultiScaleRetention
24+
from samba_pytorch.config import Config
25+
from samba_pytorch.modules.fused_rotary_embedding import apply_rotary_emb_func
26+
from samba_pytorch.modules.gla import GatedLinearAttention
27+
from samba_pytorch.modules.mamba_simple import Mamba
28+
from samba_pytorch.modules.multiscale_retention import MultiScaleRetention
3029

3130
RoPECache = Tuple[torch.Tensor, torch.Tensor]
3231
KVCache = Tuple[torch.Tensor, torch.Tensor]

0 commit comments

Comments
 (0)