Skip to content

Commit 8c18e44

Browse files
committed
update package inits
Signed-off-by: peter szemraj <peterszemraj@gmail.com>
1 parent 5bd9f71 commit 8c18e44

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

samba_pytorch/__init__.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Minimal implementation of Samba by Microsoft in PyTorch."""
2+
3+
from samba_pytorch.config import Config
4+
from samba_pytorch.samba import GPT, Block, CausalSelfAttention, LLaMAMLP
5+
from samba_pytorch.tokenizer import Tokenizer
6+
from samba_pytorch.utils import (
7+
chunked_cross_entropy,
8+
get_default_supported_precision,
9+
lazy_load,
10+
num_parameters,
11+
)
12+
13+
try:
14+
from samba_pytorch._version import version as __version__
15+
except ImportError:
16+
__version__ = "0.0.0"
17+
18+
__all__ = [
19+
"Config",
20+
"GPT",
21+
"Block",
22+
"CausalSelfAttention",
23+
"LLaMAMLP",
24+
"Tokenizer",
25+
"chunked_cross_entropy",
26+
"get_default_supported_precision",
27+
"lazy_load",
28+
"num_parameters",
29+
"__version__",
30+
]

samba_pytorch/modules/__init__.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Core model component modules."""
2+
3+
from samba_pytorch.modules.fused_rotary_embedding import (
4+
ApplyRotaryEmb,
5+
apply_rotary_emb_func,
6+
)
7+
from samba_pytorch.modules.gla import GatedLinearAttention
8+
from samba_pytorch.modules.mamba_simple import Mamba
9+
from samba_pytorch.modules.multiscale_retention import MultiScaleRetention
10+
from samba_pytorch.modules.rmsnorm import FusedRMSNorm, RMSNorm, rms_norm
11+
from samba_pytorch.modules.rotary import RotaryEmbedding, apply_rotary_emb
12+
13+
__all__ = [
14+
"apply_rotary_emb_func",
15+
"ApplyRotaryEmb",
16+
"GatedLinearAttention",
17+
"Mamba",
18+
"MultiScaleRetention",
19+
"FusedRMSNorm",
20+
"RMSNorm",
21+
"rms_norm",
22+
"apply_rotary_emb",
23+
"RotaryEmbedding",
24+
]

0 commit comments

Comments
 (0)