File tree 1 file changed +13
-1
lines changed
1 file changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -1227,10 +1227,22 @@ def _get_source_transforms( # noqa
1227
1227
if args .expand_rope_table :
1228
1228
transforms .append (materialze_broadcast_of_rope_freq_cis )
1229
1229
1230
+ use_attention_mask_for_custom_sdpa = False
1231
+ if isinstance (args , argparse .Namespace ):
1232
+ if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1233
+ use_attention_mask_for_custom_sdpa = True
1234
+
1230
1235
if args .use_sdpa_with_kv_cache :
1231
1236
transforms .append (replace_kv_cache_with_custom_kv_cache )
1232
1237
# todo: do this optionally
1233
- transforms .append (replace_sdpa_with_custom_op )
1238
+ # if use attention mask instead of causal attention
1239
+ # then create partial function that sets use_attention_mask=True
1240
+ if use_attention_mask_for_custom_sdpa :
1241
+ transforms .append (
1242
+ partial (replace_sdpa_with_custom_op , use_attention_mask = True )
1243
+ )
1244
+ else :
1245
+ transforms .append (replace_sdpa_with_custom_op )
1234
1246
1235
1247
if args .quantize_kv_cache :
1236
1248
assert args .use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
You can’t perform that action at this time.
0 commit comments