mirror of https://github.com/llvm/torch-mlir
[torch] Improve shape inference for `torch-to-linalg` path for reshapes (#3055)
Reshaping tensors depend on directly matching individual dimensions to their corresponding dim in the `torch.view` reshape dimensions. This involves decoupling dynamic dimensions from their static counterparts and support cleanup / canonicalization.pull/3062/head
parent
17eeac880a
commit
14b548f968
|
@ -3206,52 +3206,6 @@ def Torch_AtenSquare_Op : Torch_Op<"aten.square_", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
|
||||||
AllowsTypeRefinement,
|
|
||||||
ReadOnly
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::unsqueeze : (Tensor, int) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
AnyTorchTensorType:$self,
|
|
||||||
Torch_IntType:$dim
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
AnyTorchTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenUnsqueezeOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
|
||||||
}
|
|
||||||
void AtenUnsqueezeOp::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
|
|
||||||
IsTrailingUnderscoreInplaceVariant,
|
|
||||||
AllowsTypeRefinement
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::unsqueeze_ : (Tensor, int) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
Torch_NonValueTensorType:$self,
|
|
||||||
Torch_IntType:$dim
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
Torch_NonValueTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenUnsqueeze_Op::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
|
||||||
}
|
|
||||||
void AtenUnsqueeze_Op::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenZeroOp : Torch_Op<"aten.zero", [
|
def Torch_AtenZeroOp : Torch_Op<"aten.zero", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -3391,6 +3345,53 @@ def Torch_AtenFill_TensorOp : Torch_Op<"aten.fill_.Tensor", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::unsqueeze : (Tensor, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenUnsqueezeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenUnsqueezeOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::unsqueeze_ : (Tensor, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_NonValueTensorType:$self,
|
||||||
|
Torch_IntType:$dim
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
Torch_NonValueTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenUnsqueeze_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenUnsqueeze_Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
|
def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -105,6 +105,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>> createScalarizeShapesPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createFuseQuantizedOpsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createFuseQuantizedOpsPass();
|
||||||
|
|
|
@ -235,6 +235,17 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ScalarizeShapes : Pass<"torch-scalarize-shapes", "func::FuncOp"> {
|
||||||
|
let summary = "Takes common shape computation operations and scalarizes them.";
|
||||||
|
let constructor = "mlir::torch::Torch::createScalarizeShapesPass()";
|
||||||
|
let description = [{
|
||||||
|
Scalarizes shape computations to better propagate static shapes. As some
|
||||||
|
shape operation happen on tensors a single dynamic dimension can prevent
|
||||||
|
propagating static shapes. Scalarization prevents these dynamic
|
||||||
|
dimensions from blocking the statically computable operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
|
def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
|
||||||
let summary = "Recompose torch operations that have been decomposed by TorchScript";
|
let summary = "Recompose torch operations that have been decomposed by TorchScript";
|
||||||
let constructor = "mlir::torch::Torch::createRecomposeComplexOpsPass()";
|
let constructor = "mlir::torch::Torch::createRecomposeComplexOpsPass()";
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
add_subdirectory(TorchOnnxToTorch)
|
add_subdirectory(TorchOnnxToTorch)
|
||||||
|
add_subdirectory(TorchToArith)
|
||||||
add_subdirectory(TorchToLinalg)
|
add_subdirectory(TorchToLinalg)
|
||||||
add_subdirectory(TorchToSCF)
|
add_subdirectory(TorchToSCF)
|
||||||
add_subdirectory(TorchToArith)
|
|
||||||
add_subdirectory(TorchToTensor)
|
add_subdirectory(TorchToTensor)
|
||||||
add_subdirectory(TorchToTosa)
|
add_subdirectory(TorchToTosa)
|
||||||
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
if(TORCH_MLIR_ENABLE_STABLEHLO)
|
||||||
|
@ -12,9 +12,9 @@ add_subdirectory(TorchConversionToMLProgram)
|
||||||
add_subdirectory(Utils)
|
add_subdirectory(Utils)
|
||||||
|
|
||||||
# TODO: Automate this with add_torch_mlir_conversion_library.
|
# TODO: Automate this with add_torch_mlir_conversion_library.
|
||||||
set(linked_libs TorchMLIRTorchToLinalg
|
set(linked_libs TorchMLIRTorchToArith
|
||||||
|
TorchMLIRTorchToLinalg
|
||||||
TorchMLIRTorchToSCF
|
TorchMLIRTorchToSCF
|
||||||
TorchMLIRTorchToArith
|
|
||||||
TorchMLIRTorchToTensor
|
TorchMLIRTorchToTensor
|
||||||
TorchMLIRTorchToTosa
|
TorchMLIRTorchToTosa
|
||||||
TorchMLIRTorchToTMTensor
|
TorchMLIRTorchToTMTensor
|
||||||
|
|
|
@ -1720,7 +1720,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
rewriter.create<Torch::AtenEqIntOp>(binder.getLoc(), dim, zero);
|
rewriter.create<Torch::AtenEqIntOp>(binder.getLoc(), dim, zero);
|
||||||
isZero =
|
isZero =
|
||||||
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isZero);
|
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isZero);
|
||||||
Value adjustment;
|
Value adjustment = zero;
|
||||||
int64_t inputDimsSize = dataSizes.size();
|
int64_t inputDimsSize = dataSizes.size();
|
||||||
if (i < inputDimsSize) {
|
if (i < inputDimsSize) {
|
||||||
adjustment = rewriter.create<Torch::ConstantIntOp>(
|
adjustment = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
@ -1728,11 +1728,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
dataSizes[i]));
|
dataSizes[i]));
|
||||||
}
|
}
|
||||||
// Will never have a 0 in the shape tensor input at an index out of
|
|
||||||
// bounds of original input dims Therefore, no need to adjust
|
|
||||||
else {
|
|
||||||
adjustment = zero;
|
|
||||||
}
|
|
||||||
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||||
binder.getLoc(), isZero, adjustment);
|
binder.getLoc(), isZero, adjustment);
|
||||||
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
|
||||||
|
|
|
@ -239,6 +239,23 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenIntBoolOp : public OpConversionPattern<AtenIntBoolOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenIntBoolOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type resultType =
|
||||||
|
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||||
|
Value result =
|
||||||
|
convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType);
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenFloatScalarOp : public OpConversionPattern<AtenFloatScalarOp> {
|
class ConvertAtenFloatScalarOp : public OpConversionPattern<AtenFloatScalarOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -433,6 +450,9 @@ public:
|
||||||
target.addIllegalOp<AtenFloatScalarOp>();
|
target.addIllegalOp<AtenFloatScalarOp>();
|
||||||
patterns.add<ConvertAtenFloatScalarOp>(typeConverter, context);
|
patterns.add<ConvertAtenFloatScalarOp>(typeConverter, context);
|
||||||
|
|
||||||
|
target.addIllegalOp<AtenIntBoolOp>();
|
||||||
|
patterns.add<ConvertAtenIntBoolOp>(typeConverter, context);
|
||||||
|
|
||||||
target.addIllegalOp<AtenAddOp>();
|
target.addIllegalOp<AtenAddOp>();
|
||||||
patterns.add<ConvertAtenAddOp>(typeConverter, context);
|
patterns.add<ConvertAtenAddOp>(typeConverter, context);
|
||||||
|
|
||||||
|
|
|
@ -190,6 +190,61 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenFullOp : public OpConversionPattern<AtenFullOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenFullOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
SmallVector<Value> inShape;
|
||||||
|
if (!getListConstructElements(adaptor.getSize(), inShape)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: the size list is not from list construct");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultTy = cast<RankedTensorType>(
|
||||||
|
this->getTypeConverter()->convertType(op.getResult().getType()));
|
||||||
|
if (resultTy.getRank() != static_cast<int64_t>(inShape.size()))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "rank of shape and result shape do not match");
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> filteredShape;
|
||||||
|
for (int i = 0, s = resultTy.getRank(); i < s; ++i) {
|
||||||
|
if (resultTy.isDynamicDim(i)) {
|
||||||
|
filteredShape.push_back(inShape[i]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredShape.push_back(rewriter.getIndexAttr(resultTy.getDimSize(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value full = adaptor.getFillValue();
|
||||||
|
|
||||||
|
if (full.getType() != resultTy.getElementType()) {
|
||||||
|
if (isa<mlir::FloatType>(full.getType())) {
|
||||||
|
full = rewriter.create<arith::TruncFOp>(loc, resultTy.getElementType(),
|
||||||
|
full);
|
||||||
|
} else if (isa<mlir::IntegerType>(full.getType())) {
|
||||||
|
full = rewriter.create<arith::TruncIOp>(loc, resultTy.getElementType(),
|
||||||
|
full);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
|
loc, filteredShape, resultTy.getElementType());
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<linalg::FillOp>(op, full, outTensor);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Converts a tensor with one element to a scalar value.
|
// Converts a tensor with one element to a scalar value.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
|
@ -226,6 +281,9 @@ void mlir::torch::torch_to_linalg::
|
||||||
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
|
patterns.add<ConvertAtenScalarToTensorLike>(typeConverter, context);
|
||||||
target.addIllegalOp<PrimNumToTensorScalarOp>();
|
target.addIllegalOp<PrimNumToTensorScalarOp>();
|
||||||
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
|
patterns.add<ConvertPrimNumToTensorScalarOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenFullOp>();
|
||||||
|
patterns.add<ConvertAtenFullOp>(typeConverter, context);
|
||||||
|
|
||||||
patterns.add<ConvertAtenImplicitLikeOp<AtenScalarImplicitOp>>(typeConverter,
|
patterns.add<ConvertAtenImplicitLikeOp<AtenScalarImplicitOp>>(typeConverter,
|
||||||
context);
|
context);
|
||||||
patterns.add<ConvertAtenImplicitLikeOp<AtenFloatImplicitOp>>(typeConverter,
|
patterns.add<ConvertAtenImplicitLikeOp<AtenFloatImplicitOp>>(typeConverter,
|
||||||
|
|
|
@ -105,6 +105,43 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ConvertAtenTensorOpPattern : public OpConversionPattern<AtenTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenTensorOp>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenTensorOp::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenTensorOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto list = op.getData().getDefiningOp<Torch::PrimListConstructOp>();
|
||||||
|
if (!list)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto typeConverter = getTypeConverter();
|
||||||
|
auto resultTy = cast<ShapedType>(typeConverter->convertType(op.getType()));
|
||||||
|
auto resultETy = resultTy.getElementType();
|
||||||
|
|
||||||
|
SmallVector<Value> values;
|
||||||
|
for (Value operand : list.getOperands()) {
|
||||||
|
Value value = typeConverter->materializeTargetConversion(
|
||||||
|
rewriter, loc, typeConverter->convertType(operand.getType()),
|
||||||
|
operand);
|
||||||
|
|
||||||
|
if (isa<mlir::IntegerType>(resultETy) && value.getType() != resultETy)
|
||||||
|
value = rewriter.create<arith::TruncIOp>(loc, resultETy, value);
|
||||||
|
|
||||||
|
if (isa<mlir::FloatType>(resultETy) && value.getType() != resultETy)
|
||||||
|
value = rewriter.create<arith::TruncFOp>(loc, resultETy, value);
|
||||||
|
|
||||||
|
values.push_back(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, resultTy, values);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class ConvertTorchToTensor
|
class ConvertTorchToTensor
|
||||||
: public ConvertTorchToTensorBase<ConvertTorchToTensor> {
|
: public ConvertTorchToTensorBase<ConvertTorchToTensor> {
|
||||||
public:
|
public:
|
||||||
|
@ -118,6 +155,7 @@ public:
|
||||||
target.addLegalDialect<arith::ArithDialect>();
|
target.addLegalDialect<arith::ArithDialect>();
|
||||||
target.addLegalDialect<tensor::TensorDialect>();
|
target.addLegalDialect<tensor::TensorDialect>();
|
||||||
target.addIllegalOp<Torch::AtenItemOp>();
|
target.addIllegalOp<Torch::AtenItemOp>();
|
||||||
|
target.addIllegalOp<Torch::AtenTensorOp>();
|
||||||
target.addIllegalOp<Torch::Aten_ShapeAsTensorOp>();
|
target.addIllegalOp<Torch::Aten_ShapeAsTensorOp>();
|
||||||
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
|
@ -125,8 +163,8 @@ public:
|
||||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
patterns.add<ConvertAtenShapeToTensorPatternOp, ConvertAtenItemOp>(
|
patterns.add<ConvertAtenShapeToTensorPatternOp, ConvertAtenItemOp,
|
||||||
typeConverter, context);
|
ConvertAtenTensorOpPattern>(typeConverter, context);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
|
|
|
@ -715,16 +715,58 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) {
|
||||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b);
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenUnsqueezeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenUnsqueezeOp::fold(FoldAdaptor adaptor) {
|
||||||
|
auto selfTy = dyn_cast<BaseTensorType>(getSelf().getType());
|
||||||
|
auto rty = dyn_cast<BaseTensorType>(getType());
|
||||||
|
if (!rty.hasDtype())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
if (auto attr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf())) {
|
||||||
|
auto aty = dyn_cast<RankedTensorType>(attr.getType());
|
||||||
|
if (rty.hasSizes() && rty.areAllSizesKnown() && attr.isSplat()) {
|
||||||
|
auto naty = RankedTensorType::get(rty.getSizes(), aty.getElementType());
|
||||||
|
return DenseElementsAttr::get(naty, attr.getSplatValue<Attribute>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getSelf().getType() != getResult().getType())
|
||||||
|
return nullptr;
|
||||||
|
if (selfTy && rty) {
|
||||||
|
if (selfTy.hasSizes() && rty.hasSizes() &&
|
||||||
|
selfTy.getSizes().size() == rty.getSizes().size())
|
||||||
|
return getSelf();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenSqueezeOp
|
// AtenSqueezeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
||||||
if (getOperand().getType() != getResult().getType())
|
auto selfTy = dyn_cast<BaseTensorType>(getSelf().getType());
|
||||||
|
auto rty = dyn_cast<BaseTensorType>(getType());
|
||||||
|
if (!rty.hasDtype())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
if (auto attr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf())) {
|
||||||
|
auto aty = dyn_cast<RankedTensorType>(attr.getType());
|
||||||
|
if (rty.hasSizes() && rty.areAllSizesKnown() && attr.isSplat()) {
|
||||||
|
auto naty = RankedTensorType::get(rty.getSizes(), aty.getElementType());
|
||||||
|
return DenseElementsAttr::get(naty, attr.getSplatValue<Attribute>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getSelf().getType() != getResult().getType())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
if (selfTy && rty) {
|
||||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
if (selfTy.hasSizes() && rty.hasSizes() &&
|
||||||
return getOperand();
|
selfTy.getSizes().size() == rty.getSizes().size())
|
||||||
|
return getSelf();
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ add_mlir_library(TorchMLIRTorchPasses
|
||||||
ReifyShapeCalculations.cpp
|
ReifyShapeCalculations.cpp
|
||||||
ReifyDtypeCalculations.cpp
|
ReifyDtypeCalculations.cpp
|
||||||
ReifyAbstractInterpCalculationsUtils.cpp
|
ReifyAbstractInterpCalculationsUtils.cpp
|
||||||
|
ScalarizeShapes.cpp
|
||||||
AbstractInterpLibrary.cpp
|
AbstractInterpLibrary.cpp
|
||||||
SimplifyShapeCalculations.cpp
|
SimplifyShapeCalculations.cpp
|
||||||
SimplifyDtypeCalculations.cpp
|
SimplifyDtypeCalculations.cpp
|
||||||
|
|
|
@ -684,6 +684,13 @@ public:
|
||||||
/*keepDim=*/true),
|
/*keepDim=*/true),
|
||||||
op.getSelf(), dim, start, startPlusOne, /*step=*/one);
|
op.getSelf(), dim, start, startPlusOne, /*step=*/one);
|
||||||
|
|
||||||
|
auto sliceTy = cast<BaseTensorType>(slice.getType());
|
||||||
|
auto resultTy = cast<BaseTensorType>(op.getResult().getType());
|
||||||
|
if (sliceTy.getSizes().size() == resultTy.getSizes().size()) {
|
||||||
|
rewriter.replaceOp(op, slice);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after
|
// `aten.slice.tensor` doesn't squeeze the dim even when it's size 1 after
|
||||||
// slicing, while `aten.select.int` does.
|
// slicing, while `aten.select.int` does.
|
||||||
rewriter.replaceOpWithNewOp<AtenSqueezeDimOp>(op, op.getResult().getType(),
|
rewriter.replaceOpWithNewOp<AtenSqueezeDimOp>(op, op.getResult().getType(),
|
||||||
|
|
|
@ -0,0 +1,345 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
LogicalResult getListOperands(Value value, SmallVector<Value> &vals) {
|
||||||
|
auto list = value.getDefiningOp<Torch::PrimListConstructOp>();
|
||||||
|
if (!list)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (auto operand : list.getOperands())
|
||||||
|
vals.push_back(operand);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult getListFromTensor(Value value, SmallVector<Value> &vals) {
|
||||||
|
auto tensor = value.getDefiningOp<Torch::AtenTensorOp>();
|
||||||
|
if (!tensor)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return getListOperands(tensor.getData(), vals);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class PropagateAtenShapeToTensorPattern
|
||||||
|
: public OpRewritePattern<Aten_ShapeAsTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<Aten_ShapeAsTensorOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(Aten_ShapeAsTensorOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto self = op.getSelf();
|
||||||
|
auto selfTy = cast<BaseTensorType>(self.getType());
|
||||||
|
if (!selfTy.hasSizes())
|
||||||
|
return rewriter.notifyMatchFailure(op, "self has unknown rank");
|
||||||
|
|
||||||
|
int64_t rank = selfTy.getSizes().size();
|
||||||
|
SmallVector<Value> dims;
|
||||||
|
for (int64_t i = 0; i < rank; ++i) {
|
||||||
|
auto iv = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(i));
|
||||||
|
dims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
|
||||||
|
loc, rewriter.getType<Torch::IntType>(), self, iv));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto dimList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||||
|
dims);
|
||||||
|
|
||||||
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
|
loc, rewriter.getBoolAttr(false));
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
|
||||||
|
op, op.getType(), dimList, cstNone, cstNone, cstFalse);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class PropagateAtenIndexSelectPattern
|
||||||
|
: public OpRewritePattern<AtenIndexSelectOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenIndexSelectOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenIndexSelectOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
SmallVector<Value> elements;
|
||||||
|
if (failed(getListFromTensor(op.getSelf(), elements)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "requires a constant dim");
|
||||||
|
|
||||||
|
DenseElementsAttr idx;
|
||||||
|
if (!matchPattern(op.getIndex(), m_Constant(&idx)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "requires a constant index");
|
||||||
|
|
||||||
|
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
|
if (!selfTy.hasSizes())
|
||||||
|
return rewriter.notifyMatchFailure(op, "requires known rank");
|
||||||
|
|
||||||
|
auto selfShape = selfTy.getSizes();
|
||||||
|
int64_t selfRank = selfShape.size();
|
||||||
|
dim = dim < 0 ? dim + selfRank : dim;
|
||||||
|
int64_t dimLength = elements.size();
|
||||||
|
if (selfShape[dim] != dimLength)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "dim length does not match number of elements");
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < selfRank; ++i) {
|
||||||
|
if (i == dim)
|
||||||
|
continue;
|
||||||
|
if (selfShape[i] != 1)
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"expects unary non-dim dimension");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> selected;
|
||||||
|
if (idx.isSplat()) {
|
||||||
|
int64_t indexInt = idx.getSplatValue<APInt>().getSExtValue();
|
||||||
|
indexInt = indexInt < 0 ? indexInt + dimLength : indexInt;
|
||||||
|
selected.resize(idx.getNumElements(), elements[indexInt]);
|
||||||
|
} else {
|
||||||
|
for (APInt val : idx.getValues<APInt>()) {
|
||||||
|
int64_t indexInt = val.getSExtValue();
|
||||||
|
selected.push_back(elements[indexInt]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto eTy = elements.front().getType();
|
||||||
|
|
||||||
|
auto dimList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc, rewriter.getType<Torch::ListType>(eTy), selected);
|
||||||
|
|
||||||
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
|
loc, rewriter.getBoolAttr(false));
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
|
||||||
|
op, op.getType(), dimList, cstNone, cstNone, cstFalse);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Conversion attempts to handle some common propagatable slice cases, namely
|
||||||
|
// splatted values, no-op slices, known list of values, or any case where a
|
||||||
|
// new construction can be generated from a previous set of scalars allowing
|
||||||
|
// the parent tensor to be bypassed.
|
||||||
|
class PropagateAtenSliceTensorPattern
|
||||||
|
: public OpRewritePattern<AtenSliceTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenSliceTensorOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenSliceTensorOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
SmallVector<Value> elements;
|
||||||
|
if (failed(getListFromTensor(op.getSelf(), elements)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t dim, start, end, step;
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "requires a constant dim");
|
||||||
|
|
||||||
|
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "requires a constant start");
|
||||||
|
|
||||||
|
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "requires a constant end");
|
||||||
|
|
||||||
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "requires a constant step");
|
||||||
|
|
||||||
|
if (step < 0)
|
||||||
|
return rewriter.notifyMatchFailure(op, "requires a positive step value");
|
||||||
|
|
||||||
|
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
|
auto selfShape = selfTy.getSizes();
|
||||||
|
int64_t selfRank = selfShape.size();
|
||||||
|
|
||||||
|
// Correct for negative indexing:
|
||||||
|
dim = dim < 0 ? dim + selfRank : dim;
|
||||||
|
|
||||||
|
int64_t dimLength = elements.size();
|
||||||
|
start = start < 0 ? start + dimLength : start;
|
||||||
|
end = end < 0 ? end + dimLength : end;
|
||||||
|
|
||||||
|
start = start < 0 ? 0 : start;
|
||||||
|
end = end < 0 ? 0 : end;
|
||||||
|
end = end > dimLength ? dimLength : end;
|
||||||
|
|
||||||
|
if (selfShape[dim] != dimLength)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "dim length does not match number of elements");
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < selfRank; ++i) {
|
||||||
|
if (i == dim)
|
||||||
|
continue;
|
||||||
|
if (selfShape[i] != 1)
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"expects unary non-dim dimension");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> selected;
|
||||||
|
for (int i = start; i < end; i += step)
|
||||||
|
selected.push_back(elements[i]);
|
||||||
|
|
||||||
|
auto eTy = elements.front().getType();
|
||||||
|
auto dimList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc, rewriter.getType<Torch::ListType>(eTy), selected);
|
||||||
|
|
||||||
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
|
loc, rewriter.getBoolAttr(false));
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
|
||||||
|
op, op.getType(), dimList, cstNone, cstNone, cstFalse);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenItemOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenItemOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<Value> elements;
|
||||||
|
if (failed(getListFromTensor(op.getSelf(), elements)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (elements.size() != 1)
|
||||||
|
return rewriter.notifyMatchFailure(op, "expected no elements");
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, elements[0]);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class FoldAtenTensorSplatPattern : public OpRewritePattern<AtenTensorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenTensorOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenTensorOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<Value> elements;
|
||||||
|
if (failed(getListOperands(op.getData(), elements)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (elements.size() < 1)
|
||||||
|
return rewriter.notifyMatchFailure(op, "no elements");
|
||||||
|
|
||||||
|
auto front = elements.front();
|
||||||
|
for (auto element : elements)
|
||||||
|
if (element != front)
|
||||||
|
return rewriter.notifyMatchFailure(op, "multiple elements found");
|
||||||
|
|
||||||
|
if (elements.size() != 1)
|
||||||
|
return rewriter.notifyMatchFailure(op, "expected no elements");
|
||||||
|
|
||||||
|
auto resultTy = cast<BaseTensorType>(op.getType());
|
||||||
|
if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown())
|
||||||
|
return rewriter.notifyMatchFailure(op, "dynamic output shape");
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
llvm::SmallVector<Value> sizes;
|
||||||
|
for (auto size : resultTy.getSizes())
|
||||||
|
sizes.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(size)));
|
||||||
|
|
||||||
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getType<Torch::IntType>(), 1);
|
||||||
|
Value sizeList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||||
|
one);
|
||||||
|
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenFullOp>(op, resultTy, sizeList, front, none,
|
||||||
|
none, none, cstFalse);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T> class RemoveUnusedPattern : public OpRewritePattern<T> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<T>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(T op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
for (auto use : op->getResults())
|
||||||
|
if (!use.use_empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<arith::ArithDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
patterns.insert<PropagateAtenIndexSelectPattern, PropagateAtenItemPattern,
|
||||||
|
PropagateAtenShapeToTensorPattern,
|
||||||
|
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
|
||||||
|
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
|
||||||
|
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
|
||||||
|
RemoveUnusedPattern<Torch::AtenTensorOp>,
|
||||||
|
RemoveUnusedPattern<Torch::ConstantBoolOp>,
|
||||||
|
RemoveUnusedPattern<Torch::ConstantIntOp>,
|
||||||
|
RemoveUnusedPattern<Torch::ConstantNoneOp>,
|
||||||
|
RemoveUnusedPattern<Torch::PrimListConstructOp>>(context);
|
||||||
|
|
||||||
|
context->getLoadedDialect<mlir::arith::ArithDialect>()
|
||||||
|
->getCanonicalizationPatterns(patterns);
|
||||||
|
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||||
|
std::move(patterns)))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
mlir::torch::Torch::createScalarizeShapesPass() {
|
||||||
|
return std::make_unique<ScalarizeShapesPass>();
|
||||||
|
}
|
|
@ -70,6 +70,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
OpPassManager &pm) {
|
OpPassManager &pm) {
|
||||||
// We want to fuse quantized operations together before lowering to linalg.
|
// We want to fuse quantized operations together before lowering to linalg.
|
||||||
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
|
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
|
||||||
|
pm.addNestedPass<func::FuncOp>(Torch::createScalarizeShapesPass());
|
||||||
|
|
||||||
// Lower to linalg + guards which is the input to codegen backends.
|
// Lower to linalg + guards which is the input to codegen backends.
|
||||||
// We do this first as it tends to involve pattern-matching against constants,
|
// We do this first as it tends to involve pattern-matching against constants,
|
||||||
|
|
|
@ -2054,21 +2054,8 @@ ONNX_XFAIL_SET = {
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
|
|
||||||
# Failure - torch.aten.view lower
|
# Failure - torch.aten.view lower
|
||||||
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
|
||||||
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
|
||||||
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
|
||||||
"IndexTensorMultiInputContiguousCenter_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguous_basic",
|
|
||||||
"IndexTensorMultiInputOneDim_basic",
|
|
||||||
"IndexTensorMultiInputThreeIndexers_basic",
|
|
||||||
"IndexTensorMultiInput_basic",
|
|
||||||
"ViewFlattenAndExpandModule_basic",
|
|
||||||
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
|
|
||||||
"ViewSizeDimFollowedByExpandedOnesModule_basic",
|
"ViewSizeDimFollowedByExpandedOnesModule_basic",
|
||||||
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
|
|
||||||
"ViewSizeDimLedAndFollowedByExpandedOnesModule_basic",
|
"ViewSizeDimLedAndFollowedByExpandedOnesModule_basic",
|
||||||
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
|
||||||
"ViewSizeDimLedByExpandedOnesModule_basic",
|
"ViewSizeDimLedByExpandedOnesModule_basic",
|
||||||
|
|
||||||
# Failure - unknown
|
# Failure - unknown
|
||||||
|
@ -2105,9 +2092,6 @@ ONNX_XFAIL_SET = {
|
||||||
"IndexTensorHackedTwinModule_basic",
|
"IndexTensorHackedTwinModule_basic",
|
||||||
"IndexTensorModule3dInput_basic",
|
"IndexTensorModule3dInput_basic",
|
||||||
"IndexTensorModule_basic",
|
"IndexTensorModule_basic",
|
||||||
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
|
||||||
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
|
||||||
"IndexTensorSelectDimModule_basic",
|
"IndexTensorSelectDimModule_basic",
|
||||||
"MaskedFillTensorFloatValueModule_basic",
|
"MaskedFillTensorFloatValueModule_basic",
|
||||||
"ReduceAllDimEmpty_basic",
|
"ReduceAllDimEmpty_basic",
|
||||||
|
@ -2124,5 +2108,19 @@ ONNX_XFAIL_SET = {
|
||||||
ONNX_CRASHING_SET = {
|
ONNX_CRASHING_SET = {
|
||||||
"FakeQuantizePerTensorAffineModule_basic",
|
"FakeQuantizePerTensorAffineModule_basic",
|
||||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||||
|
|
||||||
|
# WIP for supporting reshape:
|
||||||
|
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||||
|
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||||
|
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||||
|
"IndexTensorMultiInputContiguousCenter_basic",
|
||||||
|
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
||||||
|
"IndexTensorMultiInputNonContiguous_basic",
|
||||||
|
"IndexTensorMultiInputOneDim_basic",
|
||||||
|
"IndexTensorMultiInputThreeIndexers_basic",
|
||||||
|
"IndexTensorMultiInput_basic",
|
||||||
|
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||||
|
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||||
|
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -324,12 +324,14 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
|
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||||
"aten::square : (Tensor) -> (Tensor)",
|
"aten::square : (Tensor) -> (Tensor)",
|
||||||
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
|
|
||||||
"aten::zero : (Tensor) -> (Tensor)",
|
"aten::zero : (Tensor) -> (Tensor)",
|
||||||
"aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)",
|
"aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)"
|
"aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)"
|
||||||
]:
|
]:
|
||||||
emit_with_mutating_variants(key)
|
emit_with_mutating_variants(key)
|
||||||
|
# Shape manipulations:
|
||||||
|
emit_with_mutating_variants("aten::unsqueeze : (Tensor, int) -> (Tensor)", has_folder=True)
|
||||||
|
|
||||||
# Elementwise tensor compute ops that don't have the standard mutating
|
# Elementwise tensor compute ops that don't have the standard mutating
|
||||||
# variants.
|
# variants.
|
||||||
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
|
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
// RUN: torch-mlir-opt <%s --torch-scalarize-shapes -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: @shape_as_tensor
|
||||||
|
func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[3],si32> {
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
|
||||||
|
// CHECK-DAG: %[[I5:.+]] = torch.constant.int 5
|
||||||
|
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]]
|
||||||
|
// CHECK-DAG: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[I2]]
|
||||||
|
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[I5]], %[[SZ1]], %[[SZ2]]
|
||||||
|
// CHECK-DAG: %[[TENSOR:.+]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]]
|
||||||
|
// CHECK: return %[[TENSOR]] : !torch.vtensor<[3],si32>
|
||||||
|
%0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
|
||||||
|
return %0 : !torch.vtensor<[3],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @shape_as_tensor_dim
|
||||||
|
func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> {
|
||||||
|
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]]
|
||||||
|
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]]
|
||||||
|
// CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]]
|
||||||
|
// CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32>
|
||||||
|
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
|
||||||
|
%dim = torch.constant.int 0
|
||||||
|
%idx = torch.vtensor.literal(dense<1> : tensor<si32>) : !torch.vtensor<[],si32>
|
||||||
|
%select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32>
|
||||||
|
return %select : !torch.vtensor<[],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @shape_as_tensor_dim_item
|
||||||
|
func.func @shape_as_tensor_dim_item(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.int {
|
||||||
|
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK-DAG: %[[SZ:.+]] = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: return %[[SZ]]
|
||||||
|
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
|
||||||
|
%dim = torch.constant.int 0
|
||||||
|
%idx = torch.vtensor.literal(dense<1> : tensor<si32>) : !torch.vtensor<[],si32>
|
||||||
|
%select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32>
|
||||||
|
%out = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
return %out : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @shape_as_tensor_slice
|
||||||
|
func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torch.vtensor<[2],si32> {
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||||
|
// CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1]]
|
||||||
|
// CHECK-DAG: %[[SZ3:.+]] = torch.aten.size.int %arg0, %[[INT3]]
|
||||||
|
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[SZ1]], %[[SZ3]]
|
||||||
|
// CHECK-DAG: %[[TENSOR:.+]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]]
|
||||||
|
// CHECK: return %[[TENSOR]]
|
||||||
|
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?,?],f32> -> !torch.vtensor<[4],si32>
|
||||||
|
%dim = torch.constant.int 0
|
||||||
|
%start = torch.constant.int 1
|
||||||
|
%end = torch.constant.int 5
|
||||||
|
%step = torch.constant.int 2
|
||||||
|
%slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
|
||||||
|
return %slice : !torch.vtensor<[2],si32>
|
||||||
|
}
|
Loading…
Reference in New Issue