[LINALG] Decompose aten_hardswish op.

`aten.hardswish` op is decomposed into (x/6) * Relu6(x+3).
pull/622/head snapshot-20220225.290
Prashant Kumar 2022-02-15 13:14:32 +00:00
parent 056cd2078d
commit 7c637eebc3
6 changed files with 168 additions and 14 deletions

View File

@ -1307,3 +1307,41 @@ class StdBiasedModule(torch.nn.Module):
@register_test_case(module_factory=lambda: StdBiasedModule())
def StdBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class HardswishModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.hardswish(x)
@register_test_case(module_factory=lambda: HardswishModule())
def HardswishModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]]))
class HardswishRandomModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.hardswish(x)
@register_test_case(module_factory=lambda: HardswishRandomModule())
def HardswishRandomModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128, 128, low=-10, high=10))

View File

@ -186,6 +186,34 @@ def Torch_AtenHardsigmoid_Op : Torch_Op<"aten.hardsigmoid_", [
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_AtenHardswishOp : Torch_Op<"aten.hardswish", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::hardswish : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_AtenHardswish_Op : Torch_Op<"aten.hardswish_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::hardswish_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
}
def Torch_AtenSinOp : Torch_Op<"aten.sin", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -110,7 +110,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
// Helper for creating `aten::sub_tensor_op`.
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
Type tensorType, Value lhs, Value rhs) {
Type tensorType, Value lhs, Value rhs) {
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
Value sub =
@ -124,7 +124,8 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
Location loc, Operation *op,
Type tensorType, Value x,
Value y, Value z, Value dim) {
Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
Value sum =
createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
if (!sum)
return nullptr;
auto broadcastSizeType =
@ -361,7 +362,7 @@ public:
loc, tensorType, tanhSquare, gradOutput);
Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
gradMulTanhSquare);
gradMulTanhSquare);
rewriter.replaceOp(op, newGrad);
return success();
}
@ -558,6 +559,63 @@ public:
};
} // namespace
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
Value input) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
Value constantOne =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(constantOne.getType()), constantOne);
BaseTensorType oneDTensorType =
inputType.getWithSizesAndDtype({1}, inputType.getDtype())
.cast<BaseTensorType>();
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, oneDTensorType, dimList, /*dtype=*/none, /*layout=*/none,
/*device=*/none,
/*pin_memory=*/none, /*memory_format=*/none);
Value constantSix =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(6));
Value sixTensor = rewriter.create<PseudoAtenFillScalarOp>(
loc, oneDTensorType, emptyTensor, constantSix);
Value relu6Out =
rewriter.create<AtenMinimumOp>(loc, inputType, relu, sixTensor);
return relu6Out;
}
// Hardswish(x) = x * Relu6(x+3)/6
namespace {
class DecomposeAtenHardswishOp : public OpRewritePattern<AtenHardswishOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenHardswishOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Type inputType = input.getType();
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(6));
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
loc, inputType, input, constantThree, /*alpha=*/constantOne);
Value relu6 = getRelu6Results(rewriter, loc, inputPlusThree);
Value divTensor =
rewriter.create<AtenDivScalarOp>(loc, inputType, relu6, constantSix);
Value mulTensor =
rewriter.create<AtenMulTensorOp>(loc, inputType, divTensor, input);
rewriter.replaceOp(op, mulTensor);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
public:
@ -659,7 +717,8 @@ public:
Value input = op.self();
Value output = op.result();
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
Value sum = rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.dtype());
Value sum =
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.dtype());
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputTensorType, sum,
numTensorElements);
@ -828,8 +887,8 @@ static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op,
// Create a uniform random op with low and high set to lb and ub respectively.
Value uniformRandom = rewriter.create<PseudoAtenUniformOp>(
loc, inputType, input, lb, ub, noneVal);
Value gtValue = rewriter.create<AtenLtScalarOp>(loc, boolType, uniformRandom,
prob);
Value gtValue =
rewriter.create<AtenLtScalarOp>(loc, boolType, uniformRandom, prob);
// Since `gtValue` will be a boolean tensor convert it back to the original
// type.
Value convertBack = rewriter.create<AtenToDtypeOp>(
@ -883,20 +942,21 @@ public:
} // namespace
namespace {
template<typename OpTy, typename T1T2Op>
template <typename OpTy, typename T1T2Op>
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Value tensor1 = op.tensor1();
Value tensor2 = op.tensor2();
Value value = op.value();
Value product = rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input, product,
value);
Value product =
rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input,
product, value);
return success();
}
};
@ -1211,9 +1271,11 @@ class DecomposeComplexOpsPass
// Make aten.matmul legal if the following condition is satisfied.
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
});
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(context);
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(
context);
target.addIllegalOp<AtenAddcmulOp>();
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(context);
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(
context);
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<DecomposeAtenLayerNormOp>(context);
@ -1239,6 +1301,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
patterns.add<DecomposeAtenHardsigmoidOp>(context);
target.addIllegalOp<AtenHardsigmoidOp>();
patterns.add<DecomposeAtenHardswishOp>(context);
target.addIllegalOp<AtenHardswishOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -230,7 +230,7 @@ public:
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp,
AtenHardsigmoidOp>(op)) {
AtenHardsigmoidOp, AtenHardswishOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

View File

@ -451,6 +451,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
"aten::log : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::hardsigmoid : (Tensor) -> (Tensor)",
"aten::hardswish : (Tensor) -> (Tensor)",
"aten::sin : (Tensor) -> (Tensor)",
"aten::exp : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)",

View File

@ -448,3 +448,26 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.hardswish(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %[[INP]], %[[INT3]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[INT1_:.*]] = torch.constant.int 1
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[RELU:.*]] = torch.aten.relu %[[ADD]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT1_]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[MEM:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],f32>
// CHECK: %[[INT6_:.*]] = torch.constant.int 6
// CHECK: %[[FILL:.*]] = torch.pseudo.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[RELU]], %[[FILL]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[MUL]] : !torch.vtensor<[?,?],f32>
func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.hardswish %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}