mirror of https://github.com/llvm/torch-mlir
parent
23b53050de
commit
267052df2a
|
@ -2585,7 +2585,36 @@ public:
|
||||||
|
|
||||||
auto weightedDelta =
|
auto weightedDelta =
|
||||||
rewriter.create<AtenMulScalarOp>(loc, inputType, delta, op.getWeight());
|
rewriter.create<AtenMulScalarOp>(loc, inputType, delta, op.getWeight());
|
||||||
auto lerp = rewriter.create<AtenAddTensorOp>(loc, inputType, start,
|
auto lerp = rewriter.create<AtenAddTensorOp>(loc, resType, start,
|
||||||
|
weightedDelta, cstOne);
|
||||||
|
rewriter.replaceOp(op, lerp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenLerpTensorOp : public OpRewritePattern<AtenLerpTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenLerpTensorOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
Value cstOne =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
auto start = op.getSelf();
|
||||||
|
auto inputType = cast<BaseTensorType>(start.getType());
|
||||||
|
|
||||||
|
auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(),
|
||||||
|
start, cstOne);
|
||||||
|
|
||||||
|
auto weightedDelta =
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, inputType, delta, op.getWeight());
|
||||||
|
auto lerp = rewriter.create<AtenAddTensorOp>(loc, resType, start,
|
||||||
weightedDelta, cstOne);
|
weightedDelta, cstOne);
|
||||||
rewriter.replaceOp(op, lerp);
|
rewriter.replaceOp(op, lerp);
|
||||||
return success();
|
return success();
|
||||||
|
@ -8114,6 +8143,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLerpScalarOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLerpScalarOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLerpTensorOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
|
||||||
|
|
|
@ -507,6 +507,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<Aten_EmbeddingBagOp>();
|
target.addIllegalOp<Aten_EmbeddingBagOp>();
|
||||||
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
target.addIllegalOp<AtenLiftFreshCopyOp>();
|
||||||
target.addIllegalOp<AtenLerpScalarOp>();
|
target.addIllegalOp<AtenLerpScalarOp>();
|
||||||
|
target.addIllegalOp<AtenLerpTensorOp>();
|
||||||
target.addIllegalOp<AtenMseLossOp>();
|
target.addIllegalOp<AtenMseLossOp>();
|
||||||
target.addIllegalOp<AtenRandintLowOp>();
|
target.addIllegalOp<AtenRandintLowOp>();
|
||||||
target.addIllegalOp<AtenRandintOp>();
|
target.addIllegalOp<AtenRandintOp>();
|
||||||
|
|
|
@ -1020,6 +1020,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseSqrtModule_basic",
|
"ElementwiseSqrtModule_basic",
|
||||||
"ElementwiseTanIntModule_basic",
|
"ElementwiseTanIntModule_basic",
|
||||||
"ElementwiseTanModule_basic",
|
"ElementwiseTanModule_basic",
|
||||||
|
"ElementwiseTernaryStaticShapeModule_basic",
|
||||||
"ElementwiseToDtypeF32ToI64Module_basic",
|
"ElementwiseToDtypeF32ToI64Module_basic",
|
||||||
"ElementwiseToDtypeI64ToI8Module_basic",
|
"ElementwiseToDtypeI64ToI8Module_basic",
|
||||||
"ElementwiseToDtypeIdentityModule_basic",
|
"ElementwiseToDtypeIdentityModule_basic",
|
||||||
|
@ -1475,6 +1476,7 @@ TOSA_PASS_SET = {
|
||||||
"AtenDotModule_basic",
|
"AtenDotModule_basic",
|
||||||
"ElementwiseFloatTensorGtIntScalarModule_basic",
|
"ElementwiseFloatTensorGtIntScalarModule_basic",
|
||||||
"ElementwiseLogSigmoidModule_basic",
|
"ElementwiseLogSigmoidModule_basic",
|
||||||
|
"ElementwiseTernaryStaticShapeModule_basic",
|
||||||
"ElementwiseTruncModule_basic",
|
"ElementwiseTruncModule_basic",
|
||||||
"ElementwiseTruncIntModule_basic",
|
"ElementwiseTruncIntModule_basic",
|
||||||
"ElementwiseSgnModule_basic",
|
"ElementwiseSgnModule_basic",
|
||||||
|
|
|
@ -414,6 +414,31 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseTernaryStaticShapeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([5, 4, 3], torch.float32, True),
|
||||||
|
([4, 3], torch.float32, True),
|
||||||
|
([3], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b, c):
|
||||||
|
return torch.lerp(a, b, c)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseTernaryStaticShapeModule())
|
||||||
|
def ElementwiseTernaryStaticShapeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 4, 3), tu.rand(4, 3), tu.rand(3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseAtenWhereSelfModule(torch.nn.Module):
|
class ElementwiseAtenWhereSelfModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue