[MLIR][TORCH] Add E2E support for aten.threshold, aten.threshold_backward op

This commit adds lowering of `aten.threshold` op
This commit adds lowering of `aten.threshold_backward` op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/516/head snapshot-20220110.197
Vivek Khandelwal 2022-01-04 12:36:15 +00:00
parent 7cf7b91664
commit ca662dc9cc
6 changed files with 392 additions and 5 deletions

View File

@ -47,6 +47,7 @@ from . import nll_loss
from . import index_select from . import index_select
from . import arange from . import arange
from . import constant_alloc from . import constant_alloc
from . import threshold
def _get_argparse(): def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -0,0 +1,291 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class Threshold1dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 1, 2)
@register_test_case(module_factory=lambda: Threshold1dIntModule())
def Threshold1dIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (4,)))
class Threshold2dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 0.5, 2)
@register_test_case(module_factory=lambda: Threshold2dIntModule())
def Threshold2dIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (4, 5)))
class Threshold3dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 1, 2.2)
@register_test_case(module_factory=lambda: Threshold3dIntModule())
def Threshold3dIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (4, 5, 6)))
class Threshold1dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 1, 2)
@register_test_case(module_factory=lambda: Threshold1dFloatModule())
def Threshold1dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4))
class Threshold2dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 0.5, 2)
@register_test_case(module_factory=lambda: Threshold2dFloatModule())
def Threshold2dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5))
class Threshold3dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 1.4, 2.0)
@register_test_case(module_factory=lambda: Threshold3dFloatModule())
def Threshold3dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6))
class ThresholdBackward1dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1)
@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule())
def ThresholdBackward1dIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (4,)), torch.randint(8, (4,)))
class ThresholdBackward2dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int64, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 0.5)
@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule())
def ThresholdBackward2dIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (4, 5)), torch.randint(8, (4, 5)))
class ThresholdBackward3dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
([-1, -1, -1], torch.int64, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1)
@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule())
def ThresholdBackward3dIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (4, 5, 6)), torch.randint(8, (4, 5, 6)))
class ThresholdBackward1dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1)
@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule())
def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4), torch.randn(4))
class ThresholdBackward2dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 0.5)
@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule())
def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5), torch.randn(4, 5))
class ThresholdBackward3dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1.4)
@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule())
def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.randn(4, 5, 6))
class ThresholdBackward1dMixedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.int64, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1)
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4), torch.randint(10, (4,)))
class ThresholdBackward2dMixedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.float32, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 0.5)
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
module.forward(torch.randint(20, (4, 5)), torch.randn(4, 5))
class ThresholdBackward3dMixedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.int64, True),
])
def forward(self, grad, input):
return torch.ops.aten.threshold_backward(grad, input, 1.4)
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
module.forward(torch.randn(4, 5, 6), torch.randint(10, (4, 5, 6)))

View File

@ -1140,6 +1140,38 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
} }
def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$threshold,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` type($self) `,` type($threshold) `,` type($value) `->` type($result)";
}
def Torch_AtenThreshold_Op : Torch_Op<"aten.threshold_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::threshold_ : (Tensor, Scalar, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$threshold,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` type($self) `,` type($threshold) `,` type($value) `->` type($result)";
}
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [ def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics
@ -1249,6 +1281,22 @@ def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [
let assemblyFormat = "$self `,` $exponent attr-dict `:` type($self) `,` type($exponent) `->` type($result)"; let assemblyFormat = "$self `,` $exponent attr-dict `:` type($self) `,` type($exponent) `->` type($result)";
} }
def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchScalarType:$threshold
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$grad_output `,` $self `,` $threshold attr-dict `:` type($grad_output) `,` type($self) `,` type($threshold) `->` type($result)";
}
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [ def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics

View File

@ -2061,6 +2061,49 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0)); b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
return b.create<arith::DivFOp>(loc, one, payloadArgs[0]); return b.create<arith::DivFOp>(loc, one, payloadArgs[0]);
} }
if (auto thresholdOp = dyn_cast<AtenThresholdOp>(op)) {
// The approach used here is as follows:
// result = self <= threshold ? value : self
AtenThresholdOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(thresholdOp.getType())
.cast<RankedTensorType>()
.getElementType();
Value self = payloadArgs[0];
Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype);
Value value = convertScalarToDtype(b, loc, adaptor.value(), dtype);
Value predicate;
if (dtype.isa<mlir::FloatType>())
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
threshold);
else
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
threshold);
return b.create<SelectOp>(loc, predicate, value, self);
}
if (auto thresholdBackward = dyn_cast<AtenThresholdBackwardOp>(op)) {
// The approach used here is as follows:
// result = self <= threshold ? 0 : grad
AtenThresholdBackwardOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(thresholdBackward.getType())
.cast<RankedTensorType>()
.getElementType();
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype);
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value predicate;
if (dtype.isa<mlir::FloatType>())
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
threshold);
else
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
threshold);
return b.create<SelectOp>(loc, predicate, constantZero, grad);
}
op->emitError("unimplemented lowering in " op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForElementwiseOp"); "createLinalgPayloadCalculationForElementwiseOp");
@ -2280,8 +2323,8 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp>( AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
op)) AtenThresholdOp, AtenThresholdBackwardOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -4163,7 +4206,8 @@ public:
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp>(); AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
AtenThresholdBackwardOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>(); target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context); patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);

View File

@ -242,7 +242,7 @@ public:
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp>(op)) { AtenAbsOp, AtenThresholdOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }
@ -318,7 +318,8 @@ public:
return visitBinaryTensorScalarOp(op, operands); return visitBinaryTensorScalarOp(op, operands);
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, } else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
AtenDivTensorOp, Aten__And__TensorOp, AtenMinimumOp, AtenDivTensorOp, Aten__And__TensorOp, AtenMinimumOp,
AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) { AtenMaximumOp, AtenBitwiseAndTensorOp,
AtenThresholdBackwardOp>(op)) {
return visitBinaryBroadcastingOp(op, operands); return visitBinaryBroadcastingOp(op, operands);
} else if (isa<AtenEqTensorOp, AtenGtTensorOp, AtenLtTensorOp>(op)) { } else if (isa<AtenEqTensorOp, AtenGtTensorOp, AtenLtTensorOp>(op)) {
return visitBinaryBroadcastingComparisonOp(op, operands); return visitBinaryBroadcastingComparisonOp(op, operands);

View File

@ -480,6 +480,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::abs : (Tensor) -> (Tensor)", "aten::abs : (Tensor) -> (Tensor)",
"aten::reciprocal : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)",
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
]: ]:
emit_with_mutating_variants(key) emit_with_mutating_variants(key)
@ -492,6 +493,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::gelu : (Tensor) -> (Tensor)") emit("aten::gelu : (Tensor) -> (Tensor)")
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)") emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")