@@ -280,7 +280,7 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
280
280
ORT_RETURN_IF_ERROR (Recurse (node, modified, graph_level, logger));
281
281
282
282
if (node.GetOutputEdgesCount () == 4 &&
283
- graph_utils::IsSupportedOptypeVersionAndDomain (node, " LayerNormalization" , {9 }, kOnnxDomain ) &&
283
+ graph_utils::IsSupportedOptypeVersionAndDomain (node, " LayerNormalization" , {1 }, kOnnxDomain ) &&
284
284
graph_utils::IsSupportedProvider (node, GetCompatibleExecutionProviders ())) {
285
285
// Get hidden size from layer norm bias tensor shape.
286
286
const NodeArg& layer_norm_bias = *(node.InputDefs ()[2 ]);
@@ -389,7 +389,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
389
389
{0 , 0 , " Reshape" , {5 }, kOnnxDomain },
390
390
{0 , 0 , " Add" , {7 }, kOnnxDomain },
391
391
{0 , 0 , " MatMul" , {1 , 9 }, kOnnxDomain },
392
- {0 , 0 , " LayerNormalization" , {9 }, kOnnxDomain }};
392
+ {0 , 0 , " LayerNormalization" , {1 }, kOnnxDomain }};
393
393
394
394
std::vector<const Node::EdgeEnd*> edges;
395
395
if (!graph_utils::FindPath (add_after_layer_norm, true , parent_path, edges, logger)) {
@@ -532,7 +532,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
532
532
{0 , 0 , " Reshape" , {5 }, kOnnxDomain },
533
533
{0 , 0 , " Add" , {7 }, kOnnxDomain },
534
534
{0 , 0 , " MatMul" , {1 , 9 }, kOnnxDomain },
535
- {0 , 0 , " LayerNormalization" , {9 }, kOnnxDomain }};
535
+ {0 , 0 , " LayerNormalization" , {1 }, kOnnxDomain }};
536
536
537
537
if (!graph_utils::FindPath (mask_add, true , q_path, edges, logger)) {
538
538
DEBUG_LOG (" Failed to find path for q" );
@@ -583,7 +583,7 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, const Node& add_after_layer
583
583
{0 , 0 , " Reshape" , {5 }, kOnnxDomain },
584
584
{0 , 0 , " Add" , {7 }, kOnnxDomain },
585
585
{0 , 0 , " MatMul" , {1 , 9 }, kOnnxDomain },
586
- {0 , 0 , " LayerNormalization" , {9 }, kOnnxDomain }};
586
+ {0 , 0 , " LayerNormalization" , {1 }, kOnnxDomain }};
587
587
588
588
if (!graph_utils::FindPath (qk_matmul, true , k_path, edges, logger)) {
589
589
DEBUG_LOG (" Failed to find path for k" );
0 commit comments