From ca662dc9ccfe50a2366edc6e807420cfc5efe17d Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 4 Jan 2022 12:36:15 +0000 Subject: [PATCH] [MLIR][TORCH] Add E2E support for aten.threshold, aten.threshold_backward op This commit adds lowering of `aten.threshold` op This commit adds lowering of `aten.threshold_backward` op Signed-Off By: Vivek Khandelwal --- e2e_testing/torchscript/main.py | 1 + e2e_testing/torchscript/threshold.py | 291 ++++++++++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 48 +++ .../TorchToLinalg/TorchToLinalg.cpp | 50 ++- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 5 +- .../jit_ir/build_tools/torch_ods_gen.py | 2 + 6 files changed, 392 insertions(+), 5 deletions(-) create mode 100644 e2e_testing/torchscript/threshold.py diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 22428dceb..018a5c682 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -47,6 +47,7 @@ from . import nll_loss from . import index_select from . import arange from . import constant_alloc +from . import threshold def _get_argparse(): config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] diff --git a/e2e_testing/torchscript/threshold.py b/e2e_testing/torchscript/threshold.py new file mode 100644 index 000000000..4408f8441 --- /dev/null +++ b/e2e_testing/torchscript/threshold.py @@ -0,0 +1,291 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + + +class Threshold1dIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ]) + + def forward(self, input): + return torch.ops.aten.threshold(input, 1, 2) + +@register_test_case(module_factory=lambda: Threshold1dIntModule()) +def Threshold1dIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (4,))) + + +class Threshold2dIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + + def forward(self, input): + return torch.ops.aten.threshold(input, 0.5, 2) + +@register_test_case(module_factory=lambda: Threshold2dIntModule()) +def Threshold2dIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (4, 5))) + + +class Threshold3dIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + + def forward(self, input): + return torch.ops.aten.threshold(input, 1, 2.2) + +@register_test_case(module_factory=lambda: Threshold3dIntModule()) +def Threshold3dIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (4, 5, 6))) + + +class Threshold1dFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, input): + return torch.ops.aten.threshold(input, 1, 2) + +@register_test_case(module_factory=lambda: Threshold1dFloatModule()) +def Threshold1dFloatModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4)) + + +class Threshold2dFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, input): + return torch.ops.aten.threshold(input, 0.5, 2) + +@register_test_case(module_factory=lambda: Threshold2dFloatModule()) +def Threshold2dFloatModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5)) + + +class Threshold3dFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + + def forward(self, input): + return torch.ops.aten.threshold(input, 1.4, 2.0) + +@register_test_case(module_factory=lambda: Threshold3dFloatModule()) +def Threshold3dFloatModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5, 6)) + + +class ThresholdBackward1dIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ]) + + def forward(self, grad, input): + return torch.ops.aten.threshold_backward(grad, input, 1) + +@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule()) +def ThresholdBackward1dIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (4,)), torch.randint(8, (4,))) + + +class ThresholdBackward2dIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int64, True), + ]) + + def forward(self, grad, input): + return torch.ops.aten.threshold_backward(grad, input, 0.5) + +@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule()) +def ThresholdBackward2dIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (4, 5)), torch.randint(8, (4, 5))) + + +class ThresholdBackward3dIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.int64, True), + ]) + + def forward(self, grad, input): + return torch.ops.aten.threshold_backward(grad, input, 1) + +@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule()) +def ThresholdBackward3dIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(10, (4, 5, 6)), torch.randint(8, (4, 5, 6))) + + +class ThresholdBackward1dFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ]) + + def forward(self, grad, input): + return torch.ops.aten.threshold_backward(grad, input, 1) + +@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule()) +def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4), torch.randn(4)) + + +class ThresholdBackward2dFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + + def forward(self, grad, input): + return torch.ops.aten.threshold_backward(grad, input, 0.5) + +@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule()) +def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5), torch.randn(4, 5)) + + +class ThresholdBackward3dFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + + def forward(self, grad, input): + return torch.ops.aten.threshold_backward(grad, input, 1.4) + +@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule()) +def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5, 6), torch.randn(4, 5, 6)) + + +class ThresholdBackward1dMixedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.int64, True), + ]) + + def forward(self, grad, input): + return torch.ops.aten.threshold_backward(grad, input, 1) + +@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule()) +def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4), torch.randint(10, (4,))) + + +class ThresholdBackward2dMixedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.float32, True), + ]) + + def forward(self, grad, input): + return torch.ops.aten.threshold_backward(grad, input, 0.5) + +@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule()) +def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils): + module.forward(torch.randint(20, (4, 5)), torch.randn(4, 5)) + + +class ThresholdBackward3dMixedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ]) + + def forward(self, grad, input): + return torch.ops.aten.threshold_backward(grad, input, 1.4) + +@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule()) +def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils): + module.forward(torch.randn(4, 5, 6), torch.randint(10, (4, 5, 6))) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index d22096c23..8fc76aa24 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1140,6 +1140,38 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; } +def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$threshold, + AnyTorchScalarType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` type($self) `,` type($threshold) `,` type($value) `->` type($result)"; +} + +def Torch_AtenThreshold_Op : Torch_Op<"aten.threshold_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::threshold_ : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$threshold, + AnyTorchScalarType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` type($self) `,` type($threshold) `,` type($value) `->` type($result)"; +} + def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [ AllowsTypeRefinement, HasValueSemantics @@ -1249,6 +1281,22 @@ def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [ let assemblyFormat = "$self `,` $exponent attr-dict `:` type($self) `,` type($exponent) `->` type($result)"; } +def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchScalarType:$threshold + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$grad_output `,` $self `,` $threshold attr-dict `:` type($grad_output) `,` type($self) `,` type($threshold) `->` type($result)"; +} + def Torch_AtenTriuOp : Torch_Op<"aten.triu", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index c3944f9e5..a44942616 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -2061,6 +2061,49 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b.create(loc, FloatAttr::get(elementType, 1.0)); return b.create(loc, one, payloadArgs[0]); } + if (auto thresholdOp = dyn_cast(op)) { + // The approach used here is as follows: + // result = self <= threshold ? value : self + AtenThresholdOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(thresholdOp.getType()) + .cast() + .getElementType(); + + Value self = payloadArgs[0]; + Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype); + Value value = convertScalarToDtype(b, loc, adaptor.value(), dtype); + + Value predicate; + if (dtype.isa()) + predicate = b.create(loc, arith::CmpFPredicate::ULE, self, + threshold); + else + predicate = b.create(loc, arith::CmpIPredicate::sle, self, + threshold); + return b.create(loc, predicate, value, self); + } + if (auto thresholdBackward = dyn_cast(op)) { + // The approach used here is as follows: + // result = self <= threshold ? 0 : grad + AtenThresholdBackwardOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(thresholdBackward.getType()) + .cast() + .getElementType(); + + Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype); + Value constantZero = b.create(loc, b.getZeroAttr(dtype)); + + Value predicate; + if (dtype.isa()) + predicate = b.create(loc, arith::CmpFPredicate::ULE, self, + threshold); + else + predicate = b.create(loc, arith::CmpIPredicate::sle, self, + threshold); + return b.create(loc, predicate, constantZero, grad); + } op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); @@ -2280,8 +2323,8 @@ struct ConvertElementwiseOp : ConversionPattern { AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, - AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp>( - op)) + AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, + AtenThresholdOp, AtenThresholdBackwardOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -4163,7 +4206,8 @@ public: AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp, - AtenEqTensorOp, AtenLtTensorOp>(); + AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, + AtenThresholdBackwardOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index fa4deb103..c4e4af2fe 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -242,7 +242,7 @@ public: AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, - AtenAbsOp>(op)) { + AtenAbsOp, AtenThresholdOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } @@ -318,7 +318,8 @@ public: return visitBinaryTensorScalarOp(op, operands); } else if (isa(op)) { + AtenMaximumOp, AtenBitwiseAndTensorOp, + AtenThresholdBackwardOp>(op)) { return visitBinaryBroadcastingOp(op, operands); } else if (isa(op)) { return visitBinaryBroadcastingComparisonOp(op, operands); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 15c9c5601..b7a518c4f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -480,6 +480,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): "aten::abs : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)", "aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", ]: emit_with_mutating_variants(key) @@ -492,6 +493,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::gelu : (Tensor) -> (Tensor)") emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") + emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")