Skip to content

Commit ecdcd68

Browse files
authored
NCHWc ReorderOutput->Transpose(NHWC) fusion (#3035)
Add support to fuse ReorderOutput+Transpose(NHWC). Converting from NCHWc to NHWC tensors is a trivial copy of data and avoids the cost of a transpose node.
1 parent 71ca43b commit ecdcd68

File tree

11 files changed

+534
-195
lines changed

11 files changed

+534
-195
lines changed

onnxruntime/contrib_ops/cpu/nchwc_ops.cc

+26-12
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
1616
float,
1717
KernelDefBuilder()
1818
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
19-
ReorderInput<float>);
19+
ReorderInput);
2020

2121
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
2222
ReorderOutput,
2323
1,
2424
float,
2525
KernelDefBuilder()
2626
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
27-
ReorderOutput<float>);
27+
ReorderOutput);
2828

2929
ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
3030
Conv,
@@ -67,27 +67,41 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
6767
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
6868
NchwcAveragePool);
6969

70-
template <typename T>
71-
Status ReorderInput<T>::Compute(OpKernelContext* context) const {
70+
Status ReorderInput::Compute(OpKernelContext* context) const {
7271
const auto* X = context->Input<Tensor>(0);
7372
const auto& X_shape = X->Shape();
7473
ORT_ENFORCE(X_shape.NumDimensions() == 4);
7574
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
7675
auto* Y = context->Output(0, X_shape);
77-
MlasReorderInput(X_shape.GetDims().data(), X->template Data<T>(), Y->template MutableData<T>());
76+
MlasReorderInput(X_shape.GetDims().data(), X->template Data<float>(), Y->template MutableData<float>());
7877
return Status::OK();
7978
}
8079

81-
template <typename T>
82-
Status ReorderOutput<T>::Compute(OpKernelContext* context) const {
80+
Status ReorderOutput::Compute(OpKernelContext* context) const {
8381
const auto* X = context->Input<Tensor>(0);
8482
const auto& X_shape = X->Shape();
85-
ORT_ENFORCE(X_shape.NumDimensions() == 4);
86-
std::vector<int64_t> Y_shape(X_shape.GetDims());
87-
ORT_ENFORCE(channels_ <= Y_shape[1]);
88-
Y_shape[1] = channels_;
83+
const auto X_rank = X_shape.NumDimensions();
84+
ORT_ENFORCE(X_rank == 4);
85+
ORT_ENFORCE(channels_ <= X_shape[1]);
86+
87+
// Build the output shape in NCHW or NHWC order.
88+
std::vector<int64_t> Y_shape(X_rank);
89+
Y_shape[0] = X_shape[0];
90+
Y_shape[channels_last_ ? X_rank - 1 : 1] = channels_;
91+
auto* Y_spatial_dims = Y_shape.data() + (channels_last_ ? 1 : 2);
92+
for (size_t i = 0; i < X_rank - 2; i++) {
93+
Y_spatial_dims[i] = X_shape[2 + i];
94+
}
8995
auto* Y = context->Output(0, Y_shape);
90-
MlasReorderOutput(Y_shape.data(), X->template Data<T>(), Y->template MutableData<T>());
96+
97+
const auto* x_data = X->template Data<float>();
98+
auto* y_data = Y->template MutableData<float>();
99+
if (channels_last_) {
100+
MlasReorderOutputNhwc(Y_shape.data(), x_data, y_data);
101+
} else {
102+
MlasReorderOutputNchw(Y_shape.data(), x_data, y_data);
103+
}
104+
91105
return Status::OK();
92106
}
93107

onnxruntime/contrib_ops/cpu/nchwc_ops.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
namespace onnxruntime {
1313
namespace contrib {
1414

15-
template <typename T>
1615
class ReorderInput : public OpKernel {
1716
public:
1817
ReorderInput(const OpKernelInfo& info) : OpKernel(info) {
@@ -21,18 +20,19 @@ class ReorderInput : public OpKernel {
2120
Status Compute(OpKernelContext* context) const override;
2221
};
2322

24-
template <typename T>
2523
class ReorderOutput : public OpKernel {
2624
public:
2725
ReorderOutput(const OpKernelInfo& info) : OpKernel(info) {
2826
ORT_ENFORCE(info.GetAttr<int64_t>("channels", &channels_).IsOK());
2927
ORT_ENFORCE(channels_ > 0, "invalid channel count");
28+
ORT_ENFORCE(info.GetAttr<int64_t>("channels_last", &channels_last_).IsOK());
3029
}
3130

3231
Status Compute(OpKernelContext* context) const override;
3332

3433
private:
3534
int64_t channels_;
35+
int64_t channels_last_;
3636
};
3737

3838
class NchwcConv : public OpKernel {

onnxruntime/core/graph/contrib_ops/contrib_defs.cc

+8-165
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "core/graph/constants.h"
66
#include "core/graph/contrib_ops/attn_lstm_schema_defs.h"
77
#include "core/graph/contrib_ops/contrib_defs.h"
8+
#include "core/graph/contrib_ops/nchwc_schema_defs.h"
89
#include "core/graph/contrib_ops/range_schema_defs.h"
910
#include "core/graph/op.h"
1011
#include "onnx/defs/schema.h"
@@ -18,7 +19,6 @@ void convPoolShapeInference(
1819
bool use_dilation, bool require_kernel_shape,
1920
int input1Idx,
2021
int input2Idx);
21-
void globalPoolTypeShapeInference(ONNX_NAMESPACE::InferenceContext& ctx);
2222
void matmulShapeInference(
2323
ONNX_NAMESPACE::InferenceContext& ctx,
2424
int input1Idx,
@@ -166,37 +166,6 @@ using ONNX_NAMESPACE::AttributeProto;
166166
using ONNX_NAMESPACE::OpSchema;
167167
using ONNX_NAMESPACE::OPTIONAL;
168168

169-
void NchwcPoolOpSchemaGenerator(OpSchema& schema) {
170-
schema.SetDomain(kMSNchwcDomain);
171-
schema.SinceVersion(1);
172-
schema.SetDoc(R"DOC(For internal use.)DOC");
173-
schema.Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET"));
174-
schema.Attr("kernel_shape", "", AttributeProto::INTS);
175-
schema.Attr("dilations", "", AttributeProto::INTS, OPTIONAL);
176-
schema.Attr("strides", "", AttributeProto::INTS, OPTIONAL);
177-
schema.Attr("pads", "", AttributeProto::INTS, OPTIONAL);
178-
schema.Attr("ceil_mode", "", AttributeProto::INT, static_cast<int64_t>(0));
179-
schema.Input(0, "X", "", "T");
180-
schema.Output(0, "Y", "", "T");
181-
schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors");
182-
schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
183-
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
184-
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, true, 0, 1);
185-
});
186-
}
187-
188-
void NchwcGlobalPoolOpSchemaGenerator(OpSchema& schema) {
189-
schema.SetDomain(kMSNchwcDomain);
190-
schema.SinceVersion(1);
191-
schema.SetDoc(R"DOC(For internal use.)DOC");
192-
schema.Input(0, "X", "", "T");
193-
schema.Output(0, "Y", "", "T");
194-
schema.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors");
195-
schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
196-
ONNX_NAMESPACE::globalPoolTypeShapeInference(ctx);
197-
});
198-
}
199-
200169
void ValidateTypeAndShapeForScaleAndZP(ONNX_NAMESPACE::InferenceContext& ctx, int index, ::google::protobuf::int32 expectedType, bool isScalar, int expectedTensorSize = 0) {
201170
if (ctx.getNumInputs() > static_cast<size_t>(index)) {
202171
auto data_type = ctx.getInputType(index);
@@ -320,132 +289,6 @@ const char* contrib_ops_auto_pad_doc =
320289
"In case of odd number add the extra padding at the end for SAME_UPPER and at the "
321290
"beginning for SAME_LOWER. VALID mean no padding.";
322291

323-
void RegisterNchwcSchemas() {
324-
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderInput)
325-
.SetDomain(kMSNchwcDomain)
326-
.SinceVersion(1)
327-
.SetDoc(R"DOC(For internal use.)DOC")
328-
.Input(0, "X", "", "T")
329-
.Output(0, "Y", "", "T")
330-
.TypeConstraint(
331-
"T",
332-
{"tensor(float)", "tensor(int8)", "tensor(uint8)"},
333-
"Constrain input and output types to float/quantized tensors")
334-
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
335-
336-
ONNX_CONTRIB_OPERATOR_SCHEMA(ReorderOutput)
337-
.SetDomain(kMSNchwcDomain)
338-
.SinceVersion(1)
339-
.SetDoc(R"DOC(For internal use.)DOC")
340-
.Attr(
341-
"channels",
342-
"",
343-
AttributeProto::INT,
344-
static_cast<int64_t>(0))
345-
.Input(0, "X", "", "T")
346-
.Output(0, "Y", "", "T")
347-
.TypeConstraint(
348-
"T",
349-
{"tensor(float)", "tensor(int8)", "tensor(uint8)"},
350-
"Constrain input and output types to float/quantized tensors")
351-
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
352-
propagateElemTypeFromInputToOutput(ctx, 0, 0);
353-
if (!hasNInputShapes(ctx, 1)) {
354-
return;
355-
}
356-
propagateShapeFromInputToOutput(ctx, 0, 0);
357-
358-
// Update the output shape with the actual number of channels.
359-
auto channels = getAttribute(ctx, "channels", 0);
360-
if (channels <= 0) {
361-
fail_shape_inference("invalid channel count");
362-
}
363-
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
364-
if (output_shape->dim_size() < 2) {
365-
fail_shape_inference("tensor rank too small");
366-
}
367-
auto* channels_dim = output_shape->mutable_dim(1);
368-
channels_dim->clear_dim_param();
369-
channels_dim->set_dim_value(channels);
370-
});
371-
372-
ONNX_CONTRIB_OPERATOR_SCHEMA(Conv)
373-
.SetDomain(kMSNchwcDomain)
374-
.SinceVersion(1)
375-
.SetDoc(R"DOC(For internal use.)DOC")
376-
.Attr(
377-
"auto_pad",
378-
"",
379-
AttributeProto::STRING,
380-
std::string("NOTSET"))
381-
.Attr(
382-
"kernel_shape",
383-
"",
384-
AttributeProto::INTS,
385-
OPTIONAL)
386-
.Attr(
387-
"dilations",
388-
"",
389-
AttributeProto::INTS,
390-
OPTIONAL)
391-
.Attr(
392-
"strides",
393-
"",
394-
AttributeProto::INTS,
395-
OPTIONAL)
396-
.Attr(
397-
"pads",
398-
"",
399-
AttributeProto::INTS, OPTIONAL)
400-
.Attr(
401-
"group",
402-
"",
403-
AttributeProto::INT,
404-
static_cast<int64_t>(1))
405-
.Attr(
406-
"activation",
407-
"",
408-
AttributeProto::STRING,
409-
OPTIONAL)
410-
.Attr(
411-
"activation_params",
412-
"",
413-
AttributeProto::FLOATS,
414-
OPTIONAL)
415-
.Input(0, "X", "", "T")
416-
.Input(1, "W", "", "T")
417-
.Input(2, "B", "", "T", OpSchema::Optional)
418-
.Input(3, "Sum", "", "T", OpSchema::Optional)
419-
.Output(0, "Y", "", "T")
420-
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors")
421-
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
422-
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
423-
ONNX_NAMESPACE::convPoolShapeInference(ctx, true, false, 0, 1);
424-
});
425-
426-
ONNX_CONTRIB_OPERATOR_SCHEMA(MaxPool)
427-
.FillUsing(NchwcPoolOpSchemaGenerator)
428-
.Attr(
429-
"storage_order",
430-
"",
431-
AttributeProto::INT,
432-
static_cast<int64_t>(0));
433-
434-
ONNX_CONTRIB_OPERATOR_SCHEMA(AveragePool)
435-
.FillUsing(NchwcPoolOpSchemaGenerator)
436-
.Attr(
437-
"count_include_pad",
438-
"",
439-
AttributeProto::INT,
440-
static_cast<int64_t>(0));
441-
442-
ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalMaxPool)
443-
.FillUsing(NchwcGlobalPoolOpSchemaGenerator);
444-
445-
ONNX_CONTRIB_OPERATOR_SCHEMA(GlobalAveragePool)
446-
.FillUsing(NchwcGlobalPoolOpSchemaGenerator);
447-
}
448-
449292
void RegisterBertSchemas() {
450293
ONNX_CONTRIB_OPERATOR_SCHEMA(Attention)
451294
.SetDomain(kMSDomain)
@@ -1383,8 +1226,8 @@ activation and leaky_relu_alpha.)DOC")
13831226
ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema);
13841227

13851228
static const char* QuantizeLinear_ver1_doc = R"DOC(
1386-
The linear quantization operator. It consumes a full precision data, a scale, a zero point and computes the quantized data.
1387-
The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), it computes the nearest integer value to arg (in floating-point format),
1229+
The linear quantization operator. It consumes a full precision data, a scale, a zero point and computes the quantized data.
1230+
The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), it computes the nearest integer value to arg (in floating-point format),
13881231
rounding halfway cases away from zero. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC";
13891232

13901233
ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinear)
@@ -1440,8 +1283,8 @@ The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale),
14401283
});
14411284

14421285
static const char* DequantizeLinear_ver1_doc = R"DOC(
1443-
The linear dequantization operator. It consumes a quantized data, a scale, a zero point and computes the full precision data.
1444-
The dequantization formula is y = (x - x_zero_point) * x_scale.
1286+
The linear dequantization operator. It consumes a quantized data, a scale, a zero point and computes the full precision data.
1287+
The dequantization formula is y = (x - x_zero_point) * x_scale.
14451288
Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC";
14461289

14471290
ONNX_CONTRIB_OPERATOR_SCHEMA(DequantizeLinear)
@@ -1682,7 +1525,7 @@ Computes the mean of the low-precision input tensor's element along the provided
16821525
The resulting tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0,
16831526
then the resulting tensor have the reduced dimension pruned. The above behavior is similar to numpy,
16841527
with the exception that numpy default keepdims to False instead of True.
1685-
Input and Output scales and zero points are used to requantize the output in a new range.
1528+
Input and Output scales and zero points are used to requantize the output in a new range.
16861529
This helps to improve accuracy as after ReduceMean operation the range of the output is expected to decrease.
16871530
16881531
```
@@ -1861,7 +1704,7 @@ C (int32) = (A - A_zero_point) * (B - B_zero_point)
18611704
```
18621705
pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + kernel_spatial_shape[i] - input_spatial_shape[i]
18631706
```
1864-
1707+
18651708
The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero).
18661709
18671710
Input and output scales and zero points are used to convert the output to a new quantization range.
@@ -2448,7 +2291,7 @@ Example 4:
24482291
R"DOC(Gaussian Error Linear Unit.
24492292
A high-performing neural network activation function.The GELU nonlinearity is
24502293
the expected transformation of a stochastic regularizer which randomly applies
2451-
the identity or zero map to a neuron's input. The GELU nonlinearity weights
2294+
the identity or zero map to a neuron's input. The GELU nonlinearity weights
24522295
inputs by their magnitude, rather than gates inputs by their sign as in ReLUs.)DOC";
24532296

24542297
ONNX_CONTRIB_OPERATOR_SCHEMA(Gelu)

0 commit comments

Comments
 (0)