diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td index 7c4283549..6c65bd31d 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td @@ -463,7 +463,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { $_op->getAttrs()); for (Region &r : $_op->getRegions()) r.cloneInto(state.addRegion(), bvm); - return b.createOperation(state); + return b.create(state); }] > ]; diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index ae1319772..46d212d3c 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -31,9 +31,7 @@ class TMTensor_Op traits = []> : TMTensorInterface, SingleBlockImplicitTerminator<"::mlir::torch::TMTensor::YieldOp"> ])> { - let verifier = [{ return verify$cppClass(*this); }]; - let printer = [{ return print$cppClass(p, *this); }]; - let parser = [{ return parse$cppClass(parser, result); }]; + let hasVerifier = 1; code extraTMTensorOpClassDeclaration = [{ SmallVector getDestinationOperands(OpBuilder &b) { SmallVector dest(outputs().begin(), outputs().end()); diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h index 99236d706..391280bbc 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h @@ -10,6 +10,7 @@ #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_ #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h index f4eea1b67..fda401bea 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h @@ -10,14 +10,15 @@ #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_ #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { namespace torch { namespace TMTensor { -std::unique_ptr> createTMTensorToLoopsPass(); -std::unique_ptr> createTMTensorBufferizePass(); +std::unique_ptr> createTMTensorToLoopsPass(); +std::unique_ptr> createTMTensorBufferizePass(); void registerPasses(); diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td index b5250e4f4..72e39b976 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.td @@ -13,12 +13,12 @@ include "mlir/Pass/PassBase.td" def TMTensorToLoops : - Pass<"tm-tensor-to-loops", "FuncOp"> { + Pass<"tm-tensor-to-loops", "func::FuncOp"> { let summary = "Convert TMTensor ops to loops and Linalg ops."; let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()"; } -def TMTensorBufferize : Pass<"tm-tensor-bufferize", "FuncOp"> { +def TMTensorBufferize : Pass<"tm-tensor-bufferize", "func::FuncOp"> { let summary = "Bufferize the TMTensor dialect"; let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()"; } diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index dcbfed50c..6e888a857 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -88,34 +88,34 @@ OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v, // ScanOp //===----------------------------------------------------------------------===// -static LogicalResult verifyScanOp(ScanOp op) { - if (op.getNumInputs() != 1) { - return op.emitOpError("expected one input operands"); +LogicalResult ScanOp::verify() { + if (getNumInputs() != 1) { + return emitOpError("expected one input operands"); } - if (op.getNumOutputs() != 2) { - return op.emitOpError("expected two output operands"); + if (getNumOutputs() != 2) { + return emitOpError("expected two output operands"); } - if (!op.input().getType().isa()) { - return op.emitOpError("expected first input element type to be shaped"); + if (!input().getType().isa()) { + return emitOpError("expected first input element type to be shaped"); } - auto accumulatorType = op.accumulator().getType().cast(); - auto inputType = op.input().getType().cast(); - auto outputType = op.output().getType().cast(); + auto accumulatorType = accumulator().getType().cast(); + auto inputType = input().getType().cast(); + auto outputType = output().getType().cast(); ArrayRef inputShapes = inputType.getShape(); ArrayRef outputShapes = outputType.getShape(); if (accumulatorType.getElementType() != inputType.getElementType()) { - return op.emitOpError( + return emitOpError( "expected input/accumulator element types to be identical"); } ArrayRef accumulatorShape = accumulatorType.getShape(); int64_t accumulatorRank = accumulatorType.getRank(); if (accumulatorRank != inputType.getRank() - 1) { - return op.emitOpError( + return emitOpError( "expected accumulator rank to be equal to input rank - 1"); } SmallVector expectedAccumulatorShape; for (size_t i = 0; i < (size_t)inputType.getRank(); i++) { - if (i != op.dimension()) + if (i != dimension()) expectedAccumulatorShape.push_back(inputShapes[i]); } if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape), @@ -124,14 +124,13 @@ static LogicalResult verifyScanOp(ScanOp op) { std::get<1>(s) != ShapedType::kDynamicSize && std::get<0>(s) != std::get<1>(s); })) { - return op.emitOpError("incompatible input/accumulator shapes"); + return emitOpError("incompatible input/accumulator shapes"); } if (inputType.getElementType() != outputType.getElementType()) { - return op.emitOpError( - "expected input/output element types to be identical"); + return emitOpError("expected input/output element types to be identical"); } if (inputShapes.size() != outputShapes.size()) { - return op.emitOpError("expected input/output to have identical ranks"); + return emitOpError("expected input/output to have identical ranks"); } if (llvm::any_of(llvm::zip(inputShapes, outputShapes), [](std::tuple s) { @@ -139,7 +138,7 @@ static LogicalResult verifyScanOp(ScanOp op) { std::get<1>(s) != ShapedType::kDynamicSize && std::get<0>(s) != std::get<1>(s); })) { - return op.emitOpError("incompatible input/output shapes"); + return emitOpError("incompatible input/output shapes"); } return success(); } @@ -232,11 +231,11 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, }); auto &srcBlock = region().front(); - Region ®ion = scfIf.getElseRegion(); + Region &thisRegion = scfIf.getElseRegion(); BlockAndValueMapping bvm; { OpBuilder::InsertionGuard guard(b); - auto &block = region.front(); + auto &block = thisRegion.front(); b.setInsertionPointToEnd(&block); for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) { bvm.map(std::get<0>(it), std::get<1>(it)); @@ -275,48 +274,47 @@ LogicalResult ScanOp::fold(ArrayRef, //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// -static LogicalResult verifyScatterOp(ScatterOp op) { - if (op.inputs().size() != 2) { - return op.emitOpError("expected two input operands"); +LogicalResult ScatterOp::verify() { + if (inputs().size() != 2) { + return emitOpError("expected two input operands"); } - if (op.outputs().size() != 1) { - return op.emitOpError("expected one output operand"); + if (outputs().size() != 1) { + return emitOpError("expected one output operand"); } auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) { return t1.getShape()[dim] == t2.getShape()[dim]; }; - auto indicesType = op.getIndicesType(); + auto indicesType = getIndicesType(); if (indicesType.getRank() != 2 || !indicesType.getElementType().isInteger(32)) { - return op.emitOpError( - "expected indices to be of rank 2 of i32 element type"); + return emitOpError("expected indices to be of rank 2 of i32 element type"); } - auto indexDepth = op.getIndexDepth(); + auto indexDepth = getIndexDepth(); if (indexDepth == ShapedType::kDynamicSize) { - return op.emitOpError("expected index depth is static"); + return emitOpError("expected index depth is static"); } // The first dimension of the indices should match the first dimension of the // output. They indicate to the number of updates. - auto updateType = op.getUpdateType(); + auto updateType = getUpdateType(); if (updateType.getRank() < 1) { - return op.emitOpError("expected update value to be at least rank 1"); + return emitOpError("expected update value to be at least rank 1"); } if (!checkDimensionsMatch(indicesType, updateType, 0)) { - return op.emitOpError( + return emitOpError( "mismatch in shape of indices and update value at dim#0"); } - auto originalType = op.getOriginalType(); + auto originalType = getOriginalType(); if (updateType.getRank() - 1 > originalType.getRank()) { - return op.emitOpError( + return emitOpError( "update value rank exceeds the rank of the original value"); } // indexDepth + update dims should cover the original dims. The first dim of // update is the number of updates. if (originalType.getRank() > indexDepth + updateType.getRank() - 1) { - return op.emitOpError( + return emitOpError( "index depth and update value does not cover rank of original value"); } @@ -331,7 +329,7 @@ static LogicalResult verifyScatterOp(ScatterOp op) { int64_t updateDim = std::get<1>(it); if (updateType.getDimSize(updateDim) != originalType.getDimSize(originalDim)) { - return op.emitOpError("mismatch in shape of update value dim#") + return emitOpError("mismatch in shape of update value dim#") << updateDim << " and original value at dim#" << originalDim; } } @@ -345,36 +343,36 @@ static LogicalResult verifyScatterOp(ScatterOp op) { int64_t updateDim = std::get<1>(it); if (updateType.getDimSize(updateDim) > originalType.getDimSize(originalDim)) { - return op.emitOpError("indexed shape of update value dim#") + return emitOpError("indexed shape of update value dim#") << updateDim << " exceeds original value at dim#" << originalDim << " " << updateType.getDimSize(updateDim) << " " << originalType.getDimSize(originalDim); } } - Region ®ion = op.region(); - Block *body = ®ion.front(); + Region &thisRegion = region(); + Block *body = &thisRegion.front(); if (body->getNumArguments() != 2) { - return op.emitOpError("expected region to have two arguments"); + return emitOpError("expected region to have two arguments"); } Type arg0Type = body->getArgument(0).getType(); Type arg1Type = body->getArgument(1).getType(); if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) { - return op.emitOpError( + return emitOpError( "expected region to have scalar argument of integer or float types"); } if (arg0Type != updateType.getElementType()) { - return op.emitOpError("mismatch in argument 0 of region ") + return emitOpError("mismatch in argument 0 of region ") << arg0Type << " and element type of update value " << updateType.getElementType(); } if (arg1Type != originalType.getElementType()) { - return op.emitOpError("mismatch in argument 1 of region ") + return emitOpError("mismatch in argument 1 of region ") << arg1Type << " and element type of original value " << originalType.getElementType(); } if (arg0Type != arg1Type) { - return op.emitOpError("mismatch in region argument types ") + return emitOpError("mismatch in region argument types ") << arg0Type << " and " << arg1Type; } auto yieldOp = cast(body->getTerminator()); @@ -455,7 +453,8 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, Value idx = b.create(loc, indices(), loadIndices); Value cast = b.create(loc, b.getIndexType(), idx); - if (starts[i]) cast = b.create(loc, cast, starts[i]); + if (starts[i]) + cast = b.create(loc, cast, starts[i]); starts[i] = cast; } diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index 3dfc2802a..e39c8413b 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinDialect.h" @@ -150,7 +151,7 @@ struct TMTensorBufferizePass }; } // namespace -std::unique_ptr> +std::unique_ptr> torch::TMTensor::createTMTensorBufferizePass() { return std::make_unique(); } diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp index fcab8e61c..a9b5c7cd2 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp @@ -7,6 +7,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -111,7 +112,7 @@ struct TMTensorToLoopsPass : public TMTensorToLoopsBase { }; } // namespace -std::unique_ptr> +std::unique_ptr> torch::TMTensor::createTMTensorToLoopsPass() { return std::make_unique(); } diff --git a/externals/llvm-project b/externals/llvm-project index 8361c5da3..e1318078a 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 8361c5da30588d3d4a48eae648f53be1feb5cfad +Subproject commit e1318078a4e160eb723bcbcfcdcc9a1b618f7067 diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 9760154c2..02a376d19 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -16,17 +16,17 @@ include "mlir/Pass/PassBase.td" // Torch conversions //===----------------------------------------------------------------------===// -def ConvertTorchToStd : Pass<"convert-torch-to-std", "FuncOp"> { +def ConvertTorchToStd : Pass<"convert-torch-to-std", "func::FuncOp"> { let summary = "Convert recognized Torch ops to Std ops"; let constructor = "mlir::torch::createConvertTorchToStdPass()"; } -def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "FuncOp"> { +def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> { let summary = "Convert recognized Torch ops to SCF ops"; let constructor = "mlir::torch::createConvertTorchToSCFPass()"; } -def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> { +def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { let summary = "Convert recognized Torch ops to Linalg ops"; let description = [{ Convert ATen ops to linalg ops. @@ -105,7 +105,7 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> { let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; } -def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> { +def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { let summary = "Convert Torch ops to TOSA ops"; let description = [{ This pass assumes that TOSA ops are responsible for emitting error @@ -114,7 +114,7 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> { let constructor = "mlir::torch::createConvertTorchToTosaPass()"; } -def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "FuncOp"> { +def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { let summary = "Convert recognized Torch ops to TMTensor/Linalg ops"; let description = [{ Convert ATen ops to tmtensor/linalg ops. diff --git a/include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h b/include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h index 6df74b700..a4ab67a64 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h @@ -10,12 +10,13 @@ #ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H #define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToLinalgPass(); +std::unique_ptr> createConvertTorchToLinalgPass(); } } // namespace mlir diff --git a/include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h b/include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h index a9b3ce8c3..7b869dae4 100644 --- a/include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h +++ b/include/torch-mlir/Conversion/TorchToSCF/TorchToSCF.h @@ -10,11 +10,12 @@ #ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H #define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToSCFPass(); +std::unique_ptr> createConvertTorchToSCFPass(); } } // namespace mlir diff --git a/include/torch-mlir/Conversion/TorchToStd/TorchToStd.h b/include/torch-mlir/Conversion/TorchToStd/TorchToStd.h index 371924725..3285bd5f0 100644 --- a/include/torch-mlir/Conversion/TorchToStd/TorchToStd.h +++ b/include/torch-mlir/Conversion/TorchToStd/TorchToStd.h @@ -10,12 +10,13 @@ #ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H #define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToStdPass(); +std::unique_ptr> createConvertTorchToStdPass(); } } // namespace mlir diff --git a/include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h b/include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h index dfa92fdef..2b42c3291 100644 --- a/include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h +++ b/include/torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h @@ -10,11 +10,12 @@ #ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H #define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToTMTensorPass(); +std::unique_ptr> createConvertTorchToTMTensorPass(); } } // namespace mlir diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index 41a53a696..a6d774a64 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -10,12 +10,13 @@ #ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H #define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToTosaPass(); +std::unique_ptr> createConvertTorchToTosaPass(); } } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index c643eb646..04cc55010 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -10,6 +10,8 @@ #ifndef TORCH_TYPES #define TORCH_TYPES +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/DialectBase.td" include "torch-mlir/Dialect/Torch/IR/TorchBase.td" //===----------------------------------------------------------------------===// @@ -24,28 +26,8 @@ class Torch_Type : Torch_Type { let parameters = (ins "::mlir::Type":$containedType); + let hasCustomAssemblyFormat = 1; - let printer = [{ - $_printer << "<"; - // Print the contained type without the `!torch.` prefix. - printTorchDialectType(getImpl()->containedType, $_printer); - $_printer << ">"; - }]; - - let parser = [{ - if ($_parser.parseLess()) - return Type(); - - // Parse the contained type, but forward directly to our internal parsing - // of `torch` dialect types, so that we can parse nested types without - // the `!torch.` prefix. - Type containedType = parseTorchDialectType($_parser); - if (!containedType) - return Type(); - if ($_parser.parseGreater()) - return Type(); - return get($_ctxt, containedType); - }]; let builders = [ TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{ return Base::get(containedType.getContext(), containedType); @@ -59,23 +41,7 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> { Represents an instance of a `torch.nn.Module` with the given `className`. }]; let parameters = (ins StringRefParameter<"class name">:$className); - - let printer = [{ - $_printer << "<\""; - llvm::printEscapedString(getImpl()->className, $_printer.getStream()); - $_printer << "\">"; - }]; - - let parser = [{ - if ($_parser.parseLess()) - return Type(); - std::string className; - if ($_parser.parseOptionalString(&className)) - return Type(); - if ($_parser.parseGreater()) - return Type(); - return get($_ctxt, className); - }]; + let hasCustomAssemblyFormat = 1; } // For standard ArrayRefs, which require allocation. @@ -186,6 +152,7 @@ class AnyTorchTensorType "::mlir::Type":$optionalDtype ); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; string extraBaseClassDeclaration = [{ }]; } @@ -243,6 +210,7 @@ def Torch_TupleType : Torch_Type<"Tuple", "tuple"> { let parameters = (ins ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes ); + let hasCustomAssemblyFormat = 1; } def Torch_UnionType : Torch_Type<"Union", "union"> { @@ -261,6 +229,7 @@ def Torch_UnionType : Torch_Type<"Union", "union"> { let parameters = (ins ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes ); + let hasCustomAssemblyFormat = 1; } def Torch_DeviceType : Torch_Type<"Device", "Device"> { @@ -367,30 +336,7 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> { let description = [{ Torch Dict type with key and value type. }]; - - let printer = [{ - $_printer << "<"; - printTorchDialectType(getImpl()->keyType, $_printer); - $_printer << ", "; - printTorchDialectType(getImpl()->valueType, $_printer); - $_printer<< ">"; - }]; - - let parser = [{ - if ($_parser.parseLess()) - return Type(); - Type keyType = parseTorchDialectType($_parser); - if (!keyType) - return Type(); - if ($_parser.parseComma()) - return Type(); - Type valueType = parseTorchDialectType($_parser); - if (!valueType) - return Type(); - if ($_parser.parseGreater()) - return Type(); - return get($_ctxt, keyType, valueType); - }]; + let hasCustomAssemblyFormat = 1; let builders = [ TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType, "::mlir::Type":$valueType), [{ diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index a951a50b4..0f86364d0 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -10,11 +10,14 @@ #ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H #define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include namespace mlir { +class ModuleOp; + namespace torch { namespace Torch { @@ -48,25 +51,26 @@ void createTorchShapeRefinementPipeline( std::unique_ptr> createAdjustCallingConventionsPass(); -std::unique_ptr> createRefineTypesPass(); +std::unique_ptr> createRefineTypesPass(); std::unique_ptr> createInlineGlobalSlotsPass(); -std::unique_ptr> createReduceOpVariantsPass(); +std::unique_ptr> createReduceOpVariantsPass(); -std::unique_ptr> createMaximizeValueSemanticsPass(); +std::unique_ptr> createMaximizeValueSemanticsPass(); std::unique_ptr> createRefinePublicReturnPass(); -std::unique_ptr> createDecomposeComplexOpsPass(); +std::unique_ptr> createDecomposeComplexOpsPass(); std::unique_ptr> createPreprocessShapeLibraryPass(); std::unique_ptr> createReifyShapeCalculationsPass(); -std::unique_ptr> createSimplifyShapeCalculationsPass(); +std::unique_ptr> +createSimplifyShapeCalculationsPass(); -std::unique_ptr> createDropShapeCalculationsPass(); +std::unique_ptr> createDropShapeCalculationsPass(); StringRef getShapeLibrary(); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 3a3ebd12e..448f490a5 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -126,7 +126,7 @@ def AdjustCallingConventions }]; } -def RefineTypes : Pass<"torch-refine-types", "FuncOp"> { +def RefineTypes : Pass<"torch-refine-types", "func::FuncOp"> { let summary = "Refine types"; let constructor = "mlir::torch::Torch::createRefineTypesPass()"; let description = [{ @@ -149,7 +149,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> { }]; } -def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> { +def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> { let summary = "Reduces variants of ops to a smaller set of ops."; let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()"; let description = [{ @@ -165,7 +165,7 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> { }]; } -def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "FuncOp"> { +def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "func::FuncOp"> { let summary = "Use value-semantic tensors where possible."; let description = [{ Use value-semantic tensors where possible to make the program more @@ -215,7 +215,7 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> { }]; } -def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> { +def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> { let summary = "Decompose complicated torch operations"; let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()"; let description = [{ @@ -238,7 +238,7 @@ def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> }]; } -def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "FuncOp"> { +def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func::FuncOp"> { let summary = "Simplify reified shape calculations."; let constructor = "mlir::torch::Torch::createSimplifyShapeCalculationsPass()"; let description = [{ @@ -246,7 +246,7 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "FuncO }]; } -def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "FuncOp"> { +def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"> { let summary = "Drop reified shape calculations."; let constructor = "mlir::torch::Torch::createDropShapeCalculationsPass()"; let description = [{ diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index 14e589836..ce6ef9da1 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -10,12 +10,15 @@ #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include namespace mlir { +class ModuleOp; + namespace torch { namespace TorchConversion { @@ -36,7 +39,7 @@ createVerifyInvariantsBeforeBackendLoweringPass(); std::unique_ptr> createFuncBackendTypeConversionPass(); -std::unique_ptr> +std::unique_ptr> createFinalizingBackendTypeConversionPass(); std::unique_ptr> diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 8b7df27bc..cbadd0b92 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -43,7 +43,7 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu } def FinalizingBackendTypeConversion - : Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> { + : Pass<"torch-finalizing-backend-type-conversion", "func::FuncOp"> { let summary = "Finalizes a partial conversion to builtin tensors"; let constructor = "mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()"; diff --git a/include/torch-mlir/RefBackend/Passes.h b/include/torch-mlir/RefBackend/Passes.h index 5238286b4..0b749df5c 100644 --- a/include/torch-mlir/RefBackend/Passes.h +++ b/include/torch-mlir/RefBackend/Passes.h @@ -10,10 +10,13 @@ #ifndef TORCHMLIR_REFBACKEND_PASSES_H #define TORCHMLIR_REFBACKEND_PASSES_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" namespace mlir { +class ModuleOp; + namespace torch { namespace RefBackend { @@ -22,13 +25,13 @@ void registerRefBackendPasses(); std::unique_ptr> createMungeCallingConventionsPass(); -std::unique_ptr> createExpandOpsForLLVMPass(); +std::unique_ptr> createExpandOpsForLLVMPass(); std::unique_ptr> createInsertRngGlobalsPass(); -std::unique_ptr> createMungeMemrefCopyPass(); +std::unique_ptr> createMungeMemrefCopyPass(); -std::unique_ptr> createGeneralizeTensorPadPass(); +std::unique_ptr> createGeneralizeTensorPadPass(); } // namespace RefBackend } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/RefBackend/Passes.td b/include/torch-mlir/RefBackend/Passes.td index 71588802b..518bc62f0 100644 --- a/include/torch-mlir/RefBackend/Passes.td +++ b/include/torch-mlir/RefBackend/Passes.td @@ -24,18 +24,18 @@ def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> { let dependentDialects = ["memref::MemRefDialect"]; } -def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "FuncOp"> { +def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> { let summary = "Expand ops into more primitive ops before LLVM lowering."; let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();"; } -def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "FuncOp"> { +def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> { let summary = "Munge memref.copy to linalg.copy"; let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();"; let dependentDialects = ["memref::MemRefDialect"]; } -def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "FuncOp"> { +def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> { let summary = "Convert tensor.pad to linalg ops"; let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()"; } diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h index 114578740..2e98b37d4 100644 --- a/lib/Conversion/PassDetail.h +++ b/lib/Conversion/PassDetail.h @@ -10,6 +10,7 @@ #ifndef TORCHMLIR_CONVERSION_PASSDETAIL_H #define TORCHMLIR_CONVERSION_PASSDETAIL_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 4509ccb22..f8ebc349f 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -88,7 +88,7 @@ public: }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::createConvertTorchToLinalgPass() { return std::make_unique(); } diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index d0b4cae61..b570e1d7f 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -91,7 +91,7 @@ public: }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::createConvertTorchToSCFPass() { return std::make_unique(); } diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index eafdd6287..a2275b9db 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -221,7 +222,7 @@ public: }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::createConvertTorchToStdPass() { return std::make_unique(); } diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 4c3f62d01..60b03d576 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -10,8 +10,10 @@ #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "../PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/MLIRContext.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" @@ -566,7 +568,7 @@ public: }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::createConvertTorchToTMTensorPass() { return std::make_unique(); } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 694eaa7a5..65fddcca4 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3170,7 +3170,7 @@ public: }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::createConvertTorchToTosaPass() { return std::make_unique(); } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 81f47e39e..0c7cc96a8 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index 99c2bc05b..1347026bb 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" @@ -124,7 +125,7 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op, unsigned argIndex, NamedAttribute namedAttr) { if (namedAttr.getName().getValue() == "torch.type_bound") { - auto func = dyn_cast(op); + auto func = dyn_cast(op); if (!func) return op->emitError() << "'torch.type_bound' must be attached to a func"; TypeAttr attr = namedAttr.getValue().dyn_cast(); @@ -134,7 +135,7 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op, if (!type) return op->emitError() << "'torch.type_bound' must be of " "!torch.tensor/!torch.vtensor type"; - if (!func.getType().getInput(argIndex).isa()) + if (!func.getFunctionType().getInput(argIndex).isa()) return op->emitError() << "'torch.type_bound' must be attached to an " "argument of !torch.tensor/!torch.vtensor type"; return success(); @@ -177,3 +178,100 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder, return nullptr; } + +//===----------------------------------------------------------------------===// +// OptionalType and ListType +//===----------------------------------------------------------------------===// + +void OptionalType::print(AsmPrinter &printer) const { + printer << "<"; + // Print the contained type without the `!torch.` prefix. + printTorchDialectType(getImpl()->containedType, printer); + printer << ">"; +} + +void ListType::print(AsmPrinter &printer) const { + printer << "<"; + // Print the contained type without the `!torch.` prefix. + printTorchDialectType(getImpl()->containedType, printer); + printer << ">"; +} + +Type OptionalType::parse(AsmParser &odsParser) { + if (odsParser.parseLess()) + return Type(); + + // Parse the contained type, but forward directly to our internal parsing + // of `torch` dialect types, so that we can parse nested types without + // the `!torch.` prefix. + Type containedType = parseTorchDialectType(odsParser); + if (!containedType) + return Type(); + if (odsParser.parseGreater()) + return Type(); + return get(odsParser.getContext(), containedType); +} + +Type ListType::parse(AsmParser &odsParser) { + if (odsParser.parseLess()) + return Type(); + + // Parse the contained type, but forward directly to our internal parsing + // of `torch` dialect types, so that we can parse nested types without + // the `!torch.` prefix. + Type containedType = parseTorchDialectType(odsParser); + if (!containedType) + return Type(); + if (odsParser.parseGreater()) + return Type(); + return get(odsParser.getContext(), containedType); +} + +//===----------------------------------------------------------------------===// +// DictType +//===----------------------------------------------------------------------===// + +void DictType::print(AsmPrinter &printer) const { + printer << "<"; + printTorchDialectType(getImpl()->keyType, printer); + printer << ", "; + printTorchDialectType(getImpl()->valueType, printer); + printer << ">"; +} + +Type DictType::parse(AsmParser &odsParser) { + if (odsParser.parseLess()) + return Type(); + Type keyType = parseTorchDialectType(odsParser); + if (!keyType) + return Type(); + if (odsParser.parseComma()) + return Type(); + Type valueType = parseTorchDialectType(odsParser); + if (!valueType) + return Type(); + if (odsParser.parseGreater()) + return Type(); + return get(odsParser.getContext(), keyType, valueType); +} + +//===----------------------------------------------------------------------===// +// NnModuleType +//===----------------------------------------------------------------------===// + +void NnModuleType::print(AsmPrinter &printer) const { + printer << "<\""; + llvm::printEscapedString(getImpl()->className, printer.getStream()); + printer << "\">"; +} + +Type NnModuleType::parse(AsmParser &odsParser) { + if (odsParser.parseLess()) + return Type(); + std::string className; + if (odsParser.parseOptionalString(&className)) + return Type(); + if (odsParser.parseGreater()) + return Type(); + return get(odsParser.getContext(), className); +} diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e586f5499..a56fd9521 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" @@ -99,7 +100,7 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) { LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto func = - symbolTable.lookupNearestSymbolFrom(*this, functionAttr()); + symbolTable.lookupNearestSymbolFrom(*this, functionAttr()); if (!func) return emitError() << "'@" << function() << "' does not reference a valid function"; @@ -112,8 +113,8 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { "merely declared)"; auto expectedReceiverArgType = NnModuleType::get( getContext(), getOperation()->getParentOfType().getName()); - if (func.getType().getNumInputs() == 0 || - func.getType().getInput(0) != expectedReceiverArgType) { + if (func.getFunctionType().getNumInputs() == 0 || + func.getFunctionType().getInput(0) != expectedReceiverArgType) { return emitError() << "the referenced function '" << function() << "' must have a first argument of type " << expectedReceiverArgType; @@ -278,7 +279,7 @@ ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) { Region *elseRegion = result.addRegion(); auto &builder = parser.getBuilder(); - OpAsmParser::OperandType cond; + OpAsmParser::UnresolvedOperand cond; Type boolType = builder.getType(); if (parser.parseOperand(cond) || parser.resolveOperand(cond, boolType, result.operands)) diff --git a/lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp b/lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp index c76947bfd..aee78ab93 100644 --- a/lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp +++ b/lib/Dialect/Torch/IR/UtilsForODSGenerated.cpp @@ -35,7 +35,7 @@ ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser, OperationState &result, int numOperands, int numResults) { llvm::SMLoc loc = parser.getCurrentLocation(); - SmallVector operands; + SmallVector operands; if (parser.parseOperandList(operands, /*requiredOperandCount=*/numOperands)) return failure(); if (parser.parseOptionalAttrDict(result.attributes)) diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index d241dcd38..265a5d8ab 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -31,11 +31,12 @@ using namespace mlir::torch::Torch; using TypeBoundMap = DenseMap, Type>; namespace { -class AdjustCallingConventionForFunc : public OpConversionPattern { +class AdjustCallingConventionForFunc + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(FuncOp func, OpAdaptor adaptor, + matchAndRewrite(func::FuncOp func, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = func.getContext(); auto typeBoundIdent = StringAttr::get(context, "torch.type_bound"); @@ -70,7 +71,7 @@ public: typeConverter); SmallVector newResultTypes; - for (auto type : func.getType().getResults()) { + for (auto type : func.getFunctionType().getResults()) { if (auto none = type.dyn_cast()) { continue; } @@ -186,7 +187,7 @@ public: }; } // namespace -static LogicalResult adjustCallingConventions(FuncOp func, +static LogicalResult adjustCallingConventions(func::FuncOp func, TypeBoundMap &typeBoundMap) { MLIRContext *context = func.getContext(); RewritePatternSet patterns(context); @@ -217,7 +218,7 @@ static LogicalResult adjustCallingConventions(FuncOp func, patterns.add(typeConverter, context); ConversionTarget target(*context); - target.addDynamicallyLegalOp([](FuncOp func) { + target.addDynamicallyLegalOp([](func::FuncOp func) { for (int i = 0, e = func.getNumArguments(); i != e; i++) { if (func.getArgAttr(i, "torch.type_bound")) return false; @@ -225,7 +226,7 @@ static LogicalResult adjustCallingConventions(FuncOp func, return false; } for (int i = 0, e = func.getNumResults(); i != e; i++) { - if (func.getType().getResults()[i].isa()) + if (func.getFunctionType().getResults()[i].isa()) return false; } return true; @@ -266,7 +267,7 @@ class AdjustCallingConventionsPass void runOnOperation() override { auto module = getOperation(); TypeBoundMap typeBoundMap; - for (auto func : module.getOps()) { + for (auto func : module.getOps()) { for (int i = 0, e = func.getNumArguments(); i != e; i++) { auto typeBoundAttr = func.getArgAttrOfType(i, "torch.type_bound"); @@ -275,7 +276,7 @@ class AdjustCallingConventionsPass typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue(); } } - for (auto func : module.getOps()) { + for (auto func : module.getOps()) { if (failed(adjustCallingConventions(func, typeBoundMap))) return signalPassFailure(); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d3dbb9f01..23279279d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1781,7 +1781,7 @@ class DecomposeComplexOpsPass } }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::Torch::createDecomposeComplexOpsPass() { return std::make_unique(); } diff --git a/lib/Dialect/Torch/Transforms/DropShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/DropShapeCalculations.cpp index 3a5fca9fb..aeb4b3dfb 100644 --- a/lib/Dialect/Torch/Transforms/DropShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/DropShapeCalculations.cpp @@ -54,7 +54,7 @@ class DropShapeCalculationsPass patterns.insert(context); ConversionTarget target(*context); target.addIllegalOp(); - target.addLegalOp(); + target.addLegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -64,7 +64,7 @@ class DropShapeCalculationsPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::Torch::createDropShapeCalculationsPass() { return std::make_unique(); } diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index c4172f9be..09f7d7d24 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -88,7 +88,7 @@ public: return it->second; } Optional getFuncLinkageInfo(NnModuleOp instance, - FuncOp methodFunc) { + func::FuncOp methodFunc) { auto it = funcLinkageInfo.find({instance, methodFunc}); if (it == funcLinkageInfo.end()) return None; @@ -185,7 +185,7 @@ private: for (auto method : classType.getOps()) { nameStack.push_back(method.name().str()); funcLinkageInfo[{nnModule, - symbolTable.lookup(method.function())}] = + symbolTable.lookup(method.function())}] = LinkageInfo{llvm::join(nameStack, "."), method.isPrivate()}; nameStack.pop_back(); } @@ -251,7 +251,7 @@ private: // Linkage info for each method in the program. Since we are going to be // monomorphizing all the functions, we also need to key this off of the // instance (NnModuleOp) that the func is monomorphized for. - DenseMap, LinkageInfo> funcLinkageInfo; + DenseMap, LinkageInfo> funcLinkageInfo; // The corresponding GlobalSlotOp for each SlotOp in the program. DenseMap slotToGlobalSlot; // A set of values that we have copied into torch.global_slot initializers, @@ -298,7 +298,7 @@ namespace { // any notion of "type" that we have in the IR, but still fits the formal // definition. struct Monomorphization { - FuncOp func; + func::FuncOp func; std::vector argInstances; }; } // namespace @@ -327,7 +327,7 @@ template <> struct llvm::DenseMapInfo { // // This generalizes to a full abstract interpretation of the function, but // currently only analyzes a subset of ops. -static LogicalResult analyzeInstances(FuncOp func, +static LogicalResult analyzeInstances(func::FuncOp func, ArrayRef argInstances, BlockAndValueMapping &mapping) { for (auto &argInstance : argInstances) @@ -351,7 +351,7 @@ static LogicalResult analyzeInstances(FuncOp func, static FailureOr createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping, SymbolTable &symbolTable) { - auto func = symbolTable.lookup(op.getCallee()); + auto func = symbolTable.lookup(op.getCallee()); Monomorphization monomorphization; monomorphization.func = func; for (auto operand : llvm::enumerate(op->getOperands())) { @@ -372,7 +372,7 @@ public: : module(module), symbolTable(module) {} LogicalResult initialize(DenseMap> &instances) { - for (auto func : module.getOps()) { + for (auto func : module.getOps()) { Monomorphization monomorphization; monomorphization.func = func; bool canTriviallyMonomorphize = true; @@ -455,7 +455,7 @@ static LogicalResult verifyNnModuleValueUses(Value value) { // Verify that `func` conforms to the subset of allowable method bodies // that we can convert. -static LogicalResult verifyFuncConformsToSubset(FuncOp func) { +static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) { // TODO: Investingate why WalkResult::interrupt() doesn't propagate properly. LogicalResult ret = success(); func.walk([&](Block *block) { @@ -481,7 +481,7 @@ static LogicalResult verifyFuncConformsToSubset(FuncOp func) { static LogicalResult verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable, MonomorphizationTracker &tracker) { - DenseMap numMonomorphizations; + DenseMap numMonomorphizations; for (auto &monomorphization : tracker.getMonomorphizations()) { numMonomorphizations[monomorphization.func] += 1; } @@ -489,7 +489,7 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable, for (auto classType : module.getOps()) { for (auto method : classType.getOps()) { if (!method.isPrivate()) { - if (numMonomorphizations[symbolTable.lookup( + if (numMonomorphizations[symbolTable.lookup( method.function())] > 1) { method.emitError() << "public function with multiple monomorphizations"; @@ -503,11 +503,10 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable, // Rewrite `func`, given that all values of `NnModuleType` have been mapped in // `mapping` to corresponding global instances. -static LogicalResult -rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping, - SymbolTable &symbolTable, - DenseMap &newFuncs, - ObjectGraphInfo &objectGraphInfo) { +static LogicalResult rewriteMonomorphizedFuncClone( + func::FuncOp func, BlockAndValueMapping mapping, SymbolTable &symbolTable, + DenseMap &newFuncs, + ObjectGraphInfo &objectGraphInfo) { SmallVector toErase; auto handlePrimSetAttr = [&](PrimSetAttrOp op) { @@ -605,7 +604,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { // static analysis to discover how to monomorphize th eprogram, including // tracking instances through control flow, through get/set attr, etc. We // implement a very simple subset of cases. - for (auto func : module.getOps()) { + for (auto func : module.getOps()) { if (failed(verifyFuncConformsToSubset(func))) return failure(); } @@ -637,10 +636,10 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { // Step 4: Clone/rewrite functions to implement the necessary // monomorphizations. - DenseMap newFuncs; + DenseMap newFuncs; int uniquifier = 0; for (auto &monomorphization : tracker.getMonomorphizations()) { - auto newFunc = cast(monomorphization.func->clone()); + auto newFunc = cast(monomorphization.func->clone()); newFuncs[monomorphization] = newFunc; Optional linkageInfo = None; // If it is potentially a method, check its linkage info. @@ -675,14 +674,14 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) { } // Step 5: Clean up object graph. - DenseSet liveFuncs; + DenseSet liveFuncs; for (auto &kv : newFuncs) { liveFuncs.insert(kv.second); } for (auto &op : llvm::make_early_inc_range(module.getOps())) { if (isa(&op)) continue; - if (auto func = dyn_cast(op)) { + if (auto func = dyn_cast(op)) { if (liveFuncs.contains(func)) continue; } diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 80cbda21e..4112f3eac 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -355,7 +355,7 @@ class MaximizeValueSemanticsPass } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::Torch::createMaximizeValueSemanticsPass() { return std::make_unique(); } diff --git a/lib/Dialect/Torch/Transforms/PassDetail.h b/lib/Dialect/Torch/Transforms/PassDetail.h index befa81149..85fc116fe 100644 --- a/lib/Dialect/Torch/Transforms/PassDetail.h +++ b/lib/Dialect/Torch/Transforms/PassDetail.h @@ -10,9 +10,11 @@ #ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H #define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { +class ModuleOp; namespace torch { namespace Torch { diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 325d467ad..94644c317 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -94,7 +94,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( if (options.optimize) { // Eliminate the PrimTupleIndexOp generated from the // adjustCallingConventions - pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCanonicalizerPass()); // Inline global slots, which for most inference scenarios deletes them. // This also exposes more information to intraprocedural transformations // below like MaximizeValueSemantics and RefineTypes. @@ -105,14 +105,14 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( } // Reduce variants of ops to a smaller set of primitives. - pm.addNestedPass(createReduceOpVariantsPass()); + pm.addNestedPass(createReduceOpVariantsPass()); if (options.optimize) { // OPT-ONLY: Right now we rely on this to eliminate certain branches that // guard unreachable code that backends can't handle yet, such as lists, // RaiseException, unimplemented tensor ops, and only-used-in-training // operations on `torch.global_slot`'s. - pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCanonicalizerPass()); // OPT-ONLY: We may have deleted some `torch.global_slot.get` / // `torch.global_slot.get` ops, which may have left more // `torch.global_slot`'s unused. @@ -124,7 +124,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( //===--------------------------------------------------------------------===// // Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's. - pm.addNestedPass(Torch::createMaximizeValueSemanticsPass()); + pm.addNestedPass(Torch::createMaximizeValueSemanticsPass()); // Do shape refinement. // This must be run before RefineTypes (which primarily does dtype inference), @@ -132,7 +132,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( // operand. createTorchShapeRefinementPipeline(pm, options); // Refine types in the program, which mainly means inferring dtypes of ops. - pm.addNestedPass(Torch::createRefineTypesPass()); + pm.addNestedPass(Torch::createRefineTypesPass()); // Propagate to ABI return types the shape/dtype information discovered by // the previous pass. Doing this is ABI-compatible for our backends. @@ -142,7 +142,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( // This can fold away some branches given the information got from // RefineTypes before doing maximize value sematics which only works with // basic blocks. - pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCanonicalizerPass()); } if (options.optimize) { @@ -152,9 +152,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( // branches that guard unreachable code that backends can't handle yet, such // as lists, RaiseException, unimplemented aten ops, and // only-used-in-training operations on `torch.global_slot`'s. - pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCanonicalizerPass()); } - pm.addNestedPass(Torch::createDecomposeComplexOpsPass()); + pm.addNestedPass(Torch::createDecomposeComplexOpsPass()); // TODO: VerifyTorchBackendContractPass. } @@ -172,11 +172,11 @@ void mlir::torch::Torch::createTorchShapeRefinementPipeline( // as hard as possible" kind of thing, so it's inherently somewhat brittle. // The idea is to keep strengthening what we do here to support the shape // library. We don't need to support arbitrary programs, thankfully. - pm.addNestedPass(Torch::createSimplifyShapeCalculationsPass()); + pm.addNestedPass(Torch::createSimplifyShapeCalculationsPass()); // Run CSE, then see if we can simplify further. - pm.addNestedPass(createCSEPass()); - pm.addNestedPass(Torch::createSimplifyShapeCalculationsPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(Torch::createSimplifyShapeCalculationsPass()); // Drop shape calculations, leaving behind the shape-refined program. - pm.addNestedPass(Torch::createDropShapeCalculationsPass()); + pm.addNestedPass(Torch::createDropShapeCalculationsPass()); } diff --git a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp index 21747bb8b..939141765 100644 --- a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp @@ -33,10 +33,10 @@ public: auto classType = symbolTable.lookup( op.receiver().getType().cast().getClassName()); assert(classType && "malformed module -- missing ClassTypeOp"); - FuncOp func; + func::FuncOp func; for (auto method : classType.getOps()) { if (method.name() == op.name()) { - func = symbolTable.lookup(method.function()); + func = symbolTable.lookup(method.function()); break; } } diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 86f070a0b..d47798b4a 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -207,7 +207,7 @@ public: assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 && "Torch JIT operators shouldn't have regions or successors"); - Operation *newOp = rewriter.createOperation(state); + Operation *newOp = rewriter.create(state); auto tensor = rewriter.create(op->getLoc(), newOp->getResult(0)); createOverwriteTensorContents(rewriter, op->getLoc(), tensor, @@ -276,7 +276,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::Torch::createReduceOpVariantsPass() { return std::make_unique(); } diff --git a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp index 7f21e9a02..4adf61346 100644 --- a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp +++ b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp @@ -25,7 +25,7 @@ class RefinePublicReturnPass : public RefinePublicReturnBase { void runOnOperation() override { auto module = getOperation(); - module.walk([&](FuncOp func) { + module.walk([&](func::FuncOp func) { if (func.getVisibility() != SymbolTable::Visibility::Public) return; if (func.isExternal()) @@ -40,7 +40,7 @@ class RefinePublicReturnPass }); } - void rewriteSignature(FuncOp func) { + void rewriteSignature(func::FuncOp func) { // Find the unique return op. func::ReturnOp returnOp; WalkResult walkResult = func.walk([&](func::ReturnOp op) { @@ -90,7 +90,7 @@ class RefinePublicReturnPass returnOp->setOperands(newOperands); // Update the function type. - auto funcType = func.getType(); + auto funcType = func.getFunctionType(); func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(), ValueRange(newOperands).getTypes())); } diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 93f7e78d0..ccb86b799 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -1191,7 +1191,7 @@ static bool isSafeToRefineOperandInPlace(OpOperand *use, Type newOperandType) { return operationIsValidWithRefinedType(use, newOperandType); } -void optimize(FuncOp func, TypeAnalyzer &analyzer) { +void optimize(func::FuncOp func, TypeAnalyzer &analyzer) { func.walk([&](Operation *op) { auto convertValuesToMostRefinedType = [&](ValueRange values, OpBuilder &b) { for (Value v : values) { @@ -1336,7 +1336,7 @@ class RefineTypesPass : public RefineTypesBase { }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::Torch::createRefineTypesPass() { return std::make_unique(); } diff --git a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp index 2214aa320..aae54cb5f 100644 --- a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp @@ -169,7 +169,7 @@ static Value adjustShapeFunctionArg(Value operand, Type desiredType, // from the shape library. static LogicalResult populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands, - mlir::FuncOp shapeFunction) { + mlir::func::FuncOp shapeFunction) { // Create a call to the shape function in the `shapeCalculation` region. // We will import the callee from the shape library later. OpBuilder b(op.getContext()); @@ -241,7 +241,7 @@ class ReifyShapeCalculationsPass name = name.drop_front(strlen("valsem.")); auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str(); auto shapeFunction = - shapeLibrary->lookupSymbol(shapeFunctionName); + shapeLibrary->lookupSymbol(shapeFunctionName); if (!shapeFunction) return; neededShapeFunctions.push_back(shapeFunctionName); @@ -276,7 +276,7 @@ class ReifyShapeCalculationsPass auto symName = worklist.pop_back_val(); if (importedFunctions.count(symName)) continue; - auto func = shapeLibrary->lookupSymbol(symName); + auto func = shapeLibrary->lookupSymbol(symName); assert(func && "broken shape library"); // Move the shape function from the library to the module this pass // is running on. (this mutates the library, but we re-parse it each time diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 91ed9faee..d3fb3ed99 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -415,7 +415,7 @@ class SimplifyShapeCalculationsPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::Torch::createSimplifyShapeCalculationsPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index e428e7624..3794602a8 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -45,9 +45,10 @@ struct FuncBackendTypeConversionPass typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); - target.addDynamicallyLegalOp([&](FuncOp op) { - return typeConverter.isSignatureLegal(op.getType()) && + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && typeConverter.isLegal(&op.getBody()); }); populateCallOpTypeConversionPattern(patterns, typeConverter); @@ -155,7 +156,7 @@ struct FinalizingBackendTypeConversionPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/PassDetail.h b/lib/Dialect/TorchConversion/Transforms/PassDetail.h index c5a0da964..224ad8e2d 100644 --- a/lib/Dialect/TorchConversion/Transforms/PassDetail.h +++ b/lib/Dialect/TorchConversion/Transforms/PassDetail.h @@ -10,9 +10,12 @@ #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { +class ModuleOp; + namespace torch { namespace TorchConversion { diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 208c59afc..b11e7f630 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -59,26 +59,27 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( // We do this first as it tends to involve pattern-matching against constants, // (e.g. dimensions which must be constant in a ranked programming model) // and those constants get somewhat obscured by TorchToStd. - pm.addNestedPass(createConvertTorchToTMTensorPass()); - pm.addNestedPass(createConvertTorchToLinalgPass()); - pm.addNestedPass(createConvertTorchToStdPass()); - pm.addNestedPass(createConvertTorchToSCFPass()); - pm.addNestedPass(memref::createExpandOpsPass()); + pm.addNestedPass(createConvertTorchToTMTensorPass()); + pm.addNestedPass(createConvertTorchToLinalgPass()); + pm.addNestedPass(createConvertTorchToStdPass()); + pm.addNestedPass(createConvertTorchToSCFPass()); + pm.addNestedPass(memref::createExpandOpsPass()); if (options.optimize) { // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCanonicalizerPass()); // Resolve `dim` ops on tensors (which currently live in the `memref` // dialect for some reason -- we don't have memrefs at this level). - pm.addNestedPass(memref::createResolveShapedTypeResultDimsPass()); + pm.addNestedPass( + memref::createResolveShapedTypeResultDimsPass()); // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createCSEPass()); } // Finish the type conversion from `torch` types to the types of the // linalg-on-tensors backend contract. pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); - pm.addNestedPass( + pm.addNestedPass( TorchConversion::createFinalizingBackendTypeConversionPass()); // Verify that we have lowered to the form that linalg on tensors backends @@ -93,21 +94,21 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( pm.addPass( TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); - pm.addNestedPass(createConvertTorchToTosaPass()); + pm.addNestedPass(createConvertTorchToTosaPass()); // Perform rank broadcasting so TosaToLinalg pass works - pm.addNestedPass(createTosaMakeBroadcastablePass()); + pm.addNestedPass(createTosaMakeBroadcastablePass()); if (options.optimize) { // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCanonicalizerPass()); // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createCSEPass()); } // Finish the type conversion from `torch` types to the types of the // TOSA backend contract. pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); - pm.addNestedPass( + pm.addNestedPass( TorchConversion::createFinalizingBackendTypeConversionPass()); // Verify that we have lowered to the form that TOSA backends diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp index ecfc1858c..75e7981c8 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp @@ -9,6 +9,8 @@ #include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -60,7 +62,7 @@ class VerifyLinalgOnTensorsBackendContractPass ConversionTarget target(*context); // Structural operations. - target.addDynamicallyLegalOp( + target.addDynamicallyLegalOp( opHasLegalTypes); target.addDynamicallyLegalOp(opHasLegalTypes); diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp index 16a5a57a7..e86948ebb 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp @@ -9,6 +9,7 @@ #include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -39,7 +40,7 @@ class VerifyTosaBackendContractPass ConversionTarget target(*context); // Structural operations. - target.addDynamicallyLegalOp( + target.addDynamicallyLegalOp( opHasLegalTypes); // Basic scalar operations. target.addLegalDialect(); diff --git a/lib/RefBackend/PassDetail.h b/lib/RefBackend/PassDetail.h index 630e0868b..aad2c3691 100644 --- a/lib/RefBackend/PassDetail.h +++ b/lib/RefBackend/PassDetail.h @@ -10,6 +10,7 @@ #ifndef REFBACKEND_PASSDETAIL_H #define REFBACKEND_PASSDETAIL_H +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 21366655d..313aeab72 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -15,7 +15,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -68,7 +68,7 @@ static bool isArgMemRefTypeValid(Type type) { return false; } -static void addEmitCInterfaceAttr(FuncOp func) { +static void addEmitCInterfaceAttr(func::FuncOp func) { func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext())); } @@ -115,7 +115,7 @@ static void replaceReturnWithCall(OpBuilder b, func::ReturnOp op, } static LogicalResult mungeFunction( - FuncOp func, + func::FuncOp func, std::map> &invokedConsumeFuncReturnFuncs) { // Only need to call mungeFunction for functions callable from outside of the // module. @@ -188,17 +188,17 @@ class MungeCallingConventions auto module = getOperation(); OpBuilder b(module.getBodyRegion()); std::map> invokedConsumeFuncReturnFuncs; - for (auto func : module.getOps()) { + for (auto func : module.getOps()) { if (failed(mungeFunction(func, invokedConsumeFuncReturnFuncs))) return signalPassFailure(); } // Create FuncOp for consumeFuncReturnFuncs that are used. for (auto &p : invokedConsumeFuncReturnFuncs) { - auto consumeFuncReturnFunc = - b.create(module.getLoc(), p.first, - FunctionType::get(module.getContext(), p.second, {}), - b.getStringAttr("private")); + auto consumeFuncReturnFunc = b.create( + module.getLoc(), p.first, + FunctionType::get(module.getContext(), p.second, {}), + b.getStringAttr("private")); addEmitCInterfaceAttr(consumeFuncReturnFunc); } } @@ -309,7 +309,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase { }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::RefBackend::createExpandOpsForLLVMPass() { return std::make_unique(); } @@ -366,7 +366,7 @@ class MungeMemrefCopy : public MungeMemrefCopyBase { }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::RefBackend::createMungeMemrefCopyPass() { return std::make_unique(); } @@ -390,7 +390,7 @@ class GeneralizeTensorPad }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::RefBackend::createGeneralizeTensorPadPass() { return std::make_unique(); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp index 753c31efe..6abebc2c8 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp @@ -36,8 +36,8 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp( MlirAttribute symNameAttr = mlirStringAttrGet( context, toMlirStringRef(function->qualname().qualifiedName())); MlirOperation func = createMlirOperation( - "builtin.func", loc, mlirRegionCreate(), - toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)), + "func.func", loc, mlirRegionCreate(), + toMlirNamedAttribute("function_type", mlirTypeAttrGet(functionType)), toMlirNamedAttribute("sym_name", symNameAttr)); std::vector argAttrDicts; for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) { diff --git a/python/torch_mlir/eager_mode/ir_building.py b/python/torch_mlir/eager_mode/ir_building.py index 31db5c3e5..aa7059faa 100644 --- a/python/torch_mlir/eager_mode/ir_building.py +++ b/python/torch_mlir/eager_mode/ir_building.py @@ -29,7 +29,7 @@ import torch from torch.jit import ScriptFunction from torch_mlir import ir -from torch_mlir.dialects.builtin import FuncOp +from torch_mlir.dialects.func import FuncOp from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 01b91807e..39f20e04f 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -115,15 +115,15 @@ class RefBackendInvoker: LOWERING_PIPELINE = ",".join([ - "builtin.func(refback-generalize-tensor-pad)", + "func.func(refback-generalize-tensor-pad)", # Bufferize. - "builtin.func(scf-bufferize)", - "builtin.func(tm-tensor-bufferize)", - "builtin.func(linalg-bufferize)", + "func.func(scf-bufferize)", + "func.func(tm-tensor-bufferize)", + "func.func(linalg-bufferize)", "func-bufferize", "arith-bufferize", - "builtin.func(tensor-bufferize)", - "builtin.func(finalizing-bufferize)", + "func.func(tensor-bufferize)", + "func.func(finalizing-bufferize)", # Munge to make it ExecutionEngine compatible. # Specifically, we rewrite calling convention boundaries to be in terms # of unranked memref, and we rewrite the return to actually be a @@ -135,17 +135,17 @@ LOWERING_PIPELINE = ",".join([ # global seed used in stateful rng. "refback-insert-rng-globals", # Lower to LLVM - "builtin.func(tm-tensor-to-loops)", - "builtin.func(refback-munge-memref-copy)", - "builtin.func(convert-linalg-to-loops)", - "builtin.func(lower-affine)", + "func.func(tm-tensor-to-loops)", + "func.func(refback-munge-memref-copy)", + "func.func(convert-linalg-to-loops)", + "func.func(lower-affine)", "convert-scf-to-cf", - "builtin.func(refback-expand-ops-for-llvm)", - "builtin.func(arith-expand)", - "builtin.func(convert-math-to-llvm)", + "func.func(refback-expand-ops-for-llvm)", + "func.func(arith-expand)", + "func.func(convert-math-to-llvm)", "convert-linalg-to-llvm", "convert-memref-to-llvm", - "builtin.func(convert-arith-to-llvm)", + "func.func(convert-arith-to-llvm)", "convert-func-to-llvm", "convert-cf-to-llvm", "reconcile-unrealized-casts", diff --git a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py index e26ccd34f..eb98371fb 100644 --- a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py @@ -38,25 +38,25 @@ class LinalgOnTensorsTosaBackend(TosaBackend): """ # TOSA legalization may emit tosa.const() ops. These are legalized - # by tosa-to-standard to arith.constants. This mechanical transformation + # by tosa-to-arith to arith.constants. This mechanical transformation # must be done prior to TOSA-to-LinAlg so that the latter does not fail. # This is an artifact of legalizations spread across a collection of simple # ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg, # that depend on TOSA as well as TOSA-to-Standard. run_pipeline_with_repro_report( imported_module, - "builtin.func(tosa-to-standard)", - "Lowering TOSA to Standard") + "func.func(tosa-to-arith)", + "Lowering TOSA to Arith") # Named ops must be legalized prior to general tosa-to-linalg run_pipeline_with_repro_report( imported_module, - "builtin.func(tosa-to-linalg-named)", + "func.func(tosa-to-linalg-named)", "Lowering TOSA to Linalg-on-Tensors for Named Ops") run_pipeline_with_repro_report( imported_module, - "builtin.func(tosa-to-linalg)", + "func.func(tosa-to-linalg)", "Lowering TOSA to Linalg-on-Tensors") return self.refbackend.compile(imported_module) diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index d7f3daa7c..be3575904 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -1,7 +1,7 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @forward -builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir index a67ba4d46..43734b6e8 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/error.mlir @@ -42,7 +42,7 @@ torch.nn_module { torch.slot "t1", %t : !torch.tensor torch.slot "t2", %t : !torch.tensor } : !torch.nn.Module<"c"> -builtin.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor { +func.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor { %t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor %t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor %cst = torch.constant.int 1 @@ -63,7 +63,7 @@ torch.nn_module { torch.slot "t1", %t : !torch.tensor torch.slot "t2", %t : !torch.tensor } : !torch.nn.Module<"c"> -builtin.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) { +func.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) { torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor return diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index a8401b498..d960713f2 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -4,8 +4,8 @@ torch.class_type @c {} %0 = torch.nn_module { - // expected-error @+1 {{'builtin.func' op is not allowed inside 'torch.nn_module'}} - builtin.func @f() + // expected-error @+1 {{'func.func' op is not allowed inside 'torch.nn_module'}} + func.func @f() } : !torch.nn.Module<"c"> // ----- @@ -32,8 +32,8 @@ torch.class_type @c { // ----- torch.class_type @c { - // expected-error @+1 {{'builtin.func' op is not allowed inside `torch.class_type`}} - builtin.func @f() + // expected-error @+1 {{'func.func' op is not allowed inside `torch.class_type`}} + func.func @f() } // ----- @@ -60,7 +60,7 @@ torch.class_type @c { torch.method "f", @f } -builtin.func @f(%arg0: !torch.nn.Module<"c">) { +func.func @f(%arg0: !torch.nn.Module<"c">) { return } @@ -71,11 +71,11 @@ torch.class_type @c { torch.method "f", @f } -builtin.func private @f(%arg0: !torch.nn.Module<"c">) +func.func private @f(%arg0: !torch.nn.Module<"c">) // ----- -builtin.func private @f() { +func.func private @f() { return } torch.class_type @c { @@ -85,7 +85,7 @@ torch.class_type @c { // ----- -builtin.func private @f(!torch.nn.Module<"other_c">) { +func.func private @f(%arg0: !torch.nn.Module<"other_c">) { return } torch.class_type @c { @@ -101,21 +101,21 @@ torch.class_type @c { // ----- // expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}} -builtin.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>}) +func.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>}) // ----- // expected-error @+1 {{'torch.type_bound' must be TypeAttr}} -builtin.func @f(%arg0: i32 {torch.type_bound = 1}) +func.func @f(%arg0: i32 {torch.type_bound = 1}) // ----- // expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}} -builtin.func @f(%arg0: i32 {torch.type_bound = i32}) +func.func @f(%arg0: i32 {torch.type_bound = i32}) // ----- -builtin.func @derefine(%arg0: !torch.optional) -> !torch.tensor { +func.func @derefine(%arg0: !torch.optional) -> !torch.tensor { // expected-error @+1 {{operand type '!torch.optional' and result type '!torch.tensor' are cast incompatible}} %0 = torch.derefine %arg0 : !torch.optional to !torch.tensor return %0 : !torch.tensor @@ -123,7 +123,7 @@ builtin.func @derefine(%arg0: !torch.optional) -> !torch.tensor { // ----- -builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional { +func.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional { // expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional' are cast incompatible}} %0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional return %0 : !torch.optional @@ -132,11 +132,11 @@ builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> ! // ----- // expected-error @+1 {{invalid dtype 'tuple<>' for !torch.tensor type}} -builtin.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>> +func.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>> // ----- -builtin.func @torch.tensor() { +func.func @torch.tensor() { // Incompatible shape. // expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}} %0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32> @@ -145,7 +145,7 @@ builtin.func @torch.tensor() { // ----- -builtin.func @torch.tensor() { +func.func @torch.tensor() { // Incompatible dtype. // expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}} %0 = torch.tensor.literal(dense<42.0> : tensor) : !torch.vtensor<[],f64> @@ -154,7 +154,7 @@ builtin.func @torch.tensor() { // ----- -builtin.func @torch.tensor() { +func.func @torch.tensor() { // Incompatible type. // expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}} %0 = torch.tensor.literal(dense<42.0> : tensor) : i1 @@ -163,7 +163,7 @@ builtin.func @torch.tensor() { // ----- -builtin.func @torch.prim.ListConstruct() { +func.func @torch.prim.ListConstruct() { %int2 = torch.constant.int 2 // expected-error@+1 {{operand types should have the same type as the list contained type}} torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list @@ -172,7 +172,7 @@ builtin.func @torch.prim.ListConstruct() { // ----- -builtin.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> { +func.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> { %0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32> // expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}} torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32> diff --git a/test/Dialect/Torch/promote-types.mlir b/test/Dialect/Torch/promote-types.mlir index 1e49447aa..5a2484a10 100644 --- a/test/Dialect/Torch/promote-types.mlir +++ b/test/Dialect/Torch/promote-types.mlir @@ -9,7 +9,7 @@ // CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64> // CHECK: return -builtin.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>, +func.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>, %t1: !torch.vtensor<[1],f64>, %alpha: !torch.float) { %1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk> @@ -25,7 +25,7 @@ builtin.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1 // CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK-SAME: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64> // CHECK: return -builtin.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>, +func.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>, %t1: !torch.vtensor<[1],f64>, %alpha: !torch.float) { %1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk> @@ -41,7 +41,7 @@ builtin.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>, // CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],f32> // CHECK: return -builtin.func @tensor_tensor$same_category_zero_rank_wider( +func.func @tensor_tensor$same_category_zero_rank_wider( %t0: !torch.vtensor<[1],f32>, %t1: !torch.vtensor<[],f64>, %alpha: !torch.int) { @@ -58,7 +58,7 @@ builtin.func @tensor_tensor$same_category_zero_rank_wider( // CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK-SAME: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32> // CHECK: return -builtin.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>, +func.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>, %t1: !torch.vtensor<[],f32>, %alpha: !torch.int) { %1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk> @@ -73,7 +73,7 @@ builtin.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si // CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],f32> // CHECK: return -builtin.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>, +func.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>, %t1: !torch.vtensor<[1],f32>, %alpha: !torch.float) { %1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk> @@ -89,7 +89,7 @@ builtin.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1], // CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK-SAME: !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],f32> // CHECK: return -builtin.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) { +func.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) { %1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk> return } @@ -103,7 +103,7 @@ builtin.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64> // CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK-SAME: !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> // CHECK: return -builtin.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) { +func.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) { %1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk> return }