From f54d15f720e8d3c2993f41a35f4694be9d7a58d1 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 17 Apr 2025 16:06:03 -0700 Subject: [PATCH] [Executorch][llama] Hookup use_attention_mask option in the source transforms inside llm mananger Differential Revision: [D73222734](https://our.internmc.facebook.com/intern/diff/D73222734/) [ghstack-poisoned] --- examples/models/llama/export_llama_lib.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 64cbc9e23af..5c181195f7e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1222,10 +1222,22 @@ def _get_source_transforms( # noqa if args.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) + use_attention_mask_for_custom_sdpa = False + if isinstance(args, argparse.Namespace): + if getattr(args, "use_custom_sdpa_with_attention_mask", None): + use_attention_mask_for_custom_sdpa = True + if args.use_sdpa_with_kv_cache: transforms.append(replace_kv_cache_with_custom_kv_cache) # todo: do this optionally - transforms.append(replace_sdpa_with_custom_op) + # if use attention mask instead of causal attention + # then create partial function that sets use_attention_mask=True + if use_attention_mask_for_custom_sdpa: + transforms.append( + partial(replace_sdpa_with_custom_op, use_attention_mask=True) + ) + else: + transforms.append(replace_sdpa_with_custom_op) if args.quantize_kv_cache: assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"