[torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055)

Reshaping tensors depend on directly matching individual dimensions to
their corresponding dim in the `torch.view` reshape dimensions. This
involves decoupling dynamic dimensions from their static counterparts
and support cleanup / canonicalization.
pull/3062/head
Rob Suderman 2024-03-26 12:41:40 -07:00 committed by GitHub
parent 17eeac880a
commit 14b548f968
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 673 additions and 78 deletions

View File

@ -3206,52 +3206,6 @@ def Torch_AtenSquare_Op : Torch_Op<"aten.square_", [
}]; }];
} }
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::unsqueeze : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUnsqueezeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenUnsqueezeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::unsqueeze_ : (Tensor, int) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_IntType:$dim
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUnsqueeze_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenUnsqueeze_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenZeroOp : Torch_Op<"aten.zero", [ def Torch_AtenZeroOp : Torch_Op<"aten.zero", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -3391,6 +3345,53 @@ def Torch_AtenFill_TensorOp : Torch_Op<"aten.fill_.Tensor", [
}]; }];
} }
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::unsqueeze : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUnsqueezeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenUnsqueezeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::unsqueeze_ : (Tensor, int) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_IntType:$dim
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUnsqueeze_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenUnsqueeze_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -105,6 +105,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps); createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
std::unique_ptr<OperationPass<func::FuncOp>> createScalarizeShapesPass();
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass(); std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createFuseQuantizedOpsPass(); std::unique_ptr<OperationPass<func::FuncOp>> createFuseQuantizedOpsPass();

View File

@ -235,6 +235,17 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
}]; }];
} }
def ScalarizeShapes : Pass<"torch-scalarize-shapes", "func::FuncOp"> {
let summary = "Takes common shape computation operations and scalarizes them.";
let constructor = "mlir::torch::Torch::createScalarizeShapesPass()";
let description = [{
Scalarizes shape computations to better propagate static shapes. As some
shape operation happen on tensors a single dynamic dimension can prevent
propagating static shapes. Scalarization prevents these dynamic
dimensions from blocking the statically computable operations.
}];
}
def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> { def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
let summary = "Recompose torch operations that have been decomposed by TorchScript"; let summary = "Recompose torch operations that have been decomposed by TorchScript";
let constructor = "mlir::torch::Torch::createRecomposeComplexOpsPass()"; let constructor = "mlir::torch::Torch::createRecomposeComplexOpsPass()";

View File

@ -1,7 +1,7 @@
add_subdirectory(TorchOnnxToTorch) add_subdirectory(TorchOnnxToTorch)
add_subdirectory(TorchToArith)
add_subdirectory(TorchToLinalg) add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF) add_subdirectory(TorchToSCF)
add_subdirectory(TorchToArith)
add_subdirectory(TorchToTensor) add_subdirectory(TorchToTensor)
add_subdirectory(TorchToTosa) add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_STABLEHLO) if(TORCH_MLIR_ENABLE_STABLEHLO)
@ -12,9 +12,9 @@ add_subdirectory(TorchConversionToMLProgram)
add_subdirectory(Utils) add_subdirectory(Utils)
# TODO: Automate this with add_torch_mlir_conversion_library. # TODO: Automate this with add_torch_mlir_conversion_library.
set(linked_libs TorchMLIRTorchToLinalg set(linked_libs TorchMLIRTorchToArith
TorchMLIRTorchToLinalg
TorchMLIRTorchToSCF TorchMLIRTorchToSCF
TorchMLIRTorchToArith
TorchMLIRTorchToTensor TorchMLIRTorchToTensor
TorchMLIRTorchToTosa TorchMLIRTorchToTosa
TorchMLIRTorchToTMTensor TorchMLIRTorchToTMTensor

View File

@ -1720,7 +1720,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.create<Torch::AtenEqIntOp>(binder.getLoc(), dim, zero); rewriter.create<Torch::AtenEqIntOp>(binder.getLoc(), dim, zero);
isZero = isZero =
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isZero); rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isZero);
Value adjustment; Value adjustment = zero;
int64_t inputDimsSize = dataSizes.size(); int64_t inputDimsSize = dataSizes.size();
if (i < inputDimsSize) { if (i < inputDimsSize) {
adjustment = rewriter.create<Torch::ConstantIntOp>( adjustment = rewriter.create<Torch::ConstantIntOp>(
@ -1728,11 +1728,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.getIntegerAttr(rewriter.getIntegerType(64), rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dataSizes[i])); dataSizes[i]));
} }
// Will never have a 0 in the shape tensor input at an index out of
// bounds of original input dims Therefore, no need to adjust
else {
adjustment = zero;
}
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>( Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isZero, adjustment); binder.getLoc(), isZero, adjustment);
Value finalDim = rewriter.create<Torch::AtenAddIntOp>( Value finalDim = rewriter.create<Torch::AtenAddIntOp>(

View File

@ -239,6 +239,23 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenIntBoolOp : public OpConversionPattern<AtenIntBoolOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenIntBoolOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
Value result =
convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace { namespace {
class ConvertAtenFloatScalarOp : public OpConversionPattern<AtenFloatScalarOp> { class ConvertAtenFloatScalarOp : public OpConversionPattern<AtenFloatScalarOp> {
public: public:
@ -433,6 +450,9 @@ public:
target.addIllegalOp<AtenFloatScalarOp>(); target.addIllegalOp<AtenFloatScalarOp>();
patterns.add<ConvertAtenFloatScalarOp>(typeConverter, context); patterns.add<ConvertAtenFloatScalarOp>(typeConverter, context);
target.addIllegalOp<AtenIntBoolOp>();
patterns.add<ConvertAtenIntBoolOp>(typeConverter, context);
target.addIllegalOp<AtenAddOp>(); target.addIllegalOp<AtenAddOp>();
patterns.add<ConvertAtenAddOp>(typeConverter, context); patterns.add<ConvertAtenAddOp>(typeConverter, context);

View File

@ -190,6 +190,61 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenFullOp : public OpConversionPattern<AtenFullOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenFullOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
SmallVector<Value> inShape;
if (!getListConstructElements(adaptor.getSize(), inShape)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: the size list is not from list construct");
}
auto resultTy = cast<RankedTensorType>(
this->getTypeConverter()->convertType(op.getResult().getType()));
if (resultTy.getRank() != static_cast<int64_t>(inShape.size()))
return rewriter.notifyMatchFailure(
op, "rank of shape and result shape do not match");
SmallVector<OpFoldResult> filteredShape;
for (int i = 0, s = resultTy.getRank(); i < s; ++i) {
if (resultTy.isDynamicDim(i)) {
filteredShape.push_back(inShape[i]);
continue;
}
filteredShape.push_back(rewriter.getIndexAttr(resultTy.getDimSize(i)));
}
Value full = adaptor.getFillValue();
if (full.getType() != resultTy.getElementType()) {
if (isa<mlir::FloatType>(full.getType())) {
full = rewriter.create<arith::TruncFOp>(loc, resultTy.getElementType(),
full);
} else if (isa<mlir::IntegerType>(full.getType())) {
full = rewriter.create<arith::TruncIOp>(loc, resultTy.getElementType(),
full);
}
}
Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, filteredShape, resultTy.getElementType());
rewriter.replaceOpWithNewOp<linalg::FillOp>(op, full, outTensor);
return success();
}
};
} // namespace
namespace { namespace {
// Converts a tensor with one element to a scalar value. // Converts a tensor with one element to a scalar value.
template <typename OpTy> template <typename OpTy>
@ -226,6 +281,9 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context); patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
target.addIllegalOp<PrimNumToTensorScalarOp>(); target.addIllegalOp<PrimNumToTensorScalarOp>();
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context); patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
target.addIllegalOp<AtenFullOp>();
patterns.add<ConvertAtenFullOp>(typeConverter, context);
patterns.add<ConvertAtenImplicitLikeOp<AtenScalarImplicitOp>>(typeConverter, patterns.add<ConvertAtenImplicitLikeOp<AtenScalarImplicitOp>>(typeConverter,
context); context);
patterns.add<ConvertAtenImplicitLikeOp<AtenFloatImplicitOp>>(typeConverter, patterns.add<ConvertAtenImplicitLikeOp<AtenFloatImplicitOp>>(typeConverter,

View File

@ -105,6 +105,43 @@ public:
} }
}; };
class ConvertAtenTensorOpPattern : public OpConversionPattern<AtenTensorOp> {
public:
using OpConversionPattern<AtenTensorOp>::OpConversionPattern;
using OpAdaptor = typename AtenTensorOp::Adaptor;
LogicalResult
matchAndRewrite(AtenTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto list = op.getData().getDefiningOp<Torch::PrimListConstructOp>();
if (!list)
return failure();
auto typeConverter = getTypeConverter();
auto resultTy = cast<ShapedType>(typeConverter->convertType(op.getType()));
auto resultETy = resultTy.getElementType();
SmallVector<Value> values;
for (Value operand : list.getOperands()) {
Value value = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(operand.getType()),
operand);
if (isa<mlir::IntegerType>(resultETy) && value.getType() != resultETy)
value = rewriter.create<arith::TruncIOp>(loc, resultETy, value);
if (isa<mlir::FloatType>(resultETy) && value.getType() != resultETy)
value = rewriter.create<arith::TruncFOp>(loc, resultETy, value);
values.push_back(value);
}
rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, resultTy, values);
return success();
}
};
class ConvertTorchToTensor class ConvertTorchToTensor
: public ConvertTorchToTensorBase<ConvertTorchToTensor> { : public ConvertTorchToTensorBase<ConvertTorchToTensor> {
public: public:
@ -118,6 +155,7 @@ public:
target.addLegalDialect<arith::ArithDialect>(); target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<tensor::TensorDialect>(); target.addLegalDialect<tensor::TensorDialect>();
target.addIllegalOp<Torch::AtenItemOp>(); target.addIllegalOp<Torch::AtenItemOp>();
target.addIllegalOp<Torch::AtenTensorOp>();
target.addIllegalOp<Torch::Aten_ShapeAsTensorOp>(); target.addIllegalOp<Torch::Aten_ShapeAsTensorOp>();
TypeConverter typeConverter; TypeConverter typeConverter;
@ -125,8 +163,8 @@ public:
TorchConversion::setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
patterns.add<ConvertAtenShapeToTensorPatternOp, ConvertAtenItemOp>( patterns.add<ConvertAtenShapeToTensorPatternOp, ConvertAtenItemOp,
typeConverter, context); ConvertAtenTensorOpPattern>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))

View File

@ -715,16 +715,58 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b); return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b);
} }
//===----------------------------------------------------------------------===//
// AtenUnsqueezeOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenUnsqueezeOp::fold(FoldAdaptor adaptor) {
auto selfTy = dyn_cast<BaseTensorType>(getSelf().getType());
auto rty = dyn_cast<BaseTensorType>(getType());
if (!rty.hasDtype())
return {};
if (auto attr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf())) {
auto aty = dyn_cast<RankedTensorType>(attr.getType());
if (rty.hasSizes() && rty.areAllSizesKnown() && attr.isSplat()) {
auto naty = RankedTensorType::get(rty.getSizes(), aty.getElementType());
return DenseElementsAttr::get(naty, attr.getSplatValue<Attribute>());
}
}
if (getSelf().getType() != getResult().getType())
return nullptr;
if (selfTy && rty) {
if (selfTy.hasSizes() && rty.hasSizes() &&
selfTy.getSizes().size() == rty.getSizes().size())
return getSelf();
}
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSqueezeOp // AtenSqueezeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
if (getOperand().getType() != getResult().getType()) auto selfTy = dyn_cast<BaseTensorType>(getSelf().getType());
auto rty = dyn_cast<BaseTensorType>(getType());
if (!rty.hasDtype())
return {};
if (auto attr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf())) {
auto aty = dyn_cast<RankedTensorType>(attr.getType());
if (rty.hasSizes() && rty.areAllSizesKnown() && attr.isSplat()) {
auto naty = RankedTensorType::get(rty.getSizes(), aty.getElementType());
return DenseElementsAttr::get(naty, attr.getSplatValue<Attribute>());
}
}
if (getSelf().getType() != getResult().getType())
return nullptr; return nullptr;
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) { if (selfTy && rty) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) if (selfTy.hasSizes() && rty.hasSizes() &&
return getOperand(); selfTy.getSizes().size() == rty.getSizes().size())
return getSelf();
} }
return nullptr; return nullptr;
} }

View File

@ -17,6 +17,7 @@ add_mlir_library(TorchMLIRTorchPasses
ReifyShapeCalculations.cpp ReifyShapeCalculations.cpp
ReifyDtypeCalculations.cpp ReifyDtypeCalculations.cpp
ReifyAbstractInterpCalculationsUtils.cpp ReifyAbstractInterpCalculationsUtils.cpp
ScalarizeShapes.cpp
AbstractInterpLibrary.cpp AbstractInterpLibrary.cpp
SimplifyShapeCalculations.cpp SimplifyShapeCalculations.cpp
SimplifyDtypeCalculations.cpp SimplifyDtypeCalculations.cpp

View File

@ -684,6 +684,13 @@ public:
/*keepDim=*/true), /*keepDim=*/true),
op.getSelf(), dim, start, startPlusOne, /*step=*/one); op.getSelf(), dim, start, startPlusOne, /*step=*/one);
auto sliceTy = cast<BaseTensorType>(slice.getType());
auto resultTy = cast<BaseTensorType>(op.getResult().getType());
if (sliceTy.getSizes().size() == resultTy.getSizes().size()) {
rewriter.replaceOp(op, slice);
return success();
}
// `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after // `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after
// slicing, while `aten.select.int` does. // slicing, while `aten.select.int` does.
rewriter.replaceOpWithNewOp<AtenSqueezeDimOp>(op, op.getResult().getType(), rewriter.replaceOpWithNewOp<AtenSqueezeDimOp>(op, op.getResult().getType(),

View File

@ -0,0 +1,345 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
LogicalResult getListOperands(Value value, SmallVector<Value> &vals) {
auto list = value.getDefiningOp<Torch::PrimListConstructOp>();
if (!list)
return failure();
for (auto operand : list.getOperands())
vals.push_back(operand);
return success();
}
LogicalResult getListFromTensor(Value value, SmallVector<Value> &vals) {
auto tensor = value.getDefiningOp<Torch::AtenTensorOp>();
if (!tensor)
return failure();
return getListOperands(tensor.getData(), vals);
}
} // namespace
namespace {
class PropagateAtenShapeToTensorPattern
: public OpRewritePattern<Aten_ShapeAsTensorOp> {
public:
using OpRewritePattern<Aten_ShapeAsTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ShapeAsTensorOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto self = op.getSelf();
auto selfTy = cast<BaseTensorType>(self.getType());
if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "self has unknown rank");
int64_t rank = selfTy.getSizes().size();
SmallVector<Value> dims;
for (int64_t i = 0; i < rank; ++i) {
auto iv = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
dims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
loc, rewriter.getType<Torch::IntType>(), self, iv));
}
auto dimList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
dims);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
op, op.getType(), dimList, cstNone, cstNone, cstFalse);
return success();
}
};
} // namespace
namespace {
class PropagateAtenIndexSelectPattern
: public OpRewritePattern<AtenIndexSelectOp> {
public:
using OpRewritePattern<AtenIndexSelectOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIndexSelectOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
SmallVector<Value> elements;
if (failed(getListFromTensor(op.getSelf(), elements)))
return failure();
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "requires a constant dim");
DenseElementsAttr idx;
if (!matchPattern(op.getIndex(), m_Constant(&idx)))
return rewriter.notifyMatchFailure(op, "requires a constant index");
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
if (!selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "requires known rank");
auto selfShape = selfTy.getSizes();
int64_t selfRank = selfShape.size();
dim = dim < 0 ? dim + selfRank : dim;
int64_t dimLength = elements.size();
if (selfShape[dim] != dimLength)
return rewriter.notifyMatchFailure(
op, "dim length does not match number of elements");
for (int64_t i = 0; i < selfRank; ++i) {
if (i == dim)
continue;
if (selfShape[i] != 1)
return rewriter.notifyMatchFailure(op,
"expects unary non-dim dimension");
}
SmallVector<Value> selected;
if (idx.isSplat()) {
int64_t indexInt = idx.getSplatValue<APInt>().getSExtValue();
indexInt = indexInt < 0 ? indexInt + dimLength : indexInt;
selected.resize(idx.getNumElements(), elements[indexInt]);
} else {
for (APInt val : idx.getValues<APInt>()) {
int64_t indexInt = val.getSExtValue();
selected.push_back(elements[indexInt]);
}
}
auto eTy = elements.front().getType();
auto dimList = rewriter.create<Torch::PrimListConstructOp>(
loc, rewriter.getType<Torch::ListType>(eTy), selected);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
op, op.getType(), dimList, cstNone, cstNone, cstFalse);
return success();
}
};
} // namespace
namespace {
// Conversion attempts to handle some common propagatable slice cases, namely
// splatted values, no-op slices, known list of values, or any case where a
// new construction can be generated from a previous set of scalars allowing
// the parent tensor to be bypassed.
class PropagateAtenSliceTensorPattern
: public OpRewritePattern<AtenSliceTensorOp> {
public:
using OpRewritePattern<AtenSliceTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSliceTensorOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
SmallVector<Value> elements;
if (failed(getListFromTensor(op.getSelf(), elements)))
return failure();
int64_t dim, start, end, step;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "requires a constant dim");
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(op, "requires a constant start");
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
return rewriter.notifyMatchFailure(op, "requires a constant end");
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
return rewriter.notifyMatchFailure(op, "requires a constant step");
if (step < 0)
return rewriter.notifyMatchFailure(op, "requires a positive step value");
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
auto selfShape = selfTy.getSizes();
int64_t selfRank = selfShape.size();
// Correct for negative indexing:
dim = dim < 0 ? dim + selfRank : dim;
int64_t dimLength = elements.size();
start = start < 0 ? start + dimLength : start;
end = end < 0 ? end + dimLength : end;
start = start < 0 ? 0 : start;
end = end < 0 ? 0 : end;
end = end > dimLength ? dimLength : end;
if (selfShape[dim] != dimLength)
return rewriter.notifyMatchFailure(
op, "dim length does not match number of elements");
for (int64_t i = 0; i < selfRank; ++i) {
if (i == dim)
continue;
if (selfShape[i] != 1)
return rewriter.notifyMatchFailure(op,
"expects unary non-dim dimension");
}
SmallVector<Value> selected;
for (int i = start; i < end; i += step)
selected.push_back(elements[i]);
auto eTy = elements.front().getType();
auto dimList = rewriter.create<Torch::PrimListConstructOp>(
loc, rewriter.getType<Torch::ListType>(eTy), selected);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
op, op.getType(), dimList, cstNone, cstNone, cstFalse);
return success();
}
};
} // namespace
namespace {
class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
public:
using OpRewritePattern<AtenItemOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenItemOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> elements;
if (failed(getListFromTensor(op.getSelf(), elements)))
return failure();
if (elements.size() != 1)
return rewriter.notifyMatchFailure(op, "expected no elements");
rewriter.replaceOp(op, elements[0]);
return success();
}
};
} // namespace
namespace {
class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
public:
using OpRewritePattern<AtenTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTensorOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> elements;
if (failed(getListOperands(op.getData(), elements)))
return failure();
if (elements.size() < 1)
return rewriter.notifyMatchFailure(op, "no elements");
auto front = elements.front();
for (auto element : elements)
if (element != front)
return rewriter.notifyMatchFailure(op, "multiple elements found");
if (elements.size() != 1)
return rewriter.notifyMatchFailure(op, "expected no elements");
auto resultTy = cast<BaseTensorType>(op.getType());
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(op, "dynamic output shape");
auto loc = op.getLoc();
llvm::SmallVector<Value> sizes;
for (auto size : resultTy.getSizes())
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(size)));
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(), 1);
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
one);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
rewriter.replaceOpWithNewOp<AtenFullOp>(op, resultTy, sizeList, front, none,
none, none, cstFalse);
return success();
}
};
} // namespace
namespace {
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
public:
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter) const override {
for (auto use : op->getResults())
if (!use.use_empty())
return failure();
rewriter.eraseOp(op);
return success();
}
};
} // namespace
namespace {
class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<PropagateAtenIndexSelectPattern, PropagateAtenItemPattern,
PropagateAtenShapeToTensorPattern,
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>,
RemoveUnusedPattern<Torch::ConstantBoolOp>,
RemoveUnusedPattern<Torch::ConstantIntOp>,
RemoveUnusedPattern<Torch::ConstantNoneOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>(context);
context->getLoadedDialect<mlir::arith::ArithDialect>()
->getCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createScalarizeShapesPass() {
return std::make_unique<ScalarizeShapesPass>();
}

View File

@ -70,6 +70,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm) { OpPassManager &pm) {
// We want to fuse quantized operations together before lowering to linalg. // We want to fuse quantized operations together before lowering to linalg.
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass()); pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
pm.addNestedPass<func::FuncOp>(Torch::createScalarizeShapesPass());
// Lower to linalg + guards which is the input to codegen backends. // Lower to linalg + guards which is the input to codegen backends.
// We do this first as it tends to involve pattern-matching against constants, // We do this first as it tends to involve pattern-matching against constants,

View File

@ -2054,21 +2054,8 @@ ONNX_XFAIL_SET = {
"ReduceMaxAlongDimUnsignedInt_basic", "ReduceMaxAlongDimUnsignedInt_basic",
# Failure - torch.aten.view lower # Failure - torch.aten.view lower
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic",
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"ViewFlattenAndExpandModule_basic",
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
"ViewSizeDimFollowedByExpandedOnesModule_basic", "ViewSizeDimFollowedByExpandedOnesModule_basic",
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
"ViewSizeDimLedAndFollowedByExpandedOnesModule_basic", "ViewSizeDimLedAndFollowedByExpandedOnesModule_basic",
"ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeDimLedByExpandedOnesModule_basic", "ViewSizeDimLedByExpandedOnesModule_basic",
# Failure - unknown # Failure - unknown
@ -2105,9 +2092,6 @@ ONNX_XFAIL_SET = {
"IndexTensorHackedTwinModule_basic", "IndexTensorHackedTwinModule_basic",
"IndexTensorModule3dInput_basic", "IndexTensorModule3dInput_basic",
"IndexTensorModule_basic", "IndexTensorModule_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
"IndexTensorSelectDimModule_basic", "IndexTensorSelectDimModule_basic",
"MaskedFillTensorFloatValueModule_basic", "MaskedFillTensorFloatValueModule_basic",
"ReduceAllDimEmpty_basic", "ReduceAllDimEmpty_basic",
@ -2124,5 +2108,19 @@ ONNX_XFAIL_SET = {
ONNX_CRASHING_SET = { ONNX_CRASHING_SET = {
"FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic",
# WIP for supporting reshape:
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
"IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic",
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
} }

View File

@ -324,12 +324,14 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)", "aten::square : (Tensor) -> (Tensor)",
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
"aten::zero : (Tensor) -> (Tensor)", "aten::zero : (Tensor) -> (Tensor)",
"aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)" "aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)"
]: ]:
emit_with_mutating_variants(key) emit_with_mutating_variants(key)
# Shape manipulations:
emit_with_mutating_variants("aten::unsqueeze : (Tensor, int) -> (Tensor)", has_folder=True)
# Elementwise tensor compute ops that don't have the standard mutating # Elementwise tensor compute ops that don't have the standard mutating
# variants. # variants.
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)

View File

@ -0,0 +1,74 @@
// RUN: torch-mlir-opt <%s --torch-scalarize-shapes -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: @shape_as_tensor
func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[3],si32> {
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[I5:.+]] = torch.constant.int 5
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]]
// CHECK-DAG: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[I2]]
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[I5]], %[[SZ1]], %[[SZ2]]
// CHECK-DAG: %[[TENSOR:.+]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]]
// CHECK: return %[[TENSOR]] : !torch.vtensor<[3],si32>
%0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
return %0 : !torch.vtensor<[3],si32>
}
// -----
// CHECK-LABEL: @shape_as_tensor_dim
func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> {
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]]
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]]
// CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]]
// CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32>
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
%dim = torch.constant.int 0
%idx = torch.vtensor.literal(dense<1> : tensor<si32>) : !torch.vtensor<[],si32>
%select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32>
return %select : !torch.vtensor<[],si32>
}
// -----
// CHECK-LABEL: @shape_as_tensor_dim_item
func.func @shape_as_tensor_dim_item(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.int {
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SZ:.+]] = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int
// CHECK: return %[[SZ]]
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
%dim = torch.constant.int 0
%idx = torch.vtensor.literal(dense<1> : tensor<si32>) : !torch.vtensor<[],si32>
%select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32>
%out = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int
return %out : !torch.int
}
// -----
// CHECK-LABEL: @shape_as_tensor_slice
func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torch.vtensor<[2],si32> {
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1]]
// CHECK-DAG: %[[SZ3:.+]] = torch.aten.size.int %arg0, %[[INT3]]
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[SZ1]], %[[SZ3]]
// CHECK-DAG: %[[TENSOR:.+]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]]
// CHECK: return %[[TENSOR]]
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?,?],f32> -> !torch.vtensor<[4],si32>
%dim = torch.constant.int 0
%start = torch.constant.int 1
%end = torch.constant.int 5
%step = torch.constant.int 2
%slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
return %slice : !torch.vtensor<[2],si32>
}