mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.threshold, aten.threshold_backward op
This commit adds lowering of `aten.threshold` op This commit adds lowering of `aten.threshold_backward` op Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/516/head snapshot-20220110.197
parent
7cf7b91664
commit
ca662dc9cc
|
@ -47,6 +47,7 @@ from . import nll_loss
|
|||
from . import index_select
|
||||
from . import arange
|
||||
from . import constant_alloc
|
||||
from . import threshold
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
|
|
|
@ -0,0 +1,291 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Threshold1dIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold1dIntModule())
|
||||
def Threshold1dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4,)))
|
||||
|
||||
|
||||
class Threshold2dIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 0.5, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold2dIntModule())
|
||||
def Threshold2dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4, 5)))
|
||||
|
||||
|
||||
class Threshold3dIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2.2)
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold3dIntModule())
|
||||
def Threshold3dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4, 5, 6)))
|
||||
|
||||
|
||||
class Threshold1dFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold1dFloatModule())
|
||||
def Threshold1dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4))
|
||||
|
||||
|
||||
class Threshold2dFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 0.5, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold2dFloatModule())
|
||||
def Threshold2dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5))
|
||||
|
||||
|
||||
class Threshold3dFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1.4, 2.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold3dFloatModule())
|
||||
def Threshold3dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6))
|
||||
|
||||
|
||||
class ThresholdBackward1dIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule())
|
||||
def ThresholdBackward1dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4,)), torch.randint(8, (4,)))
|
||||
|
||||
|
||||
class ThresholdBackward2dIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule())
|
||||
def ThresholdBackward2dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4, 5)), torch.randint(8, (4, 5)))
|
||||
|
||||
|
||||
class ThresholdBackward3dIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule())
|
||||
def ThresholdBackward3dIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4, 5, 6)), torch.randint(8, (4, 5, 6)))
|
||||
|
||||
|
||||
class ThresholdBackward1dFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule())
|
||||
def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4), torch.randn(4))
|
||||
|
||||
|
||||
class ThresholdBackward2dFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule())
|
||||
def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5), torch.randn(4, 5))
|
||||
|
||||
|
||||
class ThresholdBackward3dFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule())
|
||||
def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.randn(4, 5, 6))
|
||||
|
||||
|
||||
class ThresholdBackward1dMixedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
|
||||
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4), torch.randint(10, (4,)))
|
||||
|
||||
|
||||
class ThresholdBackward2dMixedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
|
||||
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(20, (4, 5)), torch.randn(4, 5))
|
||||
|
||||
|
||||
class ThresholdBackward3dMixedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
||||
|
||||
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
|
||||
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(4, 5, 6), torch.randint(10, (4, 5, 6)))
|
|
@ -1140,6 +1140,38 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [
|
|||
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$threshold,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` type($self) `,` type($threshold) `,` type($value) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenThreshold_Op : Torch_Op<"aten.threshold_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::threshold_ : (Tensor, Scalar, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$threshold,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` type($self) `,` type($threshold) `,` type($value) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -1249,6 +1281,22 @@ def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [
|
|||
let assemblyFormat = "$self `,` $exponent attr-dict `:` type($self) `,` type($exponent) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$threshold
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$grad_output `,` $self `,` $threshold attr-dict `:` type($grad_output) `,` type($self) `,` type($threshold) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -2061,6 +2061,49 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
|
||||
return b.create<arith::DivFOp>(loc, one, payloadArgs[0]);
|
||||
}
|
||||
if (auto thresholdOp = dyn_cast<AtenThresholdOp>(op)) {
|
||||
// The approach used here is as follows:
|
||||
// result = self <= threshold ? value : self
|
||||
AtenThresholdOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(thresholdOp.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
|
||||
Value self = payloadArgs[0];
|
||||
Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype);
|
||||
Value value = convertScalarToDtype(b, loc, adaptor.value(), dtype);
|
||||
|
||||
Value predicate;
|
||||
if (dtype.isa<mlir::FloatType>())
|
||||
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
|
||||
threshold);
|
||||
else
|
||||
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
|
||||
threshold);
|
||||
return b.create<SelectOp>(loc, predicate, value, self);
|
||||
}
|
||||
if (auto thresholdBackward = dyn_cast<AtenThresholdBackwardOp>(op)) {
|
||||
// The approach used here is as follows:
|
||||
// result = self <= threshold ? 0 : grad
|
||||
AtenThresholdBackwardOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(thresholdBackward.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
|
||||
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype);
|
||||
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||
|
||||
Value predicate;
|
||||
if (dtype.isa<mlir::FloatType>())
|
||||
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
|
||||
threshold);
|
||||
else
|
||||
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
|
||||
threshold);
|
||||
return b.create<SelectOp>(loc, predicate, constantZero, grad);
|
||||
}
|
||||
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgPayloadCalculationForElementwiseOp");
|
||||
|
@ -2280,8 +2323,8 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
|
||||
AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp>(
|
||||
op))
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -4163,7 +4206,8 @@ public:
|
|||
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
|
||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp>();
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSqueezeOp>();
|
||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||
|
|
|
@ -242,7 +242,7 @@ public:
|
|||
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp>(op)) {
|
||||
AtenAbsOp, AtenThresholdOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
@ -318,7 +318,8 @@ public:
|
|||
return visitBinaryTensorScalarOp(op, operands);
|
||||
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, Aten__And__TensorOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
|
||||
AtenMaximumOp, AtenBitwiseAndTensorOp,
|
||||
AtenThresholdBackwardOp>(op)) {
|
||||
return visitBinaryBroadcastingOp(op, operands);
|
||||
} else if (isa<AtenEqTensorOp, AtenGtTensorOp, AtenLtTensorOp>(op)) {
|
||||
return visitBinaryBroadcastingComparisonOp(op, operands);
|
||||
|
|
|
@ -480,6 +480,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
"aten::abs : (Tensor) -> (Tensor)",
|
||||
"aten::reciprocal : (Tensor) -> (Tensor)",
|
||||
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
|
||||
]:
|
||||
emit_with_mutating_variants(key)
|
||||
|
@ -492,6 +493,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::gelu : (Tensor) -> (Tensor)")
|
||||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
|
||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue