From cd21dda867b9e5783795864cd00118b33aa91dc7 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 14 Feb 2022 20:16:44 +0530 Subject: [PATCH] [LINALG] Add E2E support for `aten.Hardsigmoid` op This commit adds lowering of `aten.Hardsigmoid` op. Signed-Off-by: Gaurav Shukla --- e2e_testing/torchscript/basic.py | 39 +++++++++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 28 +++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 42 +++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 3 +- .../jit_ir/build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/decompose-complex-ops.mlir | 33 +++++++++++++++ 6 files changed, 145 insertions(+), 1 deletion(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 87032b123..279994811 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -666,6 +666,45 @@ def _LogSoftmaxModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4)) # ============================================================================== + +class HardsigmoidModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.hardsigmoid(x) + + +@register_test_case(module_factory=lambda: HardsigmoidModule()) +def HardsigmoidModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]])) + +# ============================================================================== + +class HardsigmoidRandomModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.hardsigmoid(x) + + +@register_test_case(module_factory=lambda: HardsigmoidRandomModule()) +def HardsigmoidRandomModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, low=-10, high=10)) + +# ============================================================================== + class BroadcastToModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index ddfc4d565..4178a345d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -158,6 +158,34 @@ def Torch_AtenSigmoid_Op : Torch_Op<"aten.sigmoid_", [ let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))"; } +def Torch_AtenHardsigmoidOp : Torch_Op<"aten.hardsigmoid", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::hardsigmoid : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))"; +} + +def Torch_AtenHardsigmoid_Op : Torch_Op<"aten.hardsigmoid_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::hardsigmoid_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))"; +} + def Torch_AtenSinOp : Torch_Op<"aten.sin", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index cbe2325bc..689c1bfae 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -716,6 +716,46 @@ public: }; } // namespace +// Hardsigmoid(x) = max(0, min(1, (x+3)/6)) +namespace { +class DecomposeAtenHardsigmoidOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenHardsigmoidOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.self(); + Type inputType = input.getType(); + + // outputTensor = (input + 3) / 6. + Value constantOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constantThree = rewriter.create( + loc, rewriter.getI64IntegerAttr(3)); + Value constantSix = rewriter.create( + loc, rewriter.getI64IntegerAttr(6)); + Value inputPlusThree = rewriter.create( + loc, inputType, input, constantThree, /*alpha=*/constantOne); + Value outputTensor = rewriter.create( + loc, inputType, inputPlusThree, constantSix); + + // result = max(0, min(1, (input+3)/6)) + Value none = rewriter.create(loc); + Value zeroTensor = rewriter.create( + loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none, + /*pin_memory=*/none, /*memory_format=*/none); + Value oneTensor = rewriter.create( + loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none, + /*pin_memory=*/none, /*memory_format=*/none); + Value minResult = + rewriter.create(loc, inputType, oneTensor, outputTensor); + rewriter.replaceOpWithNewOp(op, op.getType(), zeroTensor, + minResult); + return success(); + } +}; +} // namespace + // Returns a tensor with bernoulli(p) distribution. // Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p). static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op, @@ -1157,6 +1197,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index a338ce031..fe4ceed93 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -229,7 +229,8 @@ public: AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp, AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp, - PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp>(op)) { + PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp, + AtenHardsigmoidOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index b65a2ebbc..131098811 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -450,6 +450,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::log : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", + "aten::hardsigmoid : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index b5237ce5a..6b31fdc86 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -414,3 +414,36 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor %0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64> return %0 : !torch.vtensor<[?],si64> } + +// ----- +// CHECK-LABEL: func @torch.aten.hardsigmoid( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[ADD3:.*]] = torch.aten.add.Scalar %[[ARG]], %[[INT3]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[OUT:.*]] = torch.aten.div.Scalar %[[ADD3]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[IND0:.*]] = torch.constant.int 0 +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[ARG]], %[[IND0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[IND1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[ARG]], %[[IND1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32> +// CHECK: %[[FILL0:.*]] = torch.constant.int 0 +// CHECK: %[[ZERO:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[FILL0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[IND0_2:.*]] = torch.constant.int 0 +// CHECK: %[[DIM0_2:.*]] = torch.aten.size.int %[[ARG]], %[[IND0_2]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[IND1_2:.*]] = torch.constant.int 1 +// CHECK: %[[DIM1_2:.*]] = torch.aten.size.int %[[ARG]], %[[IND1_2]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[SIZES_2:.*]] = torch.prim.ListConstruct %[[DIM0_2]], %[[DIM1_2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[EMPTY_2:.*]] = torch.aten.empty.memory_format %[[SIZES_2]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32> +// CHECK: %[[FILL1:.*]] = torch.constant.int 1 +// CHECK: %[[ONE:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY_2]], %[[FILL1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ONE]], %[[OUT]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK: %[[RES:.*]] = torch.aten.maximum %[[ZERO]], %[[MIN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32> +func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +}