Skip to content

Commit cb0b643

Browse files
kimishpatelkirklandsign
authored andcommitted
[Executorch][llama] Hookup use_attention_mask option in the source transforms inside llm mananger
Pull Request resolved: #10286 ghstack-source-id: 279292327 Differential Revision: [D73222734](https://our.internmc.facebook.com/intern/diff/D73222734/)
1 parent 8eebbcd commit cb0b643

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

examples/models/llama/export_llama_lib.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1227,10 +1227,22 @@ def _get_source_transforms( # noqa
12271227
if args.expand_rope_table:
12281228
transforms.append(materialze_broadcast_of_rope_freq_cis)
12291229

1230+
use_attention_mask_for_custom_sdpa = False
1231+
if isinstance(args, argparse.Namespace):
1232+
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1233+
use_attention_mask_for_custom_sdpa = True
1234+
12301235
if args.use_sdpa_with_kv_cache:
12311236
transforms.append(replace_kv_cache_with_custom_kv_cache)
12321237
# todo: do this optionally
1233-
transforms.append(replace_sdpa_with_custom_op)
1238+
# if use attention mask instead of causal attention
1239+
# then create partial function that sets use_attention_mask=True
1240+
if use_attention_mask_for_custom_sdpa:
1241+
transforms.append(
1242+
partial(replace_sdpa_with_custom_op, use_attention_mask=True)
1243+
)
1244+
else:
1245+
transforms.append(replace_sdpa_with_custom_op)
12341246

12351247
if args.quantize_kv_cache:
12361248
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"

0 commit comments

Comments
 (0)