@@ -22,9 +22,15 @@ class SDPACustom(torch.nn.Module):
22
22
def __init__ (
23
23
self ,
24
24
dim : int ,
25
+ max_context_len ,
26
+ enable_dynamic_shape ,
27
+ use_attention_mask : bool = False ,
25
28
):
26
29
super ().__init__ ()
27
30
self .dim = dim
31
+ self .max_context_len = max_context_len
32
+ self .use_attention_mask = use_attention_mask
33
+ self .enable_dynamic_shape = enable_dynamic_shape
28
34
29
35
def forward (
30
36
self ,
@@ -36,6 +42,16 @@ def forward(
36
42
seqlen ,
37
43
mask ,
38
44
):
45
+ if self .use_attention_mask :
46
+ if self .enable_dynamic_shape :
47
+ start_pos = input_pos [- 1 ].item ()
48
+ torch ._check_is_size (start_pos )
49
+ torch ._check (start_pos < self .max_context_len )
50
+ seq_length = q .size (2 )
51
+ mask = mask .narrow (0 , start_pos , seq_length )
52
+ else :
53
+ mask = mask [input_pos ]
54
+
39
55
q = q .transpose (1 , 2 ) # (bs, seqlen, n_local_heads, head_dim)
40
56
k = k .transpose (1 , 2 )
41
57
v = v .transpose (1 , 2 )
@@ -47,34 +63,54 @@ def forward(
47
63
k = k .to (dtype = torch .float )
48
64
v = v .to (dtype = torch .float )
49
65
50
- output = torch .ops .llama .custom_sdpa (
51
- q ,
52
- k ,
53
- v ,
54
- input_pos [0 ].item (),
55
- None , # Attention mask
56
- 0 , # dropout probability. Ignored by the code
57
- True , # is_causal
58
- )
66
+ if self .use_attention_mask :
67
+ output = torch .ops .llama .custom_sdpa (
68
+ q ,
69
+ k ,
70
+ v ,
71
+ input_pos [0 ].item (),
72
+ mask , # Attention mask
73
+ 0 , # dropout probability. Ignored by the code
74
+ False , # is_causal
75
+ )
76
+ else :
77
+ output = torch .ops .llama .custom_sdpa (
78
+ q ,
79
+ k ,
80
+ v ,
81
+ input_pos [0 ].item (),
82
+ None , # Attention mask
83
+ 0 , # dropout probability. Ignored by the code
84
+ True , # is_causal
85
+ )
59
86
return output .view (bsz , seqlen , self .dim ).to (dtype = input_dtype )
60
87
61
88
62
- def _replace_sdpa_with_custom_op (module : torch .nn .Module ):
89
+ def _replace_sdpa_with_custom_op (
90
+ module : torch .nn .Module , use_attention_mask : bool = False
91
+ ):
63
92
for name , child in module .named_children ():
64
93
if isinstance (child , SDPA ):
65
94
setattr (
66
95
module ,
67
96
name ,
68
- SDPACustom (child .dim ),
97
+ SDPACustom (
98
+ child .dim ,
99
+ child .max_context_len ,
100
+ child .enable_dynamic_shape ,
101
+ use_attention_mask = use_attention_mask ,
102
+ ),
69
103
)
70
104
else :
71
- _replace_sdpa_with_custom_op (child )
105
+ _replace_sdpa_with_custom_op (child , use_attention_mask = use_attention_mask )
72
106
73
107
74
- def replace_sdpa_with_custom_op (module : torch .nn .Module ) -> torch .nn .Module :
108
+ def replace_sdpa_with_custom_op (
109
+ module : torch .nn .Module , use_attention_mask : bool = False
110
+ ) -> torch .nn .Module :
75
111
from executorch .extension .llm .custom_ops import custom_ops # noqa
76
112
77
- _replace_sdpa_with_custom_op (module )
113
+ _replace_sdpa_with_custom_op (module , use_attention_mask = use_attention_mask )
78
114
return module
79
115
80
116
0 commit comments