[Torch] decompose AtenLerpTensorOp (#3251)

as title
pull/3413/head
Xinyu Yang 2024-06-03 15:25:09 +08:00 committed by GitHub
parent 23b53050de
commit 267052df2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 1 deletions

View File

@ -2585,7 +2585,36 @@ public:
auto weightedDelta = auto weightedDelta =
rewriter.create<AtenMulScalarOp>(loc, inputType, delta, op.getWeight()); rewriter.create<AtenMulScalarOp>(loc, inputType, delta, op.getWeight());
auto lerp = rewriter.create<AtenAddTensorOp>(loc, inputType, start, auto lerp = rewriter.create<AtenAddTensorOp>(loc, resType, start,
weightedDelta, cstOne);
rewriter.replaceOp(op, lerp);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenLerpTensorOp : public OpRewritePattern<AtenLerpTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLerpTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto start = op.getSelf();
auto inputType = cast<BaseTensorType>(start.getType());
auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(),
start, cstOne);
auto weightedDelta =
rewriter.create<AtenMulTensorOp>(loc, inputType, delta, op.getWeight());
auto lerp = rewriter.create<AtenAddTensorOp>(loc, resType, start,
weightedDelta, cstOne); weightedDelta, cstOne);
rewriter.replaceOp(op, lerp); rewriter.replaceOp(op, lerp);
return success(); return success();
@ -8114,6 +8143,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLerpScalarOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenLerpScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLerpTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);

View File

@ -507,6 +507,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_EmbeddingBagOp>(); target.addIllegalOp<Aten_EmbeddingBagOp>();
target.addIllegalOp<AtenLiftFreshCopyOp>(); target.addIllegalOp<AtenLiftFreshCopyOp>();
target.addIllegalOp<AtenLerpScalarOp>(); target.addIllegalOp<AtenLerpScalarOp>();
target.addIllegalOp<AtenLerpTensorOp>();
target.addIllegalOp<AtenMseLossOp>(); target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenRandintLowOp>(); target.addIllegalOp<AtenRandintLowOp>();
target.addIllegalOp<AtenRandintOp>(); target.addIllegalOp<AtenRandintOp>();

View File

@ -1020,6 +1020,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseSqrtModule_basic", "ElementwiseSqrtModule_basic",
"ElementwiseTanIntModule_basic", "ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic", "ElementwiseTanModule_basic",
"ElementwiseTernaryStaticShapeModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeI64ToI8Module_basic",
"ElementwiseToDtypeIdentityModule_basic", "ElementwiseToDtypeIdentityModule_basic",
@ -1475,6 +1476,7 @@ TOSA_PASS_SET = {
"AtenDotModule_basic", "AtenDotModule_basic",
"ElementwiseFloatTensorGtIntScalarModule_basic", "ElementwiseFloatTensorGtIntScalarModule_basic",
"ElementwiseLogSigmoidModule_basic", "ElementwiseLogSigmoidModule_basic",
"ElementwiseTernaryStaticShapeModule_basic",
"ElementwiseTruncModule_basic", "ElementwiseTruncModule_basic",
"ElementwiseTruncIntModule_basic", "ElementwiseTruncIntModule_basic",
"ElementwiseSgnModule_basic", "ElementwiseSgnModule_basic",

View File

@ -414,6 +414,31 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseTernaryStaticShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 4, 3], torch.float32, True),
([4, 3], torch.float32, True),
([3], torch.float32, True),
]
)
def forward(self, a, b, c):
return torch.lerp(a, b, c)
@register_test_case(module_factory=lambda: ElementwiseTernaryStaticShapeModule())
def ElementwiseTernaryStaticShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), tu.rand(4, 3), tu.rand(3))
# ==============================================================================
class ElementwiseAtenWhereSelfModule(torch.nn.Module): class ElementwiseAtenWhereSelfModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()