[Torch] eliminate "getWithLeastStaticInformation" in DecomposeAtenTriuOp (#3330)

I am trying to eliminate 'getWithLeastStaticInformation' in
DecomposeAtenTriuOp. Could you provide me with some suggestions?
@qingyunqu @zjgarvey 
See issue https://github.com/llvm/torch-mlir/issues/3312
pull/3346/merge
Xinyu Yang 2024-05-22 23:16:57 +08:00 committed by GitHub
parent 972d47b586
commit 4d7cdba4bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 20 deletions

View File

@ -87,6 +87,10 @@ int64_t getNumberOfElements(RankedTensorType inputType);
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);
ValueTensorType getTensorTypeFromShapeValues(ArrayRef<Value> shapes,
Type dtype);
Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim);
// Helper function to squeeze the input tensor at given dim.
// Return the squeezed tensor or failure.
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,

View File

@ -674,7 +674,6 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTriuOp op,
PatternRewriter &rewriter) const override {
MLIRContext *context = op.getContext();
Location loc = op.getLoc();
Value input = op.getSelf();
auto inputType = cast<BaseTensorType>(input.getType());
@ -685,37 +684,50 @@ public:
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
}
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
Value cstZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value none = rewriter.create<ConstantNoneOp>(loc);
Value rowDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-2));
Value colDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);
Value rowSize = getTensorDimSize(rewriter, input, -2);
Value colSize = getTensorDimSize(rewriter, input, -1);
Value rowArange = rewriter.create<AtenArangeOp>(
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
Value colArange = rewriter.create<AtenArangeOp>(
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true);
auto int64DtypeInt = getDtypeIntValueForType(rewriter, loc, si64Type);
auto rowArrangeType = getTensorTypeFromShapeValues({rowSize}, si64Type);
auto colArrangeType = getTensorTypeFromShapeValues({colSize}, si64Type);
Value unsqueezeRowArange =
rewriter.create<AtenUnsqueezeOp>(loc, baseType, rowArange, cstOne);
Value unsqueezeColArange =
rewriter.create<AtenUnsqueezeOp>(loc, baseType, colArange, cstZero);
Value rowArange =
rewriter.create<AtenArangeOp>(loc, rowArrangeType, rowSize,
/*dtype=*/int64DtypeInt, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
Value colArange =
rewriter.create<AtenArangeOp>(loc, colArrangeType, colSize,
/*dtype=*/int64DtypeInt, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
auto unsqueezeRowArangeInfo =
unsqueezeTensor(rewriter, op, rowArange, cstOne);
auto unsqueezeColArangeInfo =
unsqueezeTensor(rewriter, op, colArange, cstZero);
if (failed(unsqueezeRowArangeInfo) || failed(unsqueezeColArangeInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
Value unsqueezeRowArange = unsqueezeRowArangeInfo.value();
Value unsqueezeColArange = unsqueezeColArangeInfo.value();
Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne);
loc, unsqueezeRowArange.getType(), unsqueezeRowArange, op.getDiagonal(),
cstOne);
auto boolType = rewriter.getI1Type();
auto condType = getTensorTypeFromShapeValues({rowSize, colSize}, boolType);
Value condTensor = rewriter.create<AtenGeTensorOp>(
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
loc, condType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
op, op.getResult().getType(), condTensor, input, cstZero);

View File

@ -289,6 +289,32 @@ SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
return updatedShape;
}
ValueTensorType Torch::getTensorTypeFromShapeValues(ArrayRef<Value> shapes,
Type dtype) {
assert(!shapes.empty() && "shape vector cannot be empty");
SmallVector<int64_t> shapeInts;
for (Value shape : shapes) {
int64_t dim;
if (matchPattern(shape, m_TorchConstantInt(&dim)))
shapeInts.push_back(dim);
else
shapeInts.push_back(kUnknownSize);
}
return Torch::ValueTensorType::get(shapes[0].getContext(), shapeInts, dtype);
}
// Helper function to get the size of the tensor at the given dimension.
Value Torch::getTensorDimSize(PatternRewriter &rewriter, Value tensor,
int64_t dim) {
auto loc = tensor.getLoc();
auto dimVal =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
// Use 'createOrFold' instead of 'create':
// If the dimension is a constant, then the AtenSizeIntOp is folded to a
// ContantIntOp.
return rewriter.createOrFold<AtenSizeIntOp>(loc, tensor, dimVal);
}
// Helper function to squeeze the input tensor at given dim.
// Return the squeezed tensor or failure.
FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,