mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch Lowering for GroupNormalization op (#3458)
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3461/merge
parent
04c6479350
commit
2ea2bc3948
|
@ -1818,6 +1818,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) ||
|
||||
binder.s64IntegerAttr(stashType, "stash_type", 1))
|
||||
return failure();
|
||||
|
||||
// Since the support for `stash_type` arg does not exist in
|
||||
// the torch op so we just check for the stash_type to be same
|
||||
// as the input dtype since that won't require us to do any
|
||||
// input type conversion and hence can be supported.
|
||||
auto xType = cast<Torch::ValueTensorType>(x.getType());
|
||||
std::optional<int64_t> stashTypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(stashType);
|
||||
if (!stashTypeIntTorch.has_value())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented support for the given stash_type");
|
||||
|
||||
FailureOr<Type> stashDtype = Torch::getTypeForScalarType(
|
||||
binder.op->getContext(),
|
||||
(torch_upstream::ScalarType)stashTypeIntTorch.value());
|
||||
if (failed(stashDtype))
|
||||
return failure();
|
||||
if (*stashDtype != xType.getOptionalDtype())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: stash_type should be same "
|
||||
"as the input dtype");
|
||||
|
||||
Value constEpsilon = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr(epsilon));
|
||||
|
@ -1826,7 +1848,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
rank = *maybeRank;
|
||||
SmallVector<Value> normalized;
|
||||
axis = Torch::toPositiveDim(axis, rank);
|
||||
auto xType = cast<Torch::ValueTensorType>(x.getType());
|
||||
if (!xType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected input (X) to have sizes");
|
||||
|
@ -2444,4 +2465,53 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
paddingList);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"GroupNormalization", 18,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input, scale, bias;
|
||||
int64_t numGroups, stashType;
|
||||
float epsilon;
|
||||
if (binder.tensorOperandAtIndex(input, 0) ||
|
||||
binder.tensorOperandAtIndex(scale, 1) ||
|
||||
binder.tensorOperandAtIndex(bias, 2) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.s64IntegerAttr(numGroups, "num_groups") ||
|
||||
binder.f32FloatAttr(epsilon, "epsilon", 1e-5) ||
|
||||
binder.s64IntegerAttr(stashType, "stash_type", 1))
|
||||
return failure();
|
||||
|
||||
// Since the support for `stash_type` arg does not exist in
|
||||
// the torch op so we just check for the stash_type to be same
|
||||
// as the input dtype since that won't require us to do any
|
||||
// input type conversion and hence can be supported.
|
||||
std::optional<int64_t> stashTypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(stashType);
|
||||
if (!stashTypeIntTorch.has_value())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented support for the given stash_type");
|
||||
|
||||
FailureOr<Type> stashDtype = Torch::getTypeForScalarType(
|
||||
binder.op->getContext(),
|
||||
(torch_upstream::ScalarType)stashTypeIntTorch.value());
|
||||
if (failed(stashDtype))
|
||||
return failure();
|
||||
auto inputDtype =
|
||||
cast<Torch::ValueTensorType>(input.getType()).getOptionalDtype();
|
||||
if (*stashDtype != inputDtype)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: stash_type != input dtype");
|
||||
|
||||
Value cstEpsilon = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr((double)epsilon));
|
||||
Value cstNumGroups = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(numGroups));
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getBoolAttr(false));
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenGroupNormOp>(
|
||||
binder.op, resultType, input, cstNumGroups, scale, bias, cstEpsilon,
|
||||
/*cudnn_enabled=*/cstFalse);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1292,3 +1292,28 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1
|
|||
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32>
|
||||
return %0 : !torch.vtensor<[1,1,4,4,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_group_normalization
|
||||
func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
|
||||
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
||||
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821
|
||||
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32>
|
||||
%0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[3,4,2,2],f32>
|
||||
return %0 : !torch.vtensor<[3,4,2,2],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue