[ONNX] Add OnnxToTorch Lowering for GroupNormalization op (#3458)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3461/merge
Vivek Khandelwal 2024-06-14 21:48:53 +05:30 committed by GitHub
parent 04c6479350
commit 2ea2bc3948
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 96 additions and 1 deletions

View File

@ -1818,6 +1818,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) || binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) ||
binder.s64IntegerAttr(stashType, "stash_type", 1)) binder.s64IntegerAttr(stashType, "stash_type", 1))
return failure(); return failure();
// Since the support for `stash_type` arg does not exist in
// the torch op so we just check for the stash_type to be same
// as the input dtype since that won't require us to do any
// input type conversion and hence can be supported.
auto xType = cast<Torch::ValueTensorType>(x.getType());
std::optional<int64_t> stashTypeIntTorch =
onnxDtypeIntToTorchDtypeInt(stashType);
if (!stashTypeIntTorch.has_value())
return rewriter.notifyMatchFailure(
binder.op, "unimplemented support for the given stash_type");
FailureOr<Type> stashDtype = Torch::getTypeForScalarType(
binder.op->getContext(),
(torch_upstream::ScalarType)stashTypeIntTorch.value());
if (failed(stashDtype))
return failure();
if (*stashDtype != xType.getOptionalDtype())
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: stash_type should be same "
"as the input dtype");
Value constEpsilon = rewriter.create<Torch::ConstantFloatOp>( Value constEpsilon = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(epsilon)); rewriter.getF64FloatAttr(epsilon));
@ -1826,7 +1848,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rank = *maybeRank; rank = *maybeRank;
SmallVector<Value> normalized; SmallVector<Value> normalized;
axis = Torch::toPositiveDim(axis, rank); axis = Torch::toPositiveDim(axis, rank);
auto xType = cast<Torch::ValueTensorType>(x.getType());
if (!xType.hasSizes()) { if (!xType.hasSizes()) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "Expected input (X) to have sizes"); binder.op, "Expected input (X) to have sizes");
@ -2444,4 +2465,53 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
paddingList); paddingList);
return success(); return success();
}); });
patterns.onOp(
"GroupNormalization", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input, scale, bias;
int64_t numGroups, stashType;
float epsilon;
if (binder.tensorOperandAtIndex(input, 0) ||
binder.tensorOperandAtIndex(scale, 1) ||
binder.tensorOperandAtIndex(bias, 2) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(numGroups, "num_groups") ||
binder.f32FloatAttr(epsilon, "epsilon", 1e-5) ||
binder.s64IntegerAttr(stashType, "stash_type", 1))
return failure();
// Since the support for `stash_type` arg does not exist in
// the torch op so we just check for the stash_type to be same
// as the input dtype since that won't require us to do any
// input type conversion and hence can be supported.
std::optional<int64_t> stashTypeIntTorch =
onnxDtypeIntToTorchDtypeInt(stashType);
if (!stashTypeIntTorch.has_value())
return rewriter.notifyMatchFailure(
binder.op, "unimplemented support for the given stash_type");
FailureOr<Type> stashDtype = Torch::getTypeForScalarType(
binder.op->getContext(),
(torch_upstream::ScalarType)stashTypeIntTorch.value());
if (failed(stashDtype))
return failure();
auto inputDtype =
cast<Torch::ValueTensorType>(input.getType()).getOptionalDtype();
if (*stashDtype != inputDtype)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: stash_type != input dtype");
Value cstEpsilon = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr((double)epsilon));
Value cstNumGroups = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(numGroups));
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(false));
rewriter.replaceOpWithNewOp<Torch::AtenGroupNormOp>(
binder.op, resultType, input, cstNumGroups, scale, bias, cstEpsilon,
/*cudnn_enabled=*/cstFalse);
return success();
});
} }

View File

@ -1292,3 +1292,28 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32>
return %0 : !torch.vtensor<[1,1,4,4,4],f32> return %0 : !torch.vtensor<[1,1,4,4,4],f32>
} }
// -----
// CHECK-LABEL: func.func @test_group_normalization
func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
return %0 : !torch.vtensor<[3,4,2,2],f32>
}
// -----
func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
return %0 : !torch.vtensor<[3,4,2,2],f32>
}