mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.threshold, aten.threshold_backward op
This commit adds lowering of `aten.threshold` op This commit adds lowering of `aten.threshold_backward` op Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/516/head snapshot-20220110.197
parent
7cf7b91664
commit
ca662dc9cc
|
@ -47,6 +47,7 @@ from . import nll_loss
|
||||||
from . import index_select
|
from . import index_select
|
||||||
from . import arange
|
from . import arange
|
||||||
from . import constant_alloc
|
from . import constant_alloc
|
||||||
|
from . import threshold
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||||
|
|
|
@ -0,0 +1,291 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class Threshold1dIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.ops.aten.threshold(input, 1, 2)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Threshold1dIntModule())
|
||||||
|
def Threshold1dIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (4,)))
|
||||||
|
|
||||||
|
|
||||||
|
class Threshold2dIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.ops.aten.threshold(input, 0.5, 2)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Threshold2dIntModule())
|
||||||
|
def Threshold2dIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (4, 5)))
|
||||||
|
|
||||||
|
|
||||||
|
class Threshold3dIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.ops.aten.threshold(input, 1, 2.2)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Threshold3dIntModule())
|
||||||
|
def Threshold3dIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (4, 5, 6)))
|
||||||
|
|
||||||
|
|
||||||
|
class Threshold1dFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.ops.aten.threshold(input, 1, 2)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Threshold1dFloatModule())
|
||||||
|
def Threshold1dFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(4))
|
||||||
|
|
||||||
|
|
||||||
|
class Threshold2dFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.ops.aten.threshold(input, 0.5, 2)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Threshold2dFloatModule())
|
||||||
|
def Threshold2dFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class Threshold3dFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.ops.aten.threshold(input, 1.4, 2.0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Threshold3dFloatModule())
|
||||||
|
def Threshold3dFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(4, 5, 6))
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdBackward1dIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, grad, input):
|
||||||
|
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ThresholdBackward1dIntModule())
|
||||||
|
def ThresholdBackward1dIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (4,)), torch.randint(8, (4,)))
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdBackward2dIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, grad, input):
|
||||||
|
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ThresholdBackward2dIntModule())
|
||||||
|
def ThresholdBackward2dIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (4, 5)), torch.randint(8, (4, 5)))
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdBackward3dIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, grad, input):
|
||||||
|
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ThresholdBackward3dIntModule())
|
||||||
|
def ThresholdBackward3dIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (4, 5, 6)), torch.randint(8, (4, 5, 6)))
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdBackward1dFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, grad, input):
|
||||||
|
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ThresholdBackward1dFloatModule())
|
||||||
|
def ThresholdBackward1dFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(4), torch.randn(4))
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdBackward2dFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, grad, input):
|
||||||
|
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ThresholdBackward2dFloatModule())
|
||||||
|
def ThresholdBackward2dFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(4, 5), torch.randn(4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdBackward3dFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, grad, input):
|
||||||
|
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ThresholdBackward3dFloatModule())
|
||||||
|
def ThresholdBackward3dFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(4, 5, 6), torch.randn(4, 5, 6))
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdBackward1dMixedModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, grad, input):
|
||||||
|
return torch.ops.aten.threshold_backward(grad, input, 1)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ThresholdBackward1dMixedModule())
|
||||||
|
def ThresholdBackward1dMixedModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(4), torch.randint(10, (4,)))
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdBackward2dMixedModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, grad, input):
|
||||||
|
return torch.ops.aten.threshold_backward(grad, input, 0.5)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ThresholdBackward2dMixedModule())
|
||||||
|
def ThresholdBackward2dMixedModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(20, (4, 5)), torch.randn(4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdBackward3dMixedModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, grad, input):
|
||||||
|
return torch.ops.aten.threshold_backward(grad, input, 1.4)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ThresholdBackward3dMixedModule())
|
||||||
|
def ThresholdBackward3dMixedModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(4, 5, 6), torch.randint(10, (4, 5, 6)))
|
|
@ -1140,6 +1140,38 @@ def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [
|
||||||
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchScalarType:$threshold,
|
||||||
|
AnyTorchScalarType:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` type($self) `,` type($threshold) `,` type($value) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenThreshold_Op : Torch_Op<"aten.threshold_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::threshold_ : (Tensor, Scalar, Scalar) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchScalarType:$threshold,
|
||||||
|
AnyTorchScalarType:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $threshold `,` $value attr-dict `:` type($self) `,` type($threshold) `,` type($value) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics
|
HasValueSemantics
|
||||||
|
@ -1249,6 +1281,22 @@ def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [
|
||||||
let assemblyFormat = "$self `,` $exponent attr-dict `:` type($self) `,` type($exponent) `->` type($result)";
|
let assemblyFormat = "$self `,` $exponent attr-dict `:` type($self) `,` type($exponent) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$grad_output,
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchScalarType:$threshold
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$grad_output `,` $self `,` $threshold attr-dict `:` type($grad_output) `,` type($self) `,` type($threshold) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
|
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics
|
HasValueSemantics
|
||||||
|
|
|
@ -2061,6 +2061,49 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
|
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 1.0));
|
||||||
return b.create<arith::DivFOp>(loc, one, payloadArgs[0]);
|
return b.create<arith::DivFOp>(loc, one, payloadArgs[0]);
|
||||||
}
|
}
|
||||||
|
if (auto thresholdOp = dyn_cast<AtenThresholdOp>(op)) {
|
||||||
|
// The approach used here is as follows:
|
||||||
|
// result = self <= threshold ? value : self
|
||||||
|
AtenThresholdOp::Adaptor adaptor(operands);
|
||||||
|
Type dtype = converter->convertType(thresholdOp.getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
|
||||||
|
Value self = payloadArgs[0];
|
||||||
|
Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype);
|
||||||
|
Value value = convertScalarToDtype(b, loc, adaptor.value(), dtype);
|
||||||
|
|
||||||
|
Value predicate;
|
||||||
|
if (dtype.isa<mlir::FloatType>())
|
||||||
|
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
|
||||||
|
threshold);
|
||||||
|
else
|
||||||
|
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
|
||||||
|
threshold);
|
||||||
|
return b.create<SelectOp>(loc, predicate, value, self);
|
||||||
|
}
|
||||||
|
if (auto thresholdBackward = dyn_cast<AtenThresholdBackwardOp>(op)) {
|
||||||
|
// The approach used here is as follows:
|
||||||
|
// result = self <= threshold ? 0 : grad
|
||||||
|
AtenThresholdBackwardOp::Adaptor adaptor(operands);
|
||||||
|
Type dtype = converter->convertType(thresholdBackward.getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
|
||||||
|
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
|
Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
Value threshold = convertScalarToDtype(b, loc, adaptor.threshold(), dtype);
|
||||||
|
Value constantZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||||
|
|
||||||
|
Value predicate;
|
||||||
|
if (dtype.isa<mlir::FloatType>())
|
||||||
|
predicate = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE, self,
|
||||||
|
threshold);
|
||||||
|
else
|
||||||
|
predicate = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, self,
|
||||||
|
threshold);
|
||||||
|
return b.create<SelectOp>(loc, predicate, constantZero, grad);
|
||||||
|
}
|
||||||
|
|
||||||
op->emitError("unimplemented lowering in "
|
op->emitError("unimplemented lowering in "
|
||||||
"createLinalgPayloadCalculationForElementwiseOp");
|
"createLinalgPayloadCalculationForElementwiseOp");
|
||||||
|
@ -2280,8 +2323,8 @@ struct ConvertElementwiseOp : ConversionPattern {
|
||||||
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
|
AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
|
||||||
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
|
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
|
||||||
AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
|
AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
|
||||||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp>(
|
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||||
op))
|
AtenThresholdOp, AtenThresholdBackwardOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -4163,7 +4206,8 @@ public:
|
||||||
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
|
AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp,
|
||||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
|
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
|
||||||
AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
||||||
AtenEqTensorOp, AtenLtTensorOp>();
|
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
|
||||||
|
AtenThresholdBackwardOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenSqueezeOp>();
|
target.addIllegalOp<AtenSqueezeOp>();
|
||||||
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);
|
||||||
|
|
|
@ -242,7 +242,7 @@ public:
|
||||||
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
||||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||||
AtenAbsOp>(op)) {
|
AtenAbsOp, AtenThresholdOp>(op)) {
|
||||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -318,7 +318,8 @@ public:
|
||||||
return visitBinaryTensorScalarOp(op, operands);
|
return visitBinaryTensorScalarOp(op, operands);
|
||||||
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
||||||
AtenDivTensorOp, Aten__And__TensorOp, AtenMinimumOp,
|
AtenDivTensorOp, Aten__And__TensorOp, AtenMinimumOp,
|
||||||
AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) {
|
AtenMaximumOp, AtenBitwiseAndTensorOp,
|
||||||
|
AtenThresholdBackwardOp>(op)) {
|
||||||
return visitBinaryBroadcastingOp(op, operands);
|
return visitBinaryBroadcastingOp(op, operands);
|
||||||
} else if (isa<AtenEqTensorOp, AtenGtTensorOp, AtenLtTensorOp>(op)) {
|
} else if (isa<AtenEqTensorOp, AtenGtTensorOp, AtenLtTensorOp>(op)) {
|
||||||
return visitBinaryBroadcastingComparisonOp(op, operands);
|
return visitBinaryBroadcastingComparisonOp(op, operands);
|
||||||
|
|
|
@ -480,6 +480,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
"aten::abs : (Tensor) -> (Tensor)",
|
"aten::abs : (Tensor) -> (Tensor)",
|
||||||
"aten::reciprocal : (Tensor) -> (Tensor)",
|
"aten::reciprocal : (Tensor) -> (Tensor)",
|
||||||
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
|
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
|
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||||
|
|
||||||
]:
|
]:
|
||||||
emit_with_mutating_variants(key)
|
emit_with_mutating_variants(key)
|
||||||
|
@ -492,6 +493,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
|
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||||
emit("aten::gelu : (Tensor) -> (Tensor)")
|
emit("aten::gelu : (Tensor) -> (Tensor)")
|
||||||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||||
|
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
|
||||||
|
|
||||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||||
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue