mirror of https://github.com/llvm/torch-mlir
[LINALG] Add E2E support for `aten.Hardsigmoid` op
This commit adds lowering of `aten.Hardsigmoid` op. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/603/head
parent
00a6e9c1bb
commit
cd21dda867
|
@ -666,6 +666,45 @@ def _LogSoftmaxModule_basic(module, tu: TestUtils):
|
|||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class HardsigmoidModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardsigmoid(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: HardsigmoidModule())
|
||||
def HardsigmoidModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]]))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class HardsigmoidRandomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardsigmoid(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: HardsigmoidRandomModule())
|
||||
def HardsigmoidRandomModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, low=-10, high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BroadcastToModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -158,6 +158,34 @@ def Torch_AtenSigmoid_Op : Torch_Op<"aten.sigmoid_", [
|
|||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenHardsigmoidOp : Torch_Op<"aten.hardsigmoid", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::hardsigmoid : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenHardsigmoid_Op : Torch_Op<"aten.hardsigmoid_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::hardsigmoid_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenSinOp : Torch_Op<"aten.sin", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -716,6 +716,46 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
|
||||
namespace {
|
||||
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenHardsigmoidOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
Type inputType = input.getType();
|
||||
|
||||
// outputTensor = (input + 3) / 6.
|
||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(3));
|
||||
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(6));
|
||||
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
|
||||
loc, inputType, input, constantThree, /*alpha=*/constantOne);
|
||||
Value outputTensor = rewriter.create<AtenDivScalarOp>(
|
||||
loc, inputType, inputPlusThree, constantSix);
|
||||
|
||||
// result = max(0, min(1, (input+3)/6))
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value zeroTensor = rewriter.create<AtenZerosLikeOp>(
|
||||
loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none,
|
||||
/*pin_memory=*/none, /*memory_format=*/none);
|
||||
Value oneTensor = rewriter.create<AtenOnesLikeOp>(
|
||||
loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none,
|
||||
/*pin_memory=*/none, /*memory_format=*/none);
|
||||
Value minResult =
|
||||
rewriter.create<AtenMinimumOp>(loc, inputType, oneTensor, outputTensor);
|
||||
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), zeroTensor,
|
||||
minResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Returns a tensor with bernoulli(p) distribution.
|
||||
// Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p).
|
||||
static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op,
|
||||
|
@ -1157,6 +1197,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenBernoulliOp>();
|
||||
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
|
||||
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
|
||||
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
||||
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -229,7 +229,8 @@ public:
|
|||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
|
||||
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
|
||||
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp>(op)) {
|
||||
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp,
|
||||
AtenHardsigmoidOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -450,6 +450,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::log : (Tensor) -> (Tensor)",
|
||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::hardsigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::sin : (Tensor) -> (Tensor)",
|
||||
"aten::exp : (Tensor) -> (Tensor)",
|
||||
"aten::cos : (Tensor) -> (Tensor)",
|
||||
|
|
|
@ -414,3 +414,36 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor
|
|||
%0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64>
|
||||
return %0 : !torch.vtensor<[?],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.hardsigmoid(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[ADD3:.*]] = torch.aten.add.Scalar %[[ARG]], %[[INT3]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.div.Scalar %[[ADD3]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[IND0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[ARG]], %[[IND0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[IND1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[ARG]], %[[IND1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[FILL0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ZERO:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[FILL0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[IND0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[DIM0_2:.*]] = torch.aten.size.int %[[ARG]], %[[IND0_2]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[IND1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM1_2:.*]] = torch.aten.size.int %[[ARG]], %[[IND1_2]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SIZES_2:.*]] = torch.prim.ListConstruct %[[DIM0_2]], %[[DIM1_2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[EMPTY_2:.*]] = torch.aten.empty.memory_format %[[SIZES_2]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[FILL1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ONE:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY_2]], %[[FILL1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ONE]], %[[OUT]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[RES:.*]] = torch.aten.maximum %[[ZERO]], %[[MIN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32>
|
||||
func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue