mirror of https://github.com/llvm/torch-mlir
[LINALG] Decompose aten_hardswish op.
`aten.hardswish` op is decomposed into (x/6) * Relu6(x+3).pull/622/head snapshot-20220225.290
parent
056cd2078d
commit
7c637eebc3
|
@ -1307,3 +1307,41 @@ class StdBiasedModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: StdBiasedModule())
|
||||
def StdBiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class HardswishModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardswish(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: HardswishModule())
|
||||
def HardswishModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]]))
|
||||
|
||||
|
||||
class HardswishRandomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardswish(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: HardswishRandomModule())
|
||||
def HardswishRandomModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(128, 128, low=-10, high=10))
|
||||
|
||||
|
|
|
@ -186,6 +186,34 @@ def Torch_AtenHardsigmoid_Op : Torch_Op<"aten.hardsigmoid_", [
|
|||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenHardswishOp : Torch_Op<"aten.hardswish", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::hardswish : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenHardswish_Op : Torch_Op<"aten.hardswish_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::hardswish_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenSinOp : Torch_Op<"aten.sin", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -110,7 +110,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
|||
|
||||
// Helper for creating `aten::sub_tensor_op`.
|
||||
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
||||
Type tensorType, Value lhs, Value rhs) {
|
||||
Type tensorType, Value lhs, Value rhs) {
|
||||
Value alpha =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||||
Value sub =
|
||||
|
@ -124,7 +124,8 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
|||
Location loc, Operation *op,
|
||||
Type tensorType, Value x,
|
||||
Value y, Value z, Value dim) {
|
||||
Value sum = createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
|
||||
Value sum =
|
||||
createSumAlongDimension(rewriter, loc, op, z, dim, /*keepDim=*/true);
|
||||
if (!sum)
|
||||
return nullptr;
|
||||
auto broadcastSizeType =
|
||||
|
@ -361,7 +362,7 @@ public:
|
|||
loc, tensorType, tanhSquare, gradOutput);
|
||||
|
||||
Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput,
|
||||
gradMulTanhSquare);
|
||||
gradMulTanhSquare);
|
||||
rewriter.replaceOp(op, newGrad);
|
||||
return success();
|
||||
}
|
||||
|
@ -558,6 +559,63 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
||||
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
||||
Value input) {
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
Value constantOne =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(constantOne.getType()), constantOne);
|
||||
BaseTensorType oneDTensorType =
|
||||
inputType.getWithSizesAndDtype({1}, inputType.getDtype())
|
||||
.cast<BaseTensorType>();
|
||||
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||
loc, oneDTensorType, dimList, /*dtype=*/none, /*layout=*/none,
|
||||
/*device=*/none,
|
||||
/*pin_memory=*/none, /*memory_format=*/none);
|
||||
Value constantSix =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(6));
|
||||
Value sixTensor = rewriter.create<PseudoAtenFillScalarOp>(
|
||||
loc, oneDTensorType, emptyTensor, constantSix);
|
||||
Value relu6Out =
|
||||
rewriter.create<AtenMinimumOp>(loc, inputType, relu, sixTensor);
|
||||
return relu6Out;
|
||||
}
|
||||
|
||||
// Hardswish(x) = x * Relu6(x+3)/6
|
||||
namespace {
|
||||
class DecomposeAtenHardswishOp : public OpRewritePattern<AtenHardswishOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenHardswishOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
Type inputType = input.getType();
|
||||
|
||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value constantThree = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(3));
|
||||
Value constantSix = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(6));
|
||||
Value inputPlusThree = rewriter.create<AtenAddScalarOp>(
|
||||
loc, inputType, input, constantThree, /*alpha=*/constantOne);
|
||||
Value relu6 = getRelu6Results(rewriter, loc, inputPlusThree);
|
||||
Value divTensor =
|
||||
rewriter.create<AtenDivScalarOp>(loc, inputType, relu6, constantSix);
|
||||
Value mulTensor =
|
||||
rewriter.create<AtenMulTensorOp>(loc, inputType, divTensor, input);
|
||||
|
||||
rewriter.replaceOp(op, mulTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
|
||||
public:
|
||||
|
@ -659,7 +717,8 @@ public:
|
|||
Value input = op.self();
|
||||
Value output = op.result();
|
||||
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
|
||||
Value sum = rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.dtype());
|
||||
Value sum =
|
||||
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.dtype());
|
||||
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
|
||||
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputTensorType, sum,
|
||||
numTensorElements);
|
||||
|
@ -828,8 +887,8 @@ static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op,
|
|||
// Create a uniform random op with low and high set to lb and ub respectively.
|
||||
Value uniformRandom = rewriter.create<PseudoAtenUniformOp>(
|
||||
loc, inputType, input, lb, ub, noneVal);
|
||||
Value gtValue = rewriter.create<AtenLtScalarOp>(loc, boolType, uniformRandom,
|
||||
prob);
|
||||
Value gtValue =
|
||||
rewriter.create<AtenLtScalarOp>(loc, boolType, uniformRandom, prob);
|
||||
// Since `gtValue` will be a boolean tensor convert it back to the original
|
||||
// type.
|
||||
Value convertBack = rewriter.create<AtenToDtypeOp>(
|
||||
|
@ -883,20 +942,21 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
template<typename OpTy, typename T1T2Op>
|
||||
template <typename OpTy, typename T1T2Op>
|
||||
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
Value tensor1 = op.tensor1();
|
||||
Value tensor2 = op.tensor2();
|
||||
Value value = op.value();
|
||||
|
||||
Value product = rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
|
||||
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input, product,
|
||||
value);
|
||||
Value product =
|
||||
rewriter.create<T1T2Op>(loc, op.getType(), tensor1, tensor2);
|
||||
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), input,
|
||||
product, value);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1211,9 +1271,11 @@ class DecomposeComplexOpsPass
|
|||
// Make aten.matmul legal if the following condition is satisfied.
|
||||
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
|
||||
});
|
||||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(context);
|
||||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcmulOp, AtenMulTensorOp>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenAddcmulOp>();
|
||||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(context);
|
||||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenAddcdivOp>();
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
patterns.add<DecomposeAtenLayerNormOp>(context);
|
||||
|
@ -1239,6 +1301,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<PseudoAtenBernoulliFloatOp>();
|
||||
patterns.add<DecomposeAtenHardsigmoidOp>(context);
|
||||
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||
patterns.add<DecomposeAtenHardswishOp>(context);
|
||||
target.addIllegalOp<AtenHardswishOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -230,7 +230,7 @@ public:
|
|||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
|
||||
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
|
||||
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp,
|
||||
AtenHardsigmoidOp>(op)) {
|
||||
AtenHardsigmoidOp, AtenHardswishOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -451,6 +451,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
"aten::log : (Tensor) -> (Tensor)",
|
||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::hardsigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::hardswish : (Tensor) -> (Tensor)",
|
||||
"aten::sin : (Tensor) -> (Tensor)",
|
||||
"aten::exp : (Tensor) -> (Tensor)",
|
||||
"aten::cos : (Tensor) -> (Tensor)",
|
||||
|
|
|
@ -448,3 +448,26 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
|
|||
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.hardswish(
|
||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %[[INP]], %[[INT3]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[INT1_:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[RELU:.*]] = torch.aten.relu %[[ADD]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT1_]] : (!torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[MEM:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],f32>
|
||||
// CHECK: %[[INT6_:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FILL:.*]] = torch.pseudo.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[RELU]], %[[FILL]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[MUL]] : !torch.vtensor<[?,?],f32>
|
||||
func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.hardswish %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue