mirror of https://github.com/llvm/torch-mlir
[LINALG] Add E2E support for `aten.Hardsigmoid` op
This commit adds lowering of `aten.Hardsigmoid` op. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/603/head
parent
00a6e9c1bb
commit
cd21dda867
|
@ -666,6 +666,45 @@ def _LogSoftmaxModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 2, 4))
|
module.forward(torch.randn(3, 2, 4))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class HardsigmoidModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.hardsigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: HardsigmoidModule())
|
||||||
|
def HardsigmoidModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]]))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class HardsigmoidRandomModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.hardsigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: HardsigmoidRandomModule())
|
||||||
|
def HardsigmoidRandomModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, low=-10, high=10))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class BroadcastToModule(torch.nn.Module):
|
class BroadcastToModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -158,6 +158,34 @@ def Torch_AtenSigmoid_Op : Torch_Op<"aten.sigmoid_", [
|
||||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenHardsigmoidOp : Torch_Op<"aten.hardsigmoid", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::hardsigmoid : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenHardsigmoid_Op : Torch_Op<"aten.hardsigmoid_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::hardsigmoid_ : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenSinOp : Torch_Op<"aten.sin", [
|
def Torch_AtenSinOp : Torch_Op<"aten.sin", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics
|
HasValueSemantics
|
||||||
|
|
|
@ -716,6 +716,46 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenHardsigmoidOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value input = op.self();
|
||||||
|
Type inputType = input.getType();
|
||||||
|
|
||||||
|
// outputTensor = (input + 3) / 6.
|
||||||
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(3));
|
||||||
|
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(6));
|
||||||
|
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
|
||||||
|
loc, inputType, input, constantThree, /*alpha=*/constantOne);
|
||||||
|
Value outputTensor = rewriter.create<AtenDivScalarOp>(
|
||||||
|
loc, inputType, inputPlusThree, constantSix);
|
||||||
|
|
||||||
|
// result = max(0, min(1, (input+3)/6))
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
Value zeroTensor = rewriter.create<AtenZerosLikeOp>(
|
||||||
|
loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none,
|
||||||
|
/*pin_memory=*/none, /*memory_format=*/none);
|
||||||
|
Value oneTensor = rewriter.create<AtenOnesLikeOp>(
|
||||||
|
loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none,
|
||||||
|
/*pin_memory=*/none, /*memory_format=*/none);
|
||||||
|
Value minResult =
|
||||||
|
rewriter.create<AtenMinimumOp>(loc, inputType, oneTensor, outputTensor);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), zeroTensor,
|
||||||
|
minResult);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Returns a tensor with bernoulli(p) distribution.
|
// Returns a tensor with bernoulli(p) distribution.
|
||||||
// Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p).
|
// Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p).
|
||||||
static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op,
|
static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op,
|
||||||
|
@ -1157,6 +1197,8 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<AtenBernoulliOp>();
|
target.addIllegalOp<AtenBernoulliOp>();
|
||||||
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
|
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
|
||||||
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
|
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
|
||||||
|
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
||||||
|
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
|
|
@ -229,7 +229,8 @@ public:
|
||||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
|
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
|
||||||
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
|
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
|
||||||
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp>(op)) {
|
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp,
|
||||||
|
AtenHardsigmoidOp>(op)) {
|
||||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -450,6 +450,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::log : (Tensor) -> (Tensor)",
|
"aten::log : (Tensor) -> (Tensor)",
|
||||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||||
|
"aten::hardsigmoid : (Tensor) -> (Tensor)",
|
||||||
"aten::sin : (Tensor) -> (Tensor)",
|
"aten::sin : (Tensor) -> (Tensor)",
|
||||||
"aten::exp : (Tensor) -> (Tensor)",
|
"aten::exp : (Tensor) -> (Tensor)",
|
||||||
"aten::cos : (Tensor) -> (Tensor)",
|
"aten::cos : (Tensor) -> (Tensor)",
|
||||||
|
|
|
@ -414,3 +414,36 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor
|
||||||
%0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64>
|
%0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64>
|
||||||
return %0 : !torch.vtensor<[?],si64>
|
return %0 : !torch.vtensor<[?],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @torch.aten.hardsigmoid(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[ADD3:.*]] = torch.aten.add.Scalar %[[ARG]], %[[INT3]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[OUT:.*]] = torch.aten.div.Scalar %[[ADD3]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[IND0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[ARG]], %[[IND0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[IND1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[ARG]], %[[IND1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||||
|
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[FILL0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[ZERO:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[FILL0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[IND0_2:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIM0_2:.*]] = torch.aten.size.int %[[ARG]], %[[IND0_2]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[IND1_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[DIM1_2:.*]] = torch.aten.size.int %[[ARG]], %[[IND1_2]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SIZES_2:.*]] = torch.prim.ListConstruct %[[DIM0_2]], %[[DIM1_2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||||
|
// CHECK: %[[EMPTY_2:.*]] = torch.aten.empty.memory_format %[[SIZES_2]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[FILL1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[ONE:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY_2]], %[[FILL1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ONE]], %[[OUT]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[RES:.*]] = torch.aten.maximum %[[ZERO]], %[[MIN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32>
|
||||||
|
func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue