mirror of https://github.com/llvm/torch-mlir
[Torch] eliminate "getWithLeastStaticInformation" in DecomposeAtenTriuOp (#3330)
I am trying to eliminate 'getWithLeastStaticInformation' in DecomposeAtenTriuOp. Could you provide me with some suggestions? @qingyunqu @zjgarvey See issue https://github.com/llvm/torch-mlir/issues/3312pull/3346/merge
parent
972d47b586
commit
4d7cdba4bf
|
@ -87,6 +87,10 @@ int64_t getNumberOfElements(RankedTensorType inputType);
|
|||
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
|
||||
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);
|
||||
|
||||
ValueTensorType getTensorTypeFromShapeValues(ArrayRef<Value> shapes,
|
||||
Type dtype);
|
||||
Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim);
|
||||
|
||||
// Helper function to squeeze the input tensor at given dim.
|
||||
// Return the squeezed tensor or failure.
|
||||
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||
|
|
|
@ -674,7 +674,6 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenTriuOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
MLIRContext *context = op.getContext();
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
|
@ -685,37 +684,50 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
|
||||
}
|
||||
|
||||
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||
Value cstZero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
||||
Value rowDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(-2));
|
||||
Value colDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(-1));
|
||||
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
|
||||
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);
|
||||
Value rowSize = getTensorDimSize(rewriter, input, -2);
|
||||
Value colSize = getTensorDimSize(rewriter, input, -1);
|
||||
|
||||
Value rowArange = rewriter.create<AtenArangeOp>(
|
||||
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
Value colArange = rewriter.create<AtenArangeOp>(
|
||||
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true);
|
||||
auto int64DtypeInt = getDtypeIntValueForType(rewriter, loc, si64Type);
|
||||
auto rowArrangeType = getTensorTypeFromShapeValues({rowSize}, si64Type);
|
||||
auto colArrangeType = getTensorTypeFromShapeValues({colSize}, si64Type);
|
||||
|
||||
Value unsqueezeRowArange =
|
||||
rewriter.create<AtenUnsqueezeOp>(loc, baseType, rowArange, cstOne);
|
||||
Value unsqueezeColArange =
|
||||
rewriter.create<AtenUnsqueezeOp>(loc, baseType, colArange, cstZero);
|
||||
Value rowArange =
|
||||
rewriter.create<AtenArangeOp>(loc, rowArrangeType, rowSize,
|
||||
/*dtype=*/int64DtypeInt, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
Value colArange =
|
||||
rewriter.create<AtenArangeOp>(loc, colArrangeType, colSize,
|
||||
/*dtype=*/int64DtypeInt, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
|
||||
auto unsqueezeRowArangeInfo =
|
||||
unsqueezeTensor(rewriter, op, rowArange, cstOne);
|
||||
auto unsqueezeColArangeInfo =
|
||||
unsqueezeTensor(rewriter, op, colArange, cstZero);
|
||||
|
||||
if (failed(unsqueezeRowArangeInfo) || failed(unsqueezeColArangeInfo)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate unsqueeze tensor");
|
||||
}
|
||||
|
||||
Value unsqueezeRowArange = unsqueezeRowArangeInfo.value();
|
||||
Value unsqueezeColArange = unsqueezeColArangeInfo.value();
|
||||
|
||||
Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
|
||||
loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne);
|
||||
loc, unsqueezeRowArange.getType(), unsqueezeRowArange, op.getDiagonal(),
|
||||
cstOne);
|
||||
|
||||
auto boolType = rewriter.getI1Type();
|
||||
auto condType = getTensorTypeFromShapeValues({rowSize, colSize}, boolType);
|
||||
Value condTensor = rewriter.create<AtenGeTensorOp>(
|
||||
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
|
||||
loc, condType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
|
||||
op, op.getResult().getType(), condTensor, input, cstZero);
|
||||
|
|
|
@ -289,6 +289,32 @@ SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
|
|||
return updatedShape;
|
||||
}
|
||||
|
||||
ValueTensorType Torch::getTensorTypeFromShapeValues(ArrayRef<Value> shapes,
|
||||
Type dtype) {
|
||||
assert(!shapes.empty() && "shape vector cannot be empty");
|
||||
SmallVector<int64_t> shapeInts;
|
||||
for (Value shape : shapes) {
|
||||
int64_t dim;
|
||||
if (matchPattern(shape, m_TorchConstantInt(&dim)))
|
||||
shapeInts.push_back(dim);
|
||||
else
|
||||
shapeInts.push_back(kUnknownSize);
|
||||
}
|
||||
return Torch::ValueTensorType::get(shapes[0].getContext(), shapeInts, dtype);
|
||||
}
|
||||
|
||||
// Helper function to get the size of the tensor at the given dimension.
|
||||
Value Torch::getTensorDimSize(PatternRewriter &rewriter, Value tensor,
|
||||
int64_t dim) {
|
||||
auto loc = tensor.getLoc();
|
||||
auto dimVal =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
|
||||
// Use 'createOrFold' instead of 'create':
|
||||
// If the dimension is a constant, then the AtenSizeIntOp is folded to a
|
||||
// ContantIntOp.
|
||||
return rewriter.createOrFold<AtenSizeIntOp>(loc, tensor, dimVal);
|
||||
}
|
||||
|
||||
// Helper function to squeeze the input tensor at given dim.
|
||||
// Return the squeezed tensor or failure.
|
||||
FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||
|
|
Loading…
Reference in New Issue