From 4d7cdba4bf29b3665094b843550917430a845a10 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Wed, 22 May 2024 23:16:57 +0800 Subject: [PATCH] [Torch] eliminate "getWithLeastStaticInformation" in DecomposeAtenTriuOp (#3330) I am trying to eliminate 'getWithLeastStaticInformation' in DecomposeAtenTriuOp. Could you provide me with some suggestions? @qingyunqu @zjgarvey See issue https://github.com/llvm/torch-mlir/issues/3312 --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 4 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 52 ++++++++++++------- lib/Dialect/Torch/Utils/Utils.cpp | 26 ++++++++++ 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 1aaf546c2..24db6f14f 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -87,6 +87,10 @@ int64_t getNumberOfElements(RankedTensorType inputType); SmallVector makeShapeLLVMCompatible(ArrayRef shape); SmallVector makeShapeTorchCompatible(ArrayRef shape); +ValueTensorType getTensorTypeFromShapeValues(ArrayRef shapes, + Type dtype); +Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim); + // Helper function to squeeze the input tensor at given dim. // Return the squeezed tensor or failure. FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 5ec22233b..ce88854f1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -674,7 +674,6 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTriuOp op, PatternRewriter &rewriter) const override { - MLIRContext *context = op.getContext(); Location loc = op.getLoc(); Value input = op.getSelf(); auto inputType = cast(input.getType()); @@ -685,37 +684,50 @@ public: return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2"); } - auto baseType = ValueTensorType::getWithLeastStaticInformation(context); Value cstZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value none = rewriter.create(loc); - Value rowDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(-2)); - Value colDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(-1)); - Value rowSize = rewriter.create(loc, input, rowDim); - Value colSize = rewriter.create(loc, input, colDim); + Value rowSize = getTensorDimSize(rewriter, input, -2); + Value colSize = getTensorDimSize(rewriter, input, -1); - Value rowArange = rewriter.create( - loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); - Value colArange = rewriter.create( - loc, baseType, colSize, /*dtype=*/none, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); + auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true); + auto int64DtypeInt = getDtypeIntValueForType(rewriter, loc, si64Type); + auto rowArrangeType = getTensorTypeFromShapeValues({rowSize}, si64Type); + auto colArrangeType = getTensorTypeFromShapeValues({colSize}, si64Type); - Value unsqueezeRowArange = - rewriter.create(loc, baseType, rowArange, cstOne); - Value unsqueezeColArange = - rewriter.create(loc, baseType, colArange, cstZero); + Value rowArange = + rewriter.create(loc, rowArrangeType, rowSize, + /*dtype=*/int64DtypeInt, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + Value colArange = + rewriter.create(loc, colArrangeType, colSize, + /*dtype=*/int64DtypeInt, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + + auto unsqueezeRowArangeInfo = + unsqueezeTensor(rewriter, op, rowArange, cstOne); + auto unsqueezeColArangeInfo = + unsqueezeTensor(rewriter, op, colArange, cstZero); + + if (failed(unsqueezeRowArangeInfo) || failed(unsqueezeColArangeInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + + Value unsqueezeRowArange = unsqueezeRowArangeInfo.value(); + Value unsqueezeColArange = unsqueezeColArangeInfo.value(); Value unsqueezeRowArangePlusDiagonal = rewriter.create( - loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne); + loc, unsqueezeRowArange.getType(), unsqueezeRowArange, op.getDiagonal(), + cstOne); + auto boolType = rewriter.getI1Type(); + auto condType = getTensorTypeFromShapeValues({rowSize, colSize}, boolType); Value condTensor = rewriter.create( - loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); + loc, condType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), condTensor, input, cstZero); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 8101a2a5b..197f09c66 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -289,6 +289,32 @@ SmallVector Torch::makeShapeTorchCompatible(ArrayRef shape) { return updatedShape; } +ValueTensorType Torch::getTensorTypeFromShapeValues(ArrayRef shapes, + Type dtype) { + assert(!shapes.empty() && "shape vector cannot be empty"); + SmallVector shapeInts; + for (Value shape : shapes) { + int64_t dim; + if (matchPattern(shape, m_TorchConstantInt(&dim))) + shapeInts.push_back(dim); + else + shapeInts.push_back(kUnknownSize); + } + return Torch::ValueTensorType::get(shapes[0].getContext(), shapeInts, dtype); +} + +// Helper function to get the size of the tensor at the given dimension. +Value Torch::getTensorDimSize(PatternRewriter &rewriter, Value tensor, + int64_t dim) { + auto loc = tensor.getLoc(); + auto dimVal = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + // Use 'createOrFold' instead of 'create': + // If the dimension is a constant, then the AtenSizeIntOp is folded to a + // ContantIntOp. + return rewriter.createOrFold(loc, tensor, dimVal); +} + // Helper function to squeeze the input tensor at given dim. // Return the squeezed tensor or failure. FailureOr Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,