Skip to content

Commit 7b3481a

Browse files
committed
format and lint
Signed-off-by: peter szemraj <peterszemraj@gmail.com>
1 parent 3e30c55 commit 7b3481a

File tree

8 files changed

+16
-27
lines changed

8 files changed

+16
-27
lines changed

samba_pytorch/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
from dataclasses import dataclass
88
from typing import Any, Literal, Optional, Type
99

10-
import torch
11-
from typing_extensions import Self
12-
1310
import lit_gpt.model
11+
import torch
1412
from lit_gpt.utils import find_multiple
13+
from typing_extensions import Self
1514

1615

1716
@dataclass

samba_pytorch/modules/fused_rotary_embedding.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33

44
# Copyright (c) 2023, Tri Dao.
55

6-
import math
7-
from typing import Optional, Tuple
86

97
import rotary_emb
108
import torch
11-
from einops import rearrange, repeat
9+
from einops import rearrange
1210

1311

1412
class ApplyRotaryEmb(torch.autograd.Function):

samba_pytorch/modules/gla.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@
1313
import torch.nn as nn
1414
import torch.nn.functional as F
1515
from einops import rearrange
16-
from transformers.activations import ACT2FN
17-
from transformers.cache_utils import Cache
18-
1916
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
2017
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
18+
from transformers.activations import ACT2FN
19+
from transformers.cache_utils import Cache
2120

2221

2322
class GatedLinearAttention(nn.Module):
24-
2523
def __init__(
2624
self,
2725
mode: str = "fused_chunk",

samba_pytorch/modules/mamba_simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# Copyright (c) 2023, Tri Dao, Albert Gu.
55

66
import math
7+
78
import torch
89
import torch.nn as nn
910
import torch.nn.functional as F
1011
from einops import rearrange, repeat
11-
12-
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
12+
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
1313

1414
try:
1515
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update

samba_pytorch/modules/multiscale_retention.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
import torch
1313
import torch.nn as nn
1414
from einops import rearrange
15-
from transformers.activations import ACT2FN
16-
from transformers.cache_utils import Cache
17-
1815
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
1916
from fla.modules.rotary import RotaryEmbedding
2017
from fla.ops.retention import (
@@ -23,10 +20,11 @@
2320
fused_recurrent_retention,
2421
parallel_retention,
2522
)
23+
from transformers.activations import ACT2FN
24+
from transformers.cache_utils import Cache
2625

2726

2827
class MultiScaleRetention(nn.Module):
29-
3028
def __init__(
3129
self,
3230
mode: str = "fused_chunk",

samba_pytorch/modules/rmsnorm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# Copyright (c) 2022, Tri Dao.
55
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16
66

7-
import torch
87
import dropout_layer_norm
98
import torch
109
from torch.nn import init

samba_pytorch/modules/rotary.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Optional, Union
77

88
import torch
9-
109
import triton
1110
import triton.language as tl
1211

samba_pytorch/samba.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,28 @@
55
# see LICENSE file at https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
66

77
import math
8+
from functools import partial
89
from typing import Any, List, Optional, Tuple
910

1011
import torch
1112
import torch.nn as nn
12-
from typing_extensions import Self
1313
from lit_gpt.config import Config
14+
from torch import Tensor
15+
from typing_extensions import Self
1416
from xformers.ops import SwiGLU
17+
1518
from .fused_rotary_embedding import apply_rotary_emb_func
16-
from torch import Tensor
1719
from .mamba_simple import Mamba
18-
from functools import partial
1920

2021
try:
2122
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
2223
except ImportError:
2324
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
24-
from .gla import GatedLinearAttention
25-
from .multiscale_retention import MultiScaleRetention
26-
from einops import rearrange
27-
import torch.nn.functional as F
28-
2925
from causal_conv1d import causal_conv1d_fn
26+
from einops import rearrange
3027

28+
from .gla import GatedLinearAttention
29+
from .multiscale_retention import MultiScaleRetention
3130

3231
RoPECache = Tuple[torch.Tensor, torch.Tensor]
3332
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -396,7 +395,6 @@ def forward(
396395
input_pos: Optional[torch.Tensor] = None,
397396
kv_cache: Optional[KVCache] = None,
398397
) -> Tuple[torch.Tensor, Optional[KVCache]]:
399-
400398
n_1 = self.norm_1(x)
401399

402400
if self.use_mamba:

0 commit comments

Comments
 (0)