From 33ad5ff15582e68a8b7dd0049f79bf24deafa408 Mon Sep 17 00:00:00 2001 From: yyp0 Date: Sat, 12 Oct 2024 17:51:15 +0800 Subject: [PATCH] [Torch] support 1d aten tensor shape and dtype infer (#3776) --- .../Transforms/SimplifyShapeCalculations.cpp | 57 +++++++++++++++++++ .../torch_mlir_e2e_test/test_suite/basic.py | 24 ++++++++ 2 files changed, 81 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 6d2008a28..f63fb4eb9 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -46,6 +46,62 @@ public: }; } // namespace +namespace { +class InferTensorOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTensorOp op, + PatternRewriter &rewriter) const override { + auto context = op.getContext(); + auto loc = op.getLoc(); + auto result = op.getResult(); + auto resultType = cast(result.getType()); + if (resultType.hasSizes() && resultType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "The result of aten.tensor is already a BaseTensorType."); + } + + auto inputList = op.getOperand(0); + auto listConstruct = inputList.getDefiningOp(); + if (!listConstruct) { + return rewriter.notifyMatchFailure( + op, "The operand 0 of aten.tensor is not PrimListConstructOp."); + } + + // Currently only support the 1d input list. + SmallVector sizes; + sizes.push_back(listConstruct->getOperands().size()); + FailureOr torchType; + auto eleType = listConstruct->getOperands()[0].getType(); + if (isa(eleType)) { + torchType = getTypeForScalarType(op->getContext(), + torch_upstream::ScalarType::Long); + } else if (isa(eleType)) { + torchType = getTypeForScalarType(op->getContext(), + torch_upstream::ScalarType::Float); + } else { + return rewriter.notifyMatchFailure( + op, "Currently only support Int and Float Type."); + } + auto newResultType = ValueTensorType::get(context, sizes, *torchType); + + Value originalTypedValue; + for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) { + if (!originalTypedValue) { + rewriter.setInsertionPointAfter(op); + originalTypedValue = + rewriter.create(loc, resultType, result); + } + use.set(originalTypedValue); + } + + result.setType(newResultType); + + return success(); + } +}; +} // namespace + static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, int resultNum, PatternRewriter &rewriter) { @@ -135,6 +191,7 @@ class SimplifyShapeCalculationsPass populateFoldPrimUncheckedCastOpPattern(patterns, context); patterns.insert(context); patterns.insert(context); + patterns.insert(context); PrimIfOp::getCanonicalizationPatterns(patterns, context); Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 2bda11410..778039aff 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5226,6 +5226,30 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorAlloc1dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4, 6], torch.int, True), + ] + ) + def forward(self, x): + res = torch.tensor([x.shape[0]]) + return res + + +@register_test_case(module_factory=lambda: TensorAlloc1dStaticModule()) +def TensorAlloc1dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 6)) + + +# ============================================================================== + + class ScalarTensorFloat32Module(torch.nn.Module): def __init__(self): super().__init__()