[LINALG] Add E2E support for `aten.Hardsigmoid` op

This commit adds lowering of `aten.Hardsigmoid` op.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/603/head
Gaurav Shukla 2022-02-14 20:16:44 +05:30
parent 00a6e9c1bb
commit cd21dda867
6 changed files with 145 additions and 1 deletions

View File

@ -666,6 +666,45 @@ def _LogSoftmaxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4)) module.forward(torch.randn(3, 2, 4))
# ============================================================================== # ==============================================================================
class HardsigmoidModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.hardsigmoid(x)
@register_test_case(module_factory=lambda: HardsigmoidModule())
def HardsigmoidModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]]))
# ==============================================================================
class HardsigmoidRandomModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.hardsigmoid(x)
@register_test_case(module_factory=lambda: HardsigmoidRandomModule())
def HardsigmoidRandomModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, low=-10, high=10))
# ==============================================================================
class BroadcastToModule(torch.nn.Module): class BroadcastToModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -158,6 +158,34 @@ def Torch_AtenSigmoid_Op : Torch_Op<"aten.sigmoid_", [
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))"; let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
} }
def Torch_AtenHardsigmoidOp : Torch_Op<"aten.hardsigmoid", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::hardsigmoid : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_AtenHardsigmoid_Op : Torch_Op<"aten.hardsigmoid_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::hardsigmoid_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_AtenSinOp : Torch_Op<"aten.sin", [ def Torch_AtenSinOp : Torch_Op<"aten.sin", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics

View File

@ -716,6 +716,46 @@ public:
}; };
} // namespace } // namespace
// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
namespace {
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenHardsigmoidOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Type inputType = input.getType();
// outputTensor = (input + 3) / 6.
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(6));
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
loc, inputType, input, constantThree, /*alpha=*/constantOne);
Value outputTensor = rewriter.create<AtenDivScalarOp>(
loc, inputType, inputPlusThree, constantSix);
// result = max(0, min(1, (input+3)/6))
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value zeroTensor = rewriter.create<AtenZerosLikeOp>(
loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none,
/*pin_memory=*/none, /*memory_format=*/none);
Value oneTensor = rewriter.create<AtenOnesLikeOp>(
loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none,
/*pin_memory=*/none, /*memory_format=*/none);
Value minResult =
rewriter.create<AtenMinimumOp>(loc, inputType, oneTensor, outputTensor);
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), zeroTensor,
minResult);
return success();
}
};
} // namespace
// Returns a tensor with bernoulli(p) distribution. // Returns a tensor with bernoulli(p) distribution.
// Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p). // Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p).
static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op, static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op,
@ -1157,6 +1197,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenBernoulliOp>(); target.addIllegalOp<AtenBernoulliOp>();
patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context); patterns.add<DecomposePseudoAtenBernoulliFloatOp>(context);
target.addIllegalOp<PseudoAtenBernoulliFloatOp>(); target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
patterns.add<DecomposeAtenHardsigmoidOp>(context);
target.addIllegalOp<AtenHardsigmoidOp>();
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) { std::move(patterns)))) {

View File

@ -229,7 +229,8 @@ public:
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp, AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp, AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp>(op)) { PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp,
AtenHardsigmoidOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }

View File

@ -450,6 +450,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)", "aten::log : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)",
"aten::hardsigmoid : (Tensor) -> (Tensor)",
"aten::sin : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)",
"aten::exp : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)",

View File

@ -414,3 +414,36 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor
%0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64> %0 = torch.aten.select.int %arg0, %int0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?],si64>
return %0 : !torch.vtensor<[?],si64> return %0 : !torch.vtensor<[?],si64>
} }
// -----
// CHECK-LABEL: func @torch.aten.hardsigmoid(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[ADD3:.*]] = torch.aten.add.Scalar %[[ARG]], %[[INT3]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[OUT:.*]] = torch.aten.div.Scalar %[[ADD3]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[IND0:.*]] = torch.constant.int 0
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[ARG]], %[[IND0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[IND1:.*]] = torch.constant.int 1
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[ARG]], %[[IND1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
// CHECK: %[[FILL0:.*]] = torch.constant.int 0
// CHECK: %[[ZERO:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[FILL0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[IND0_2:.*]] = torch.constant.int 0
// CHECK: %[[DIM0_2:.*]] = torch.aten.size.int %[[ARG]], %[[IND0_2]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[IND1_2:.*]] = torch.constant.int 1
// CHECK: %[[DIM1_2:.*]] = torch.aten.size.int %[[ARG]], %[[IND1_2]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[SIZES_2:.*]] = torch.prim.ListConstruct %[[DIM0_2]], %[[DIM1_2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: %[[EMPTY_2:.*]] = torch.aten.empty.memory_format %[[SIZES_2]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
// CHECK: %[[FILL1:.*]] = torch.constant.int 1
// CHECK: %[[ONE:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY_2]], %[[FILL1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ONE]], %[[OUT]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[RES:.*]] = torch.aten.maximum %[[ZERO]], %[[MIN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32>
func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}