llvm: bump tag to e1318078 (#781)

The updated LLVM code includes a patch to create bfloat16 array
attributes, thus enabling a different patch to torch-mlir to flesh out
support for the bfloat16 type.
pull/761/head snapshot-20220426.416
Ashay Rane 2022-04-26 12:27:51 -07:00 committed by GitHub
parent 9ec4712516
commit 9208bf0eb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 377 additions and 300 deletions

View File

@ -463,7 +463,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
$_op->getAttrs()); $_op->getAttrs());
for (Region &r : $_op->getRegions()) for (Region &r : $_op->getRegions())
r.cloneInto(state.addRegion(), bvm); r.cloneInto(state.addRegion(), bvm);
return b.createOperation(state); return b.create(state);
}] }]
> >
]; ];

View File

@ -31,9 +31,7 @@ class TMTensor_Op<string mnemonic, list<Trait> traits = []> :
TMTensorInterface, TMTensorInterface,
SingleBlockImplicitTerminator<"::mlir::torch::TMTensor::YieldOp"> SingleBlockImplicitTerminator<"::mlir::torch::TMTensor::YieldOp">
])> { ])> {
let verifier = [{ return verify$cppClass(*this); }]; let hasVerifier = 1;
let printer = [{ return print$cppClass(p, *this); }];
let parser = [{ return parse$cppClass(parser, result); }];
code extraTMTensorOpClassDeclaration = [{ code extraTMTensorOpClassDeclaration = [{
SmallVector<Value> getDestinationOperands(OpBuilder &b) { SmallVector<Value> getDestinationOperands(OpBuilder &b) {
SmallVector<Value> dest(outputs().begin(), outputs().end()); SmallVector<Value> dest(outputs().begin(), outputs().end());

View File

@ -10,6 +10,7 @@
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_ #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_ #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {

View File

@ -10,14 +10,15 @@
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_ #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_ #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace torch { namespace torch {
namespace TMTensor { namespace TMTensor {
std::unique_ptr<OperationPass<FuncOp>> createTMTensorToLoopsPass(); std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorToLoopsPass();
std::unique_ptr<OperationPass<FuncOp>> createTMTensorBufferizePass(); std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorBufferizePass();
void registerPasses(); void registerPasses();

View File

@ -13,12 +13,12 @@
include "mlir/Pass/PassBase.td" include "mlir/Pass/PassBase.td"
def TMTensorToLoops : def TMTensorToLoops :
Pass<"tm-tensor-to-loops", "FuncOp"> { Pass<"tm-tensor-to-loops", "func::FuncOp"> {
let summary = "Convert TMTensor ops to loops and Linalg ops."; let summary = "Convert TMTensor ops to loops and Linalg ops.";
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()"; let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
} }
def TMTensorBufferize : Pass<"tm-tensor-bufferize", "FuncOp"> { def TMTensorBufferize : Pass<"tm-tensor-bufferize", "func::FuncOp"> {
let summary = "Bufferize the TMTensor dialect"; let summary = "Bufferize the TMTensor dialect";
let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()"; let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()";
} }

View File

@ -88,34 +88,34 @@ OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
// ScanOp // ScanOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verifyScanOp(ScanOp op) { LogicalResult ScanOp::verify() {
if (op.getNumInputs() != 1) { if (getNumInputs() != 1) {
return op.emitOpError("expected one input operands"); return emitOpError("expected one input operands");
} }
if (op.getNumOutputs() != 2) { if (getNumOutputs() != 2) {
return op.emitOpError("expected two output operands"); return emitOpError("expected two output operands");
} }
if (!op.input().getType().isa<ShapedType>()) { if (!input().getType().isa<ShapedType>()) {
return op.emitOpError("expected first input element type to be shaped"); return emitOpError("expected first input element type to be shaped");
} }
auto accumulatorType = op.accumulator().getType().cast<ShapedType>(); auto accumulatorType = accumulator().getType().cast<ShapedType>();
auto inputType = op.input().getType().cast<ShapedType>(); auto inputType = input().getType().cast<ShapedType>();
auto outputType = op.output().getType().cast<ShapedType>(); auto outputType = output().getType().cast<ShapedType>();
ArrayRef<int64_t> inputShapes = inputType.getShape(); ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape(); ArrayRef<int64_t> outputShapes = outputType.getShape();
if (accumulatorType.getElementType() != inputType.getElementType()) { if (accumulatorType.getElementType() != inputType.getElementType()) {
return op.emitOpError( return emitOpError(
"expected input/accumulator element types to be identical"); "expected input/accumulator element types to be identical");
} }
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape(); ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
int64_t accumulatorRank = accumulatorType.getRank(); int64_t accumulatorRank = accumulatorType.getRank();
if (accumulatorRank != inputType.getRank() - 1) { if (accumulatorRank != inputType.getRank() - 1) {
return op.emitOpError( return emitOpError(
"expected accumulator rank to be equal to input rank - 1"); "expected accumulator rank to be equal to input rank - 1");
} }
SmallVector<int64_t> expectedAccumulatorShape; SmallVector<int64_t> expectedAccumulatorShape;
for (size_t i = 0; i < (size_t)inputType.getRank(); i++) { for (size_t i = 0; i < (size_t)inputType.getRank(); i++) {
if (i != op.dimension()) if (i != dimension())
expectedAccumulatorShape.push_back(inputShapes[i]); expectedAccumulatorShape.push_back(inputShapes[i]);
} }
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape), if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
@ -124,14 +124,13 @@ static LogicalResult verifyScanOp(ScanOp op) {
std::get<1>(s) != ShapedType::kDynamicSize && std::get<1>(s) != ShapedType::kDynamicSize &&
std::get<0>(s) != std::get<1>(s); std::get<0>(s) != std::get<1>(s);
})) { })) {
return op.emitOpError("incompatible input/accumulator shapes"); return emitOpError("incompatible input/accumulator shapes");
} }
if (inputType.getElementType() != outputType.getElementType()) { if (inputType.getElementType() != outputType.getElementType()) {
return op.emitOpError( return emitOpError("expected input/output element types to be identical");
"expected input/output element types to be identical");
} }
if (inputShapes.size() != outputShapes.size()) { if (inputShapes.size() != outputShapes.size()) {
return op.emitOpError("expected input/output to have identical ranks"); return emitOpError("expected input/output to have identical ranks");
} }
if (llvm::any_of(llvm::zip(inputShapes, outputShapes), if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
[](std::tuple<int64_t, int64_t> s) { [](std::tuple<int64_t, int64_t> s) {
@ -139,7 +138,7 @@ static LogicalResult verifyScanOp(ScanOp op) {
std::get<1>(s) != ShapedType::kDynamicSize && std::get<1>(s) != ShapedType::kDynamicSize &&
std::get<0>(s) != std::get<1>(s); std::get<0>(s) != std::get<1>(s);
})) { })) {
return op.emitOpError("incompatible input/output shapes"); return emitOpError("incompatible input/output shapes");
} }
return success(); return success();
} }
@ -232,11 +231,11 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
}); });
auto &srcBlock = region().front(); auto &srcBlock = region().front();
Region &region = scfIf.getElseRegion(); Region &thisRegion = scfIf.getElseRegion();
BlockAndValueMapping bvm; BlockAndValueMapping bvm;
{ {
OpBuilder::InsertionGuard guard(b); OpBuilder::InsertionGuard guard(b);
auto &block = region.front(); auto &block = thisRegion.front();
b.setInsertionPointToEnd(&block); b.setInsertionPointToEnd(&block);
for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) { for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
bvm.map(std::get<0>(it), std::get<1>(it)); bvm.map(std::get<0>(it), std::get<1>(it));
@ -275,48 +274,47 @@ LogicalResult ScanOp::fold(ArrayRef<Attribute>,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ScatterOp // ScatterOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verifyScatterOp(ScatterOp op) { LogicalResult ScatterOp::verify() {
if (op.inputs().size() != 2) { if (inputs().size() != 2) {
return op.emitOpError("expected two input operands"); return emitOpError("expected two input operands");
} }
if (op.outputs().size() != 1) { if (outputs().size() != 1) {
return op.emitOpError("expected one output operand"); return emitOpError("expected one output operand");
} }
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) { auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
return t1.getShape()[dim] == t2.getShape()[dim]; return t1.getShape()[dim] == t2.getShape()[dim];
}; };
auto indicesType = op.getIndicesType(); auto indicesType = getIndicesType();
if (indicesType.getRank() != 2 || if (indicesType.getRank() != 2 ||
!indicesType.getElementType().isInteger(32)) { !indicesType.getElementType().isInteger(32)) {
return op.emitOpError( return emitOpError("expected indices to be of rank 2 of i32 element type");
"expected indices to be of rank 2 of i32 element type");
} }
auto indexDepth = op.getIndexDepth(); auto indexDepth = getIndexDepth();
if (indexDepth == ShapedType::kDynamicSize) { if (indexDepth == ShapedType::kDynamicSize) {
return op.emitOpError("expected index depth is static"); return emitOpError("expected index depth is static");
} }
// The first dimension of the indices should match the first dimension of the // The first dimension of the indices should match the first dimension of the
// output. They indicate to the number of updates. // output. They indicate to the number of updates.
auto updateType = op.getUpdateType(); auto updateType = getUpdateType();
if (updateType.getRank() < 1) { if (updateType.getRank() < 1) {
return op.emitOpError("expected update value to be at least rank 1"); return emitOpError("expected update value to be at least rank 1");
} }
if (!checkDimensionsMatch(indicesType, updateType, 0)) { if (!checkDimensionsMatch(indicesType, updateType, 0)) {
return op.emitOpError( return emitOpError(
"mismatch in shape of indices and update value at dim#0"); "mismatch in shape of indices and update value at dim#0");
} }
auto originalType = op.getOriginalType(); auto originalType = getOriginalType();
if (updateType.getRank() - 1 > originalType.getRank()) { if (updateType.getRank() - 1 > originalType.getRank()) {
return op.emitOpError( return emitOpError(
"update value rank exceeds the rank of the original value"); "update value rank exceeds the rank of the original value");
} }
// indexDepth + update dims should cover the original dims. The first dim of // indexDepth + update dims should cover the original dims. The first dim of
// update is the number of updates. // update is the number of updates.
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) { if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
return op.emitOpError( return emitOpError(
"index depth and update value does not cover rank of original value"); "index depth and update value does not cover rank of original value");
} }
@ -331,7 +329,7 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
int64_t updateDim = std::get<1>(it); int64_t updateDim = std::get<1>(it);
if (updateType.getDimSize(updateDim) != if (updateType.getDimSize(updateDim) !=
originalType.getDimSize(originalDim)) { originalType.getDimSize(originalDim)) {
return op.emitOpError("mismatch in shape of update value dim#") return emitOpError("mismatch in shape of update value dim#")
<< updateDim << " and original value at dim#" << originalDim; << updateDim << " and original value at dim#" << originalDim;
} }
} }
@ -345,36 +343,36 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
int64_t updateDim = std::get<1>(it); int64_t updateDim = std::get<1>(it);
if (updateType.getDimSize(updateDim) > if (updateType.getDimSize(updateDim) >
originalType.getDimSize(originalDim)) { originalType.getDimSize(originalDim)) {
return op.emitOpError("indexed shape of update value dim#") return emitOpError("indexed shape of update value dim#")
<< updateDim << " exceeds original value at dim#" << originalDim << updateDim << " exceeds original value at dim#" << originalDim
<< " " << updateType.getDimSize(updateDim) << " " << " " << updateType.getDimSize(updateDim) << " "
<< originalType.getDimSize(originalDim); << originalType.getDimSize(originalDim);
} }
} }
Region &region = op.region(); Region &thisRegion = region();
Block *body = &region.front(); Block *body = &thisRegion.front();
if (body->getNumArguments() != 2) { if (body->getNumArguments() != 2) {
return op.emitOpError("expected region to have two arguments"); return emitOpError("expected region to have two arguments");
} }
Type arg0Type = body->getArgument(0).getType(); Type arg0Type = body->getArgument(0).getType();
Type arg1Type = body->getArgument(1).getType(); Type arg1Type = body->getArgument(1).getType();
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) { if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
return op.emitOpError( return emitOpError(
"expected region to have scalar argument of integer or float types"); "expected region to have scalar argument of integer or float types");
} }
if (arg0Type != updateType.getElementType()) { if (arg0Type != updateType.getElementType()) {
return op.emitOpError("mismatch in argument 0 of region ") return emitOpError("mismatch in argument 0 of region ")
<< arg0Type << " and element type of update value " << arg0Type << " and element type of update value "
<< updateType.getElementType(); << updateType.getElementType();
} }
if (arg1Type != originalType.getElementType()) { if (arg1Type != originalType.getElementType()) {
return op.emitOpError("mismatch in argument 1 of region ") return emitOpError("mismatch in argument 1 of region ")
<< arg1Type << " and element type of original value " << arg1Type << " and element type of original value "
<< originalType.getElementType(); << originalType.getElementType();
} }
if (arg0Type != arg1Type) { if (arg0Type != arg1Type) {
return op.emitOpError("mismatch in region argument types ") return emitOpError("mismatch in region argument types ")
<< arg0Type << " and " << arg1Type; << arg0Type << " and " << arg1Type;
} }
auto yieldOp = cast<TMTensor::YieldOp>(body->getTerminator()); auto yieldOp = cast<TMTensor::YieldOp>(body->getTerminator());
@ -455,7 +453,8 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices); Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx); Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
if (starts[i]) cast = b.create<arith::AddIOp>(loc, cast, starts[i]); if (starts[i])
cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
starts[i] = cast; starts[i] = cast;
} }

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinDialect.h"
@ -150,7 +151,7 @@ struct TMTensorBufferizePass
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
torch::TMTensor::createTMTensorBufferizePass() { torch::TMTensor::createTMTensorBufferizePass() {
return std::make_unique<TMTensorBufferizePass>(); return std::make_unique<TMTensorBufferizePass>();
} }

View File

@ -7,6 +7,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
@ -111,7 +112,7 @@ struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
torch::TMTensor::createTMTensorToLoopsPass() { torch::TMTensor::createTMTensorToLoopsPass() {
return std::make_unique<TMTensorToLoopsPass>(); return std::make_unique<TMTensorToLoopsPass>();
} }

@ -1 +1 @@
Subproject commit 8361c5da30588d3d4a48eae648f53be1feb5cfad Subproject commit e1318078a4e160eb723bcbcfcdcc9a1b618f7067

View File

@ -16,17 +16,17 @@ include "mlir/Pass/PassBase.td"
// Torch conversions // Torch conversions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def ConvertTorchToStd : Pass<"convert-torch-to-std", "FuncOp"> { def ConvertTorchToStd : Pass<"convert-torch-to-std", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Std ops"; let summary = "Convert recognized Torch ops to Std ops";
let constructor = "mlir::torch::createConvertTorchToStdPass()"; let constructor = "mlir::torch::createConvertTorchToStdPass()";
} }
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "FuncOp"> { def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to SCF ops"; let summary = "Convert recognized Torch ops to SCF ops";
let constructor = "mlir::torch::createConvertTorchToSCFPass()"; let constructor = "mlir::torch::createConvertTorchToSCFPass()";
} }
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> { def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Linalg ops"; let summary = "Convert recognized Torch ops to Linalg ops";
let description = [{ let description = [{
Convert ATen ops to linalg ops. Convert ATen ops to linalg ops.
@ -105,7 +105,7 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
} }
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> { def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA ops"; let summary = "Convert Torch ops to TOSA ops";
let description = [{ let description = [{
This pass assumes that TOSA ops are responsible for emitting error This pass assumes that TOSA ops are responsible for emitting error
@ -114,7 +114,7 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToTosaPass()"; let constructor = "mlir::torch::createConvertTorchToTosaPass()";
} }
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "FuncOp"> { def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops"; let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
let description = [{ let description = [{
Convert ATen ops to tmtensor/linalg ops. Convert ATen ops to tmtensor/linalg ops.

View File

@ -10,12 +10,13 @@
#ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H #ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
#define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H #define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory> #include <memory>
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToLinalgPass(); std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
} }
} // namespace mlir } // namespace mlir

View File

@ -10,11 +10,12 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H #ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
#define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H #define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToSCFPass(); std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToSCFPass();
} }
} // namespace mlir } // namespace mlir

View File

@ -10,12 +10,13 @@
#ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H #ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
#define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H #define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory> #include <memory>
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToStdPass(); std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToStdPass();
} }
} // namespace mlir } // namespace mlir

View File

@ -10,11 +10,12 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H #ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H #define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTMTensorPass(); std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTMTensorPass();
} }
} // namespace mlir } // namespace mlir

View File

@ -10,12 +10,13 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H #ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H #define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory> #include <memory>
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTosaPass(); std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
} }
} // namespace mlir } // namespace mlir

View File

@ -10,6 +10,8 @@
#ifndef TORCH_TYPES #ifndef TORCH_TYPES
#define TORCH_TYPES #define TORCH_TYPES
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/DialectBase.td"
include "torch-mlir/Dialect/Torch/IR/TorchBase.td" include "torch-mlir/Dialect/Torch/IR/TorchBase.td"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -24,28 +26,8 @@ class Torch_Type<string name, string typeMnemonic,
class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type<name, typeMnemonic> { class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type<name, typeMnemonic> {
let parameters = (ins "::mlir::Type":$containedType); let parameters = (ins "::mlir::Type":$containedType);
let hasCustomAssemblyFormat = 1;
let printer = [{
$_printer << "<";
// Print the contained type without the `!torch.` prefix.
printTorchDialectType(getImpl()->containedType, $_printer);
$_printer << ">";
}];
let parser = [{
if ($_parser.parseLess())
return Type();
// Parse the contained type, but forward directly to our internal parsing
// of `torch` dialect types, so that we can parse nested types without
// the `!torch.` prefix.
Type containedType = parseTorchDialectType($_parser);
if (!containedType)
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, containedType);
}];
let builders = [ let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{ TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
return Base::get(containedType.getContext(), containedType); return Base::get(containedType.getContext(), containedType);
@ -59,23 +41,7 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
Represents an instance of a `torch.nn.Module` with the given `className`. Represents an instance of a `torch.nn.Module` with the given `className`.
}]; }];
let parameters = (ins StringRefParameter<"class name">:$className); let parameters = (ins StringRefParameter<"class name">:$className);
let hasCustomAssemblyFormat = 1;
let printer = [{
$_printer << "<\"";
llvm::printEscapedString(getImpl()->className, $_printer.getStream());
$_printer << "\">";
}];
let parser = [{
if ($_parser.parseLess())
return Type();
std::string className;
if ($_parser.parseOptionalString(&className))
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, className);
}];
} }
// For standard ArrayRefs, which require allocation. // For standard ArrayRefs, which require allocation.
@ -186,6 +152,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
"::mlir::Type":$optionalDtype "::mlir::Type":$optionalDtype
); );
let genVerifyDecl = 1; let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
string extraBaseClassDeclaration = [{ string extraBaseClassDeclaration = [{
}]; }];
} }
@ -243,6 +210,7 @@ def Torch_TupleType : Torch_Type<"Tuple", "tuple"> {
let parameters = (ins let parameters = (ins
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
); );
let hasCustomAssemblyFormat = 1;
} }
def Torch_UnionType : Torch_Type<"Union", "union"> { def Torch_UnionType : Torch_Type<"Union", "union"> {
@ -261,6 +229,7 @@ def Torch_UnionType : Torch_Type<"Union", "union"> {
let parameters = (ins let parameters = (ins
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
); );
let hasCustomAssemblyFormat = 1;
} }
def Torch_DeviceType : Torch_Type<"Device", "Device"> { def Torch_DeviceType : Torch_Type<"Device", "Device"> {
@ -367,30 +336,7 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
let description = [{ let description = [{
Torch Dict type with key and value type. Torch Dict type with key and value type.
}]; }];
let hasCustomAssemblyFormat = 1;
let printer = [{
$_printer << "<";
printTorchDialectType(getImpl()->keyType, $_printer);
$_printer << ", ";
printTorchDialectType(getImpl()->valueType, $_printer);
$_printer<< ">";
}];
let parser = [{
if ($_parser.parseLess())
return Type();
Type keyType = parseTorchDialectType($_parser);
if (!keyType)
return Type();
if ($_parser.parseComma())
return Type();
Type valueType = parseTorchDialectType($_parser);
if (!valueType)
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, keyType, valueType);
}];
let builders = [ let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType, TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
"::mlir::Type":$valueType), [{ "::mlir::Type":$valueType), [{

View File

@ -10,11 +10,14 @@
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H #ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H #define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory> #include <memory>
namespace mlir { namespace mlir {
class ModuleOp;
namespace torch { namespace torch {
namespace Torch { namespace Torch {
@ -48,25 +51,26 @@ void createTorchShapeRefinementPipeline(
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass(); std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
std::unique_ptr<OperationPass<FuncOp>> createRefineTypesPass(); std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass(); std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
std::unique_ptr<OperationPass<FuncOp>> createReduceOpVariantsPass(); std::unique_ptr<OperationPass<func::FuncOp>> createReduceOpVariantsPass();
std::unique_ptr<OperationPass<FuncOp>> createMaximizeValueSemanticsPass(); std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass(); std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<FuncOp>> createDecomposeComplexOpsPass(); std::unique_ptr<OperationPass<func::FuncOp>> createDecomposeComplexOpsPass();
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass(); std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass(); std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
std::unique_ptr<OperationPass<FuncOp>> createSimplifyShapeCalculationsPass(); std::unique_ptr<OperationPass<func::FuncOp>>
createSimplifyShapeCalculationsPass();
std::unique_ptr<OperationPass<FuncOp>> createDropShapeCalculationsPass(); std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
StringRef getShapeLibrary(); StringRef getShapeLibrary();

View File

@ -126,7 +126,7 @@ def AdjustCallingConventions
}]; }];
} }
def RefineTypes : Pass<"torch-refine-types", "FuncOp"> { def RefineTypes : Pass<"torch-refine-types", "func::FuncOp"> {
let summary = "Refine types"; let summary = "Refine types";
let constructor = "mlir::torch::Torch::createRefineTypesPass()"; let constructor = "mlir::torch::Torch::createRefineTypesPass()";
let description = [{ let description = [{
@ -149,7 +149,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
}]; }];
} }
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> { def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
let summary = "Reduces variants of ops to a smaller set of ops."; let summary = "Reduces variants of ops to a smaller set of ops.";
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()"; let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
let description = [{ let description = [{
@ -165,7 +165,7 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
}]; }];
} }
def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "FuncOp"> { def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "func::FuncOp"> {
let summary = "Use value-semantic tensors where possible."; let summary = "Use value-semantic tensors where possible.";
let description = [{ let description = [{
Use value-semantic tensors where possible to make the program more Use value-semantic tensors where possible to make the program more
@ -215,7 +215,7 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
}]; }];
} }
def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> { def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
let summary = "Decompose complicated torch operations"; let summary = "Decompose complicated torch operations";
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()"; let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
let description = [{ let description = [{
@ -238,7 +238,7 @@ def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp">
}]; }];
} }
def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "FuncOp"> { def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func::FuncOp"> {
let summary = "Simplify reified shape calculations."; let summary = "Simplify reified shape calculations.";
let constructor = "mlir::torch::Torch::createSimplifyShapeCalculationsPass()"; let constructor = "mlir::torch::Torch::createSimplifyShapeCalculationsPass()";
let description = [{ let description = [{
@ -246,7 +246,7 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "FuncO
}]; }];
} }
def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "FuncOp"> { def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"> {
let summary = "Drop reified shape calculations."; let summary = "Drop reified shape calculations.";
let constructor = "mlir::torch::Torch::createDropShapeCalculationsPass()"; let constructor = "mlir::torch::Torch::createDropShapeCalculationsPass()";
let description = [{ let description = [{

View File

@ -10,12 +10,15 @@
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include <memory> #include <memory>
namespace mlir { namespace mlir {
class ModuleOp;
namespace torch { namespace torch {
namespace TorchConversion { namespace TorchConversion {
@ -36,7 +39,7 @@ createVerifyInvariantsBeforeBackendLoweringPass();
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass(); std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
createFinalizingBackendTypeConversionPass(); createFinalizingBackendTypeConversionPass();
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>

View File

@ -43,7 +43,7 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu
} }
def FinalizingBackendTypeConversion def FinalizingBackendTypeConversion
: Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> { : Pass<"torch-finalizing-backend-type-conversion", "func::FuncOp"> {
let summary = "Finalizes a partial conversion to builtin tensors"; let summary = "Finalizes a partial conversion to builtin tensors";
let constructor = let constructor =
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()"; "mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()";

View File

@ -10,10 +10,13 @@
#ifndef TORCHMLIR_REFBACKEND_PASSES_H #ifndef TORCHMLIR_REFBACKEND_PASSES_H
#define TORCHMLIR_REFBACKEND_PASSES_H #define TORCHMLIR_REFBACKEND_PASSES_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
namespace mlir { namespace mlir {
class ModuleOp;
namespace torch { namespace torch {
namespace RefBackend { namespace RefBackend {
@ -22,13 +25,13 @@ void registerRefBackendPasses();
std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass(); std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
std::unique_ptr<OperationPass<FuncOp>> createExpandOpsForLLVMPass(); std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass(); std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
std::unique_ptr<OperationPass<FuncOp>> createMungeMemrefCopyPass(); std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
std::unique_ptr<OperationPass<FuncOp>> createGeneralizeTensorPadPass(); std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
} // namespace RefBackend } // namespace RefBackend
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -24,18 +24,18 @@ def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> {
let dependentDialects = ["memref::MemRefDialect"]; let dependentDialects = ["memref::MemRefDialect"];
} }
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "FuncOp"> { def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> {
let summary = "Expand ops into more primitive ops before LLVM lowering."; let summary = "Expand ops into more primitive ops before LLVM lowering.";
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();"; let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
} }
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "FuncOp"> { def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
let summary = "Munge memref.copy to linalg.copy"; let summary = "Munge memref.copy to linalg.copy";
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();"; let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
let dependentDialects = ["memref::MemRefDialect"]; let dependentDialects = ["memref::MemRefDialect"];
} }
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "FuncOp"> { def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
let summary = "Convert tensor.pad to linalg ops"; let summary = "Convert tensor.pad to linalg ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()"; let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
} }

View File

@ -10,6 +10,7 @@
#ifndef TORCHMLIR_CONVERSION_PASSDETAIL_H #ifndef TORCHMLIR_CONVERSION_PASSDETAIL_H
#define TORCHMLIR_CONVERSION_PASSDETAIL_H #define TORCHMLIR_CONVERSION_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {

View File

@ -88,7 +88,7 @@ public:
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToLinalgPass() { mlir::torch::createConvertTorchToLinalgPass() {
return std::make_unique<ConvertTorchToLinalg>(); return std::make_unique<ConvertTorchToLinalg>();
} }

View File

@ -91,7 +91,7 @@ public:
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToSCFPass() { mlir::torch::createConvertTorchToSCFPass() {
return std::make_unique<ConvertTorchToSCF>(); return std::make_unique<ConvertTorchToSCF>();
} }

View File

@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h" #include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -221,7 +222,7 @@ public:
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToStdPass() { mlir::torch::createConvertTorchToStdPass() {
return std::make_unique<ConvertTorchToStd>(); return std::make_unique<ConvertTorchToStd>();
} }

View File

@ -10,8 +10,10 @@
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
@ -566,7 +568,7 @@ public:
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTMTensorPass() { mlir::torch::createConvertTorchToTMTensorPass() {
return std::make_unique<ConvertTorchToTMTensor>(); return std::make_unique<ConvertTorchToTMTensor>();
} }

View File

@ -3170,7 +3170,7 @@ public:
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTosaPass() { mlir::torch::createConvertTorchToTosaPass() {
return std::make_unique<ConvertTorchToTosa>(); return std::make_unique<ConvertTorchToTosa>();
} }

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

View File

@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/InliningUtils.h"
@ -124,7 +125,7 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
unsigned argIndex, unsigned argIndex,
NamedAttribute namedAttr) { NamedAttribute namedAttr) {
if (namedAttr.getName().getValue() == "torch.type_bound") { if (namedAttr.getName().getValue() == "torch.type_bound") {
auto func = dyn_cast<FuncOp>(op); auto func = dyn_cast<func::FuncOp>(op);
if (!func) if (!func)
return op->emitError() << "'torch.type_bound' must be attached to a func"; return op->emitError() << "'torch.type_bound' must be attached to a func";
TypeAttr attr = namedAttr.getValue().dyn_cast<TypeAttr>(); TypeAttr attr = namedAttr.getValue().dyn_cast<TypeAttr>();
@ -134,7 +135,7 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
if (!type) if (!type)
return op->emitError() << "'torch.type_bound' must be of " return op->emitError() << "'torch.type_bound' must be of "
"!torch.tensor/!torch.vtensor type"; "!torch.tensor/!torch.vtensor type";
if (!func.getType().getInput(argIndex).isa<BaseTensorType>()) if (!func.getFunctionType().getInput(argIndex).isa<BaseTensorType>())
return op->emitError() << "'torch.type_bound' must be attached to an " return op->emitError() << "'torch.type_bound' must be attached to an "
"argument of !torch.tensor/!torch.vtensor type"; "argument of !torch.tensor/!torch.vtensor type";
return success(); return success();
@ -177,3 +178,100 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// OptionalType and ListType
//===----------------------------------------------------------------------===//
void OptionalType::print(AsmPrinter &printer) const {
printer << "<";
// Print the contained type without the `!torch.` prefix.
printTorchDialectType(getImpl()->containedType, printer);
printer << ">";
}
void ListType::print(AsmPrinter &printer) const {
printer << "<";
// Print the contained type without the `!torch.` prefix.
printTorchDialectType(getImpl()->containedType, printer);
printer << ">";
}
Type OptionalType::parse(AsmParser &odsParser) {
if (odsParser.parseLess())
return Type();
// Parse the contained type, but forward directly to our internal parsing
// of `torch` dialect types, so that we can parse nested types without
// the `!torch.` prefix.
Type containedType = parseTorchDialectType(odsParser);
if (!containedType)
return Type();
if (odsParser.parseGreater())
return Type();
return get(odsParser.getContext(), containedType);
}
Type ListType::parse(AsmParser &odsParser) {
if (odsParser.parseLess())
return Type();
// Parse the contained type, but forward directly to our internal parsing
// of `torch` dialect types, so that we can parse nested types without
// the `!torch.` prefix.
Type containedType = parseTorchDialectType(odsParser);
if (!containedType)
return Type();
if (odsParser.parseGreater())
return Type();
return get(odsParser.getContext(), containedType);
}
//===----------------------------------------------------------------------===//
// DictType
//===----------------------------------------------------------------------===//
void DictType::print(AsmPrinter &printer) const {
printer << "<";
printTorchDialectType(getImpl()->keyType, printer);
printer << ", ";
printTorchDialectType(getImpl()->valueType, printer);
printer << ">";
}
Type DictType::parse(AsmParser &odsParser) {
if (odsParser.parseLess())
return Type();
Type keyType = parseTorchDialectType(odsParser);
if (!keyType)
return Type();
if (odsParser.parseComma())
return Type();
Type valueType = parseTorchDialectType(odsParser);
if (!valueType)
return Type();
if (odsParser.parseGreater())
return Type();
return get(odsParser.getContext(), keyType, valueType);
}
//===----------------------------------------------------------------------===//
// NnModuleType
//===----------------------------------------------------------------------===//
void NnModuleType::print(AsmPrinter &printer) const {
printer << "<\"";
llvm::printEscapedString(getImpl()->className, printer.getStream());
printer << "\">";
}
Type NnModuleType::parse(AsmParser &odsParser) {
if (odsParser.parseLess())
return Type();
std::string className;
if (odsParser.parseOptionalString(&className))
return Type();
if (odsParser.parseGreater())
return Type();
return get(odsParser.getContext(), className);
}

View File

@ -9,6 +9,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
@ -99,7 +100,7 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto func = auto func =
symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, functionAttr()); symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, functionAttr());
if (!func) if (!func)
return emitError() << "'@" << function() return emitError() << "'@" << function()
<< "' does not reference a valid function"; << "' does not reference a valid function";
@ -112,8 +113,8 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
"merely declared)"; "merely declared)";
auto expectedReceiverArgType = NnModuleType::get( auto expectedReceiverArgType = NnModuleType::get(
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName()); getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
if (func.getType().getNumInputs() == 0 || if (func.getFunctionType().getNumInputs() == 0 ||
func.getType().getInput(0) != expectedReceiverArgType) { func.getFunctionType().getInput(0) != expectedReceiverArgType) {
return emitError() << "the referenced function '" << function() return emitError() << "the referenced function '" << function()
<< "' must have a first argument of type " << "' must have a first argument of type "
<< expectedReceiverArgType; << expectedReceiverArgType;
@ -278,7 +279,7 @@ ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
Region *elseRegion = result.addRegion(); Region *elseRegion = result.addRegion();
auto &builder = parser.getBuilder(); auto &builder = parser.getBuilder();
OpAsmParser::OperandType cond; OpAsmParser::UnresolvedOperand cond;
Type boolType = builder.getType<Torch::BoolType>(); Type boolType = builder.getType<Torch::BoolType>();
if (parser.parseOperand(cond) || if (parser.parseOperand(cond) ||
parser.resolveOperand(cond, boolType, result.operands)) parser.resolveOperand(cond, boolType, result.operands))

View File

@ -35,7 +35,7 @@ ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser,
OperationState &result, int numOperands, OperationState &result, int numOperands,
int numResults) { int numResults) {
llvm::SMLoc loc = parser.getCurrentLocation(); llvm::SMLoc loc = parser.getCurrentLocation();
SmallVector<OpAsmParser::OperandType> operands; SmallVector<OpAsmParser::UnresolvedOperand> operands;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/numOperands)) if (parser.parseOperandList(operands, /*requiredOperandCount=*/numOperands))
return failure(); return failure();
if (parser.parseOptionalAttrDict(result.attributes)) if (parser.parseOptionalAttrDict(result.attributes))

View File

@ -31,11 +31,12 @@ using namespace mlir::torch::Torch;
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type>; using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type>;
namespace { namespace {
class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> { class AdjustCallingConventionForFunc
: public OpConversionPattern<func::FuncOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(FuncOp func, OpAdaptor adaptor, matchAndRewrite(func::FuncOp func, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = func.getContext(); MLIRContext *context = func.getContext();
auto typeBoundIdent = StringAttr::get(context, "torch.type_bound"); auto typeBoundIdent = StringAttr::get(context, "torch.type_bound");
@ -70,7 +71,7 @@ public:
typeConverter); typeConverter);
SmallVector<Type> newResultTypes; SmallVector<Type> newResultTypes;
for (auto type : func.getType().getResults()) { for (auto type : func.getFunctionType().getResults()) {
if (auto none = type.dyn_cast<Torch::NoneType>()) { if (auto none = type.dyn_cast<Torch::NoneType>()) {
continue; continue;
} }
@ -186,7 +187,7 @@ public:
}; };
} // namespace } // namespace
static LogicalResult adjustCallingConventions(FuncOp func, static LogicalResult adjustCallingConventions(func::FuncOp func,
TypeBoundMap &typeBoundMap) { TypeBoundMap &typeBoundMap) {
MLIRContext *context = func.getContext(); MLIRContext *context = func.getContext();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
@ -217,7 +218,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
patterns.add<AdjustCallingConventionForReturn>(typeConverter, context); patterns.add<AdjustCallingConventionForReturn>(typeConverter, context);
ConversionTarget target(*context); ConversionTarget target(*context);
target.addDynamicallyLegalOp<FuncOp>([](FuncOp func) { target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp func) {
for (int i = 0, e = func.getNumArguments(); i != e; i++) { for (int i = 0, e = func.getNumArguments(); i != e; i++) {
if (func.getArgAttr(i, "torch.type_bound")) if (func.getArgAttr(i, "torch.type_bound"))
return false; return false;
@ -225,7 +226,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
return false; return false;
} }
for (int i = 0, e = func.getNumResults(); i != e; i++) { for (int i = 0, e = func.getNumResults(); i != e; i++) {
if (func.getType().getResults()[i].isa<Torch::NoneType>()) if (func.getFunctionType().getResults()[i].isa<Torch::NoneType>())
return false; return false;
} }
return true; return true;
@ -266,7 +267,7 @@ class AdjustCallingConventionsPass
void runOnOperation() override { void runOnOperation() override {
auto module = getOperation(); auto module = getOperation();
TypeBoundMap typeBoundMap; TypeBoundMap typeBoundMap;
for (auto func : module.getOps<FuncOp>()) { for (auto func : module.getOps<func::FuncOp>()) {
for (int i = 0, e = func.getNumArguments(); i != e; i++) { for (int i = 0, e = func.getNumArguments(); i != e; i++) {
auto typeBoundAttr = auto typeBoundAttr =
func.getArgAttrOfType<TypeAttr>(i, "torch.type_bound"); func.getArgAttrOfType<TypeAttr>(i, "torch.type_bound");
@ -275,7 +276,7 @@ class AdjustCallingConventionsPass
typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue(); typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue();
} }
} }
for (auto func : module.getOps<FuncOp>()) { for (auto func : module.getOps<func::FuncOp>()) {
if (failed(adjustCallingConventions(func, typeBoundMap))) if (failed(adjustCallingConventions(func, typeBoundMap)))
return signalPassFailure(); return signalPassFailure();
} }

View File

@ -1781,7 +1781,7 @@ class DecomposeComplexOpsPass
} }
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createDecomposeComplexOpsPass() { mlir::torch::Torch::createDecomposeComplexOpsPass() {
return std::make_unique<DecomposeComplexOpsPass>(); return std::make_unique<DecomposeComplexOpsPass>();
} }

View File

@ -54,7 +54,7 @@ class DropShapeCalculationsPass
patterns.insert<DropShapeCalculateOp>(context); patterns.insert<DropShapeCalculateOp>(context);
ConversionTarget target(*context); ConversionTarget target(*context);
target.addIllegalOp<ShapeCalculateOp>(); target.addIllegalOp<ShapeCalculateOp>();
target.addLegalOp<FuncOp>(); target.addLegalOp<func::FuncOp>();
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) { std::move(patterns)))) {
@ -64,7 +64,7 @@ class DropShapeCalculationsPass
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createDropShapeCalculationsPass() { mlir::torch::Torch::createDropShapeCalculationsPass() {
return std::make_unique<DropShapeCalculationsPass>(); return std::make_unique<DropShapeCalculationsPass>();
} }

View File

@ -88,7 +88,7 @@ public:
return it->second; return it->second;
} }
Optional<LinkageInfo> getFuncLinkageInfo(NnModuleOp instance, Optional<LinkageInfo> getFuncLinkageInfo(NnModuleOp instance,
FuncOp methodFunc) { func::FuncOp methodFunc) {
auto it = funcLinkageInfo.find({instance, methodFunc}); auto it = funcLinkageInfo.find({instance, methodFunc});
if (it == funcLinkageInfo.end()) if (it == funcLinkageInfo.end())
return None; return None;
@ -185,7 +185,7 @@ private:
for (auto method : classType.getOps<MethodOp>()) { for (auto method : classType.getOps<MethodOp>()) {
nameStack.push_back(method.name().str()); nameStack.push_back(method.name().str());
funcLinkageInfo[{nnModule, funcLinkageInfo[{nnModule,
symbolTable.lookup<FuncOp>(method.function())}] = symbolTable.lookup<func::FuncOp>(method.function())}] =
LinkageInfo{llvm::join(nameStack, "."), method.isPrivate()}; LinkageInfo{llvm::join(nameStack, "."), method.isPrivate()};
nameStack.pop_back(); nameStack.pop_back();
} }
@ -251,7 +251,7 @@ private:
// Linkage info for each method in the program. Since we are going to be // Linkage info for each method in the program. Since we are going to be
// monomorphizing all the functions, we also need to key this off of the // monomorphizing all the functions, we also need to key this off of the
// instance (NnModuleOp) that the func is monomorphized for. // instance (NnModuleOp) that the func is monomorphized for.
DenseMap<std::pair<NnModuleOp, FuncOp>, LinkageInfo> funcLinkageInfo; DenseMap<std::pair<NnModuleOp, func::FuncOp>, LinkageInfo> funcLinkageInfo;
// The corresponding GlobalSlotOp for each SlotOp in the program. // The corresponding GlobalSlotOp for each SlotOp in the program.
DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot; DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot;
// A set of values that we have copied into torch.global_slot initializers, // A set of values that we have copied into torch.global_slot initializers,
@ -298,7 +298,7 @@ namespace {
// any notion of "type" that we have in the IR, but still fits the formal // any notion of "type" that we have in the IR, but still fits the formal
// definition. // definition.
struct Monomorphization { struct Monomorphization {
FuncOp func; func::FuncOp func;
std::vector<ArgInstance> argInstances; std::vector<ArgInstance> argInstances;
}; };
} // namespace } // namespace
@ -327,7 +327,7 @@ template <> struct llvm::DenseMapInfo<Monomorphization> {
// //
// This generalizes to a full abstract interpretation of the function, but // This generalizes to a full abstract interpretation of the function, but
// currently only analyzes a subset of ops. // currently only analyzes a subset of ops.
static LogicalResult analyzeInstances(FuncOp func, static LogicalResult analyzeInstances(func::FuncOp func,
ArrayRef<ArgInstance> argInstances, ArrayRef<ArgInstance> argInstances,
BlockAndValueMapping &mapping) { BlockAndValueMapping &mapping) {
for (auto &argInstance : argInstances) for (auto &argInstance : argInstances)
@ -351,7 +351,7 @@ static LogicalResult analyzeInstances(FuncOp func,
static FailureOr<Monomorphization> static FailureOr<Monomorphization>
createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping, createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping,
SymbolTable &symbolTable) { SymbolTable &symbolTable) {
auto func = symbolTable.lookup<FuncOp>(op.getCallee()); auto func = symbolTable.lookup<func::FuncOp>(op.getCallee());
Monomorphization monomorphization; Monomorphization monomorphization;
monomorphization.func = func; monomorphization.func = func;
for (auto operand : llvm::enumerate(op->getOperands())) { for (auto operand : llvm::enumerate(op->getOperands())) {
@ -372,7 +372,7 @@ public:
: module(module), symbolTable(module) {} : module(module), symbolTable(module) {}
LogicalResult LogicalResult
initialize(DenseMap<ClassTypeOp, std::vector<NnModuleOp>> &instances) { initialize(DenseMap<ClassTypeOp, std::vector<NnModuleOp>> &instances) {
for (auto func : module.getOps<FuncOp>()) { for (auto func : module.getOps<func::FuncOp>()) {
Monomorphization monomorphization; Monomorphization monomorphization;
monomorphization.func = func; monomorphization.func = func;
bool canTriviallyMonomorphize = true; bool canTriviallyMonomorphize = true;
@ -455,7 +455,7 @@ static LogicalResult verifyNnModuleValueUses(Value value) {
// Verify that `func` conforms to the subset of allowable method bodies // Verify that `func` conforms to the subset of allowable method bodies
// that we can convert. // that we can convert.
static LogicalResult verifyFuncConformsToSubset(FuncOp func) { static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) {
// TODO: Investingate why WalkResult::interrupt() doesn't propagate properly. // TODO: Investingate why WalkResult::interrupt() doesn't propagate properly.
LogicalResult ret = success(); LogicalResult ret = success();
func.walk([&](Block *block) { func.walk([&](Block *block) {
@ -481,7 +481,7 @@ static LogicalResult verifyFuncConformsToSubset(FuncOp func) {
static LogicalResult static LogicalResult
verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable, verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
MonomorphizationTracker &tracker) { MonomorphizationTracker &tracker) {
DenseMap<FuncOp, int> numMonomorphizations; DenseMap<func::FuncOp, int> numMonomorphizations;
for (auto &monomorphization : tracker.getMonomorphizations()) { for (auto &monomorphization : tracker.getMonomorphizations()) {
numMonomorphizations[monomorphization.func] += 1; numMonomorphizations[monomorphization.func] += 1;
} }
@ -489,7 +489,7 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
for (auto classType : module.getOps<ClassTypeOp>()) { for (auto classType : module.getOps<ClassTypeOp>()) {
for (auto method : classType.getOps<MethodOp>()) { for (auto method : classType.getOps<MethodOp>()) {
if (!method.isPrivate()) { if (!method.isPrivate()) {
if (numMonomorphizations[symbolTable.lookup<FuncOp>( if (numMonomorphizations[symbolTable.lookup<func::FuncOp>(
method.function())] > 1) { method.function())] > 1) {
method.emitError() method.emitError()
<< "public function with multiple monomorphizations"; << "public function with multiple monomorphizations";
@ -503,10 +503,9 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
// Rewrite `func`, given that all values of `NnModuleType` have been mapped in // Rewrite `func`, given that all values of `NnModuleType` have been mapped in
// `mapping` to corresponding global instances. // `mapping` to corresponding global instances.
static LogicalResult static LogicalResult rewriteMonomorphizedFuncClone(
rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping, func::FuncOp func, BlockAndValueMapping mapping, SymbolTable &symbolTable,
SymbolTable &symbolTable, DenseMap<Monomorphization, func::FuncOp> &newFuncs,
DenseMap<Monomorphization, FuncOp> &newFuncs,
ObjectGraphInfo &objectGraphInfo) { ObjectGraphInfo &objectGraphInfo) {
SmallVector<Operation *> toErase; SmallVector<Operation *> toErase;
@ -605,7 +604,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
// static analysis to discover how to monomorphize th eprogram, including // static analysis to discover how to monomorphize th eprogram, including
// tracking instances through control flow, through get/set attr, etc. We // tracking instances through control flow, through get/set attr, etc. We
// implement a very simple subset of cases. // implement a very simple subset of cases.
for (auto func : module.getOps<FuncOp>()) { for (auto func : module.getOps<func::FuncOp>()) {
if (failed(verifyFuncConformsToSubset(func))) if (failed(verifyFuncConformsToSubset(func)))
return failure(); return failure();
} }
@ -637,10 +636,10 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
// Step 4: Clone/rewrite functions to implement the necessary // Step 4: Clone/rewrite functions to implement the necessary
// monomorphizations. // monomorphizations.
DenseMap<Monomorphization, FuncOp> newFuncs; DenseMap<Monomorphization, func::FuncOp> newFuncs;
int uniquifier = 0; int uniquifier = 0;
for (auto &monomorphization : tracker.getMonomorphizations()) { for (auto &monomorphization : tracker.getMonomorphizations()) {
auto newFunc = cast<FuncOp>(monomorphization.func->clone()); auto newFunc = cast<func::FuncOp>(monomorphization.func->clone());
newFuncs[monomorphization] = newFunc; newFuncs[monomorphization] = newFunc;
Optional<LinkageInfo> linkageInfo = None; Optional<LinkageInfo> linkageInfo = None;
// If it is potentially a method, check its linkage info. // If it is potentially a method, check its linkage info.
@ -675,14 +674,14 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
} }
// Step 5: Clean up object graph. // Step 5: Clean up object graph.
DenseSet<FuncOp> liveFuncs; DenseSet<func::FuncOp> liveFuncs;
for (auto &kv : newFuncs) { for (auto &kv : newFuncs) {
liveFuncs.insert(kv.second); liveFuncs.insert(kv.second);
} }
for (auto &op : llvm::make_early_inc_range(module.getOps())) { for (auto &op : llvm::make_early_inc_range(module.getOps())) {
if (isa<GlobalSlotOp>(&op)) if (isa<GlobalSlotOp>(&op))
continue; continue;
if (auto func = dyn_cast<FuncOp>(op)) { if (auto func = dyn_cast<func::FuncOp>(op)) {
if (liveFuncs.contains(func)) if (liveFuncs.contains(func))
continue; continue;
} }

View File

@ -355,7 +355,7 @@ class MaximizeValueSemanticsPass
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createMaximizeValueSemanticsPass() { mlir::torch::Torch::createMaximizeValueSemanticsPass() {
return std::make_unique<MaximizeValueSemanticsPass>(); return std::make_unique<MaximizeValueSemanticsPass>();
} }

View File

@ -10,9 +10,11 @@
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H #ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H #define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
class ModuleOp;
namespace torch { namespace torch {
namespace Torch { namespace Torch {

View File

@ -94,7 +94,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
if (options.optimize) { if (options.optimize) {
// Eliminate the PrimTupleIndexOp generated from the // Eliminate the PrimTupleIndexOp generated from the
// adjustCallingConventions // adjustCallingConventions
pm.addNestedPass<FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Inline global slots, which for most inference scenarios deletes them. // Inline global slots, which for most inference scenarios deletes them.
// This also exposes more information to intraprocedural transformations // This also exposes more information to intraprocedural transformations
// below like MaximizeValueSemantics and RefineTypes. // below like MaximizeValueSemantics and RefineTypes.
@ -105,14 +105,14 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
} }
// Reduce variants of ops to a smaller set of primitives. // Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<FuncOp>(createReduceOpVariantsPass()); pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
if (options.optimize) { if (options.optimize) {
// OPT-ONLY: Right now we rely on this to eliminate certain branches that // OPT-ONLY: Right now we rely on this to eliminate certain branches that
// guard unreachable code that backends can't handle yet, such as lists, // guard unreachable code that backends can't handle yet, such as lists,
// RaiseException, unimplemented tensor ops, and only-used-in-training // RaiseException, unimplemented tensor ops, and only-used-in-training
// operations on `torch.global_slot`'s. // operations on `torch.global_slot`'s.
pm.addNestedPass<FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// OPT-ONLY: We may have deleted some `torch.global_slot.get` / // OPT-ONLY: We may have deleted some `torch.global_slot.get` /
// `torch.global_slot.get` ops, which may have left more // `torch.global_slot.get` ops, which may have left more
// `torch.global_slot`'s unused. // `torch.global_slot`'s unused.
@ -124,7 +124,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's. // Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass()); pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
// Do shape refinement. // Do shape refinement.
// This must be run before RefineTypes (which primarily does dtype inference), // This must be run before RefineTypes (which primarily does dtype inference),
@ -132,7 +132,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// operand. // operand.
createTorchShapeRefinementPipeline(pm, options); createTorchShapeRefinementPipeline(pm, options);
// Refine types in the program, which mainly means inferring dtypes of ops. // Refine types in the program, which mainly means inferring dtypes of ops.
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass()); pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by // Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends. // the previous pass. Doing this is ABI-compatible for our backends.
@ -142,7 +142,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// This can fold away some branches given the information got from // This can fold away some branches given the information got from
// RefineTypes before doing maximize value sematics which only works with // RefineTypes before doing maximize value sematics which only works with
// basic blocks. // basic blocks.
pm.addNestedPass<FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
} }
if (options.optimize) { if (options.optimize) {
@ -152,9 +152,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// branches that guard unreachable code that backends can't handle yet, such // branches that guard unreachable code that backends can't handle yet, such
// as lists, RaiseException, unimplemented aten ops, and // as lists, RaiseException, unimplemented aten ops, and
// only-used-in-training operations on `torch.global_slot`'s. // only-used-in-training operations on `torch.global_slot`'s.
pm.addNestedPass<FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
} }
pm.addNestedPass<FuncOp>(Torch::createDecomposeComplexOpsPass()); pm.addNestedPass<func::FuncOp>(Torch::createDecomposeComplexOpsPass());
// TODO: VerifyTorchBackendContractPass. // TODO: VerifyTorchBackendContractPass.
} }
@ -172,11 +172,11 @@ void mlir::torch::Torch::createTorchShapeRefinementPipeline(
// as hard as possible" kind of thing, so it's inherently somewhat brittle. // as hard as possible" kind of thing, so it's inherently somewhat brittle.
// The idea is to keep strengthening what we do here to support the shape // The idea is to keep strengthening what we do here to support the shape
// library. We don't need to support arbitrary programs, thankfully. // library. We don't need to support arbitrary programs, thankfully.
pm.addNestedPass<FuncOp>(Torch::createSimplifyShapeCalculationsPass()); pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
// Run CSE, then see if we can simplify further. // Run CSE, then see if we can simplify further.
pm.addNestedPass<FuncOp>(createCSEPass()); pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<FuncOp>(Torch::createSimplifyShapeCalculationsPass()); pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
// Drop shape calculations, leaving behind the shape-refined program. // Drop shape calculations, leaving behind the shape-refined program.
pm.addNestedPass<FuncOp>(Torch::createDropShapeCalculationsPass()); pm.addNestedPass<func::FuncOp>(Torch::createDropShapeCalculationsPass());
} }

View File

@ -33,10 +33,10 @@ public:
auto classType = symbolTable.lookup<ClassTypeOp>( auto classType = symbolTable.lookup<ClassTypeOp>(
op.receiver().getType().cast<NnModuleType>().getClassName()); op.receiver().getType().cast<NnModuleType>().getClassName());
assert(classType && "malformed module -- missing ClassTypeOp"); assert(classType && "malformed module -- missing ClassTypeOp");
FuncOp func; func::FuncOp func;
for (auto method : classType.getOps<MethodOp>()) { for (auto method : classType.getOps<MethodOp>()) {
if (method.name() == op.name()) { if (method.name() == op.name()) {
func = symbolTable.lookup<FuncOp>(method.function()); func = symbolTable.lookup<func::FuncOp>(method.function());
break; break;
} }
} }

View File

@ -207,7 +207,7 @@ public:
assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 && assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 &&
"Torch JIT operators shouldn't have regions or successors"); "Torch JIT operators shouldn't have regions or successors");
Operation *newOp = rewriter.createOperation(state); Operation *newOp = rewriter.create(state);
auto tensor = auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0)); rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
createOverwriteTensorContents(rewriter, op->getLoc(), tensor, createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
@ -276,7 +276,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createReduceOpVariantsPass() { mlir::torch::Torch::createReduceOpVariantsPass() {
return std::make_unique<ReduceOpVariantsPass>(); return std::make_unique<ReduceOpVariantsPass>();
} }

View File

@ -25,7 +25,7 @@ class RefinePublicReturnPass
: public RefinePublicReturnBase<RefinePublicReturnPass> { : public RefinePublicReturnBase<RefinePublicReturnPass> {
void runOnOperation() override { void runOnOperation() override {
auto module = getOperation(); auto module = getOperation();
module.walk([&](FuncOp func) { module.walk([&](func::FuncOp func) {
if (func.getVisibility() != SymbolTable::Visibility::Public) if (func.getVisibility() != SymbolTable::Visibility::Public)
return; return;
if (func.isExternal()) if (func.isExternal())
@ -40,7 +40,7 @@ class RefinePublicReturnPass
}); });
} }
void rewriteSignature(FuncOp func) { void rewriteSignature(func::FuncOp func) {
// Find the unique return op. // Find the unique return op.
func::ReturnOp returnOp; func::ReturnOp returnOp;
WalkResult walkResult = func.walk([&](func::ReturnOp op) { WalkResult walkResult = func.walk([&](func::ReturnOp op) {
@ -90,7 +90,7 @@ class RefinePublicReturnPass
returnOp->setOperands(newOperands); returnOp->setOperands(newOperands);
// Update the function type. // Update the function type.
auto funcType = func.getType(); auto funcType = func.getFunctionType();
func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(), func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(),
ValueRange(newOperands).getTypes())); ValueRange(newOperands).getTypes()));
} }

View File

@ -1191,7 +1191,7 @@ static bool isSafeToRefineOperandInPlace(OpOperand *use, Type newOperandType) {
return operationIsValidWithRefinedType(use, newOperandType); return operationIsValidWithRefinedType(use, newOperandType);
} }
void optimize(FuncOp func, TypeAnalyzer &analyzer) { void optimize(func::FuncOp func, TypeAnalyzer &analyzer) {
func.walk([&](Operation *op) { func.walk([&](Operation *op) {
auto convertValuesToMostRefinedType = [&](ValueRange values, OpBuilder &b) { auto convertValuesToMostRefinedType = [&](ValueRange values, OpBuilder &b) {
for (Value v : values) { for (Value v : values) {
@ -1336,7 +1336,7 @@ class RefineTypesPass : public RefineTypesBase<RefineTypesPass> {
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createRefineTypesPass() { mlir::torch::Torch::createRefineTypesPass() {
return std::make_unique<RefineTypesPass>(); return std::make_unique<RefineTypesPass>();
} }

View File

@ -169,7 +169,7 @@ static Value adjustShapeFunctionArg(Value operand, Type desiredType,
// from the shape library. // from the shape library.
static LogicalResult static LogicalResult
populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands, populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands,
mlir::FuncOp shapeFunction) { mlir::func::FuncOp shapeFunction) {
// Create a call to the shape function in the `shapeCalculation` region. // Create a call to the shape function in the `shapeCalculation` region.
// We will import the callee from the shape library later. // We will import the callee from the shape library later.
OpBuilder b(op.getContext()); OpBuilder b(op.getContext());
@ -241,7 +241,7 @@ class ReifyShapeCalculationsPass
name = name.drop_front(strlen("valsem.")); name = name.drop_front(strlen("valsem."));
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str(); auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
auto shapeFunction = auto shapeFunction =
shapeLibrary->lookupSymbol<FuncOp>(shapeFunctionName); shapeLibrary->lookupSymbol<func::FuncOp>(shapeFunctionName);
if (!shapeFunction) if (!shapeFunction)
return; return;
neededShapeFunctions.push_back(shapeFunctionName); neededShapeFunctions.push_back(shapeFunctionName);
@ -276,7 +276,7 @@ class ReifyShapeCalculationsPass
auto symName = worklist.pop_back_val(); auto symName = worklist.pop_back_val();
if (importedFunctions.count(symName)) if (importedFunctions.count(symName))
continue; continue;
auto func = shapeLibrary->lookupSymbol<mlir::FuncOp>(symName); auto func = shapeLibrary->lookupSymbol<mlir::func::FuncOp>(symName);
assert(func && "broken shape library"); assert(func && "broken shape library");
// Move the shape function from the library to the module this pass // Move the shape function from the library to the module this pass
// is running on. (this mutates the library, but we re-parse it each time // is running on. (this mutates the library, but we re-parse it each time

View File

@ -415,7 +415,7 @@ class SimplifyShapeCalculationsPass
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createSimplifyShapeCalculationsPass() { mlir::torch::Torch::createSimplifyShapeCalculationsPass() {
return std::make_unique<SimplifyShapeCalculationsPass>(); return std::make_unique<SimplifyShapeCalculationsPass>();
} }

View File

@ -45,9 +45,10 @@ struct FuncBackendTypeConversionPass
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversion(target, typeConverter);
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, typeConverter); populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { patterns, typeConverter);
return typeConverter.isSignatureLegal(op.getType()) && target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody()); typeConverter.isLegal(&op.getBody());
}); });
populateCallOpTypeConversionPattern(patterns, typeConverter); populateCallOpTypeConversionPattern(patterns, typeConverter);
@ -155,7 +156,7 @@ struct FinalizingBackendTypeConversionPass
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
return std::make_unique<FinalizingBackendTypeConversionPass>(); return std::make_unique<FinalizingBackendTypeConversionPass>();
} }

View File

@ -10,9 +10,12 @@
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
class ModuleOp;
namespace torch { namespace torch {
namespace TorchConversion { namespace TorchConversion {

View File

@ -59,26 +59,27 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
// We do this first as it tends to involve pattern-matching against constants, // We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model) // (e.g. dimensions which must be constant in a ranked programming model)
// and those constants get somewhat obscured by TorchToStd. // and those constants get somewhat obscured by TorchToStd.
pm.addNestedPass<FuncOp>(createConvertTorchToTMTensorPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<FuncOp>(memref::createExpandOpsPass()); pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
if (options.optimize) { if (options.optimize) {
// Clean up any non-canonical code introduced above.. // Clean up any non-canonical code introduced above..
pm.addNestedPass<FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Resolve `dim` ops on tensors (which currently live in the `memref` // Resolve `dim` ops on tensors (which currently live in the `memref`
// dialect for some reason -- we don't have memrefs at this level). // dialect for some reason -- we don't have memrefs at this level).
pm.addNestedPass<FuncOp>(memref::createResolveShapedTypeResultDimsPass()); pm.addNestedPass<func::FuncOp>(
memref::createResolveShapedTypeResultDimsPass());
// The resolution of `dim` ops tends to create identical ops. CSE them. // The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<FuncOp>(createCSEPass()); pm.addNestedPass<func::FuncOp>(createCSEPass());
} }
// Finish the type conversion from `torch` types to the types of the // Finish the type conversion from `torch` types to the types of the
// linalg-on-tensors backend contract. // linalg-on-tensors backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<FuncOp>( pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass()); TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that linalg on tensors backends // Verify that we have lowered to the form that linalg on tensors backends
@ -93,21 +94,21 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
pm.addPass( pm.addPass(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass()); pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works // Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass()); pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
if (options.optimize) { if (options.optimize) {
// Clean up any non-canonical code introduced above.. // Clean up any non-canonical code introduced above..
pm.addNestedPass<FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them. // The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<FuncOp>(createCSEPass()); pm.addNestedPass<func::FuncOp>(createCSEPass());
} }
// Finish the type conversion from `torch` types to the types of the // Finish the type conversion from `torch` types to the types of the
// TOSA backend contract. // TOSA backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<FuncOp>( pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass()); TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that TOSA backends // Verify that we have lowered to the form that TOSA backends

View File

@ -9,6 +9,8 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
@ -60,7 +62,7 @@ class VerifyLinalgOnTensorsBackendContractPass
ConversionTarget target(*context); ConversionTarget target(*context);
// Structural operations. // Structural operations.
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>( target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
opHasLegalTypes); opHasLegalTypes);
target.addDynamicallyLegalOp<GetNextSeedOp>(opHasLegalTypes); target.addDynamicallyLegalOp<GetNextSeedOp>(opHasLegalTypes);

View File

@ -9,6 +9,7 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
@ -39,7 +40,7 @@ class VerifyTosaBackendContractPass
ConversionTarget target(*context); ConversionTarget target(*context);
// Structural operations. // Structural operations.
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>( target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
opHasLegalTypes); opHasLegalTypes);
// Basic scalar operations. // Basic scalar operations.
target.addLegalDialect<tosa::TosaDialect>(); target.addLegalDialect<tosa::TosaDialect>();

View File

@ -10,6 +10,7 @@
#ifndef REFBACKEND_PASSDETAIL_H #ifndef REFBACKEND_PASSDETAIL_H
#define REFBACKEND_PASSDETAIL_H #define REFBACKEND_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"

View File

@ -15,7 +15,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@ -68,7 +68,7 @@ static bool isArgMemRefTypeValid(Type type) {
return false; return false;
} }
static void addEmitCInterfaceAttr(FuncOp func) { static void addEmitCInterfaceAttr(func::FuncOp func) {
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext())); func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
} }
@ -115,7 +115,7 @@ static void replaceReturnWithCall(OpBuilder b, func::ReturnOp op,
} }
static LogicalResult mungeFunction( static LogicalResult mungeFunction(
FuncOp func, func::FuncOp func,
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) { std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
// Only need to call mungeFunction for functions callable from outside of the // Only need to call mungeFunction for functions callable from outside of the
// module. // module.
@ -188,15 +188,15 @@ class MungeCallingConventions
auto module = getOperation(); auto module = getOperation();
OpBuilder b(module.getBodyRegion()); OpBuilder b(module.getBodyRegion());
std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs; std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
for (auto func : module.getOps<FuncOp>()) { for (auto func : module.getOps<func::FuncOp>()) {
if (failed(mungeFunction(func, invokedConsumeFuncReturnFuncs))) if (failed(mungeFunction(func, invokedConsumeFuncReturnFuncs)))
return signalPassFailure(); return signalPassFailure();
} }
// Create FuncOp for consumeFuncReturnFuncs that are used. // Create FuncOp for consumeFuncReturnFuncs that are used.
for (auto &p : invokedConsumeFuncReturnFuncs) { for (auto &p : invokedConsumeFuncReturnFuncs) {
auto consumeFuncReturnFunc = auto consumeFuncReturnFunc = b.create<func::FuncOp>(
b.create<FuncOp>(module.getLoc(), p.first, module.getLoc(), p.first,
FunctionType::get(module.getContext(), p.second, {}), FunctionType::get(module.getContext(), p.second, {}),
b.getStringAttr("private")); b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc); addEmitCInterfaceAttr(consumeFuncReturnFunc);
@ -309,7 +309,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::RefBackend::createExpandOpsForLLVMPass() { mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
return std::make_unique<ExpandOpsForLLVM>(); return std::make_unique<ExpandOpsForLLVM>();
} }
@ -366,7 +366,7 @@ class MungeMemrefCopy : public MungeMemrefCopyBase<MungeMemrefCopy> {
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::RefBackend::createMungeMemrefCopyPass() { mlir::torch::RefBackend::createMungeMemrefCopyPass() {
return std::make_unique<MungeMemrefCopy>(); return std::make_unique<MungeMemrefCopy>();
} }
@ -390,7 +390,7 @@ class GeneralizeTensorPad
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::RefBackend::createGeneralizeTensorPadPass() { mlir::torch::RefBackend::createGeneralizeTensorPadPass() {
return std::make_unique<GeneralizeTensorPad>(); return std::make_unique<GeneralizeTensorPad>();
} }

View File

@ -36,8 +36,8 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
MlirAttribute symNameAttr = mlirStringAttrGet( MlirAttribute symNameAttr = mlirStringAttrGet(
context, toMlirStringRef(function->qualname().qualifiedName())); context, toMlirStringRef(function->qualname().qualifiedName()));
MlirOperation func = createMlirOperation( MlirOperation func = createMlirOperation(
"builtin.func", loc, mlirRegionCreate(), "func.func", loc, mlirRegionCreate(),
toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)), toMlirNamedAttribute("function_type", mlirTypeAttrGet(functionType)),
toMlirNamedAttribute("sym_name", symNameAttr)); toMlirNamedAttribute("sym_name", symNameAttr));
std::vector<MlirAttribute> argAttrDicts; std::vector<MlirAttribute> argAttrDicts;
for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) { for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) {

View File

@ -29,7 +29,7 @@ import torch
from torch.jit import ScriptFunction from torch.jit import ScriptFunction
from torch_mlir import ir from torch_mlir import ir
from torch_mlir.dialects.builtin import FuncOp from torch_mlir.dialects.func import FuncOp
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder

View File

@ -115,15 +115,15 @@ class RefBackendInvoker:
LOWERING_PIPELINE = ",".join([ LOWERING_PIPELINE = ",".join([
"builtin.func(refback-generalize-tensor-pad)", "func.func(refback-generalize-tensor-pad)",
# Bufferize. # Bufferize.
"builtin.func(scf-bufferize)", "func.func(scf-bufferize)",
"builtin.func(tm-tensor-bufferize)", "func.func(tm-tensor-bufferize)",
"builtin.func(linalg-bufferize)", "func.func(linalg-bufferize)",
"func-bufferize", "func-bufferize",
"arith-bufferize", "arith-bufferize",
"builtin.func(tensor-bufferize)", "func.func(tensor-bufferize)",
"builtin.func(finalizing-bufferize)", "func.func(finalizing-bufferize)",
# Munge to make it ExecutionEngine compatible. # Munge to make it ExecutionEngine compatible.
# Specifically, we rewrite calling convention boundaries to be in terms # Specifically, we rewrite calling convention boundaries to be in terms
# of unranked memref, and we rewrite the return to actually be a # of unranked memref, and we rewrite the return to actually be a
@ -135,17 +135,17 @@ LOWERING_PIPELINE = ",".join([
# global seed used in stateful rng. # global seed used in stateful rng.
"refback-insert-rng-globals", "refback-insert-rng-globals",
# Lower to LLVM # Lower to LLVM
"builtin.func(tm-tensor-to-loops)", "func.func(tm-tensor-to-loops)",
"builtin.func(refback-munge-memref-copy)", "func.func(refback-munge-memref-copy)",
"builtin.func(convert-linalg-to-loops)", "func.func(convert-linalg-to-loops)",
"builtin.func(lower-affine)", "func.func(lower-affine)",
"convert-scf-to-cf", "convert-scf-to-cf",
"builtin.func(refback-expand-ops-for-llvm)", "func.func(refback-expand-ops-for-llvm)",
"builtin.func(arith-expand)", "func.func(arith-expand)",
"builtin.func(convert-math-to-llvm)", "func.func(convert-math-to-llvm)",
"convert-linalg-to-llvm", "convert-linalg-to-llvm",
"convert-memref-to-llvm", "convert-memref-to-llvm",
"builtin.func(convert-arith-to-llvm)", "func.func(convert-arith-to-llvm)",
"convert-func-to-llvm", "convert-func-to-llvm",
"convert-cf-to-llvm", "convert-cf-to-llvm",
"reconcile-unrealized-casts", "reconcile-unrealized-casts",

View File

@ -38,25 +38,25 @@ class LinalgOnTensorsTosaBackend(TosaBackend):
""" """
# TOSA legalization may emit tosa.const() ops. These are legalized # TOSA legalization may emit tosa.const() ops. These are legalized
# by tosa-to-standard to arith.constants. This mechanical transformation # by tosa-to-arith to arith.constants. This mechanical transformation
# must be done prior to TOSA-to-LinAlg so that the latter does not fail. # must be done prior to TOSA-to-LinAlg so that the latter does not fail.
# This is an artifact of legalizations spread across a collection of simple # This is an artifact of legalizations spread across a collection of simple
# ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg, # ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg,
# that depend on TOSA as well as TOSA-to-Standard. # that depend on TOSA as well as TOSA-to-Standard.
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
imported_module, imported_module,
"builtin.func(tosa-to-standard)", "func.func(tosa-to-arith)",
"Lowering TOSA to Standard") "Lowering TOSA to Arith")
# Named ops must be legalized prior to general tosa-to-linalg # Named ops must be legalized prior to general tosa-to-linalg
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
imported_module, imported_module,
"builtin.func(tosa-to-linalg-named)", "func.func(tosa-to-linalg-named)",
"Lowering TOSA to Linalg-on-Tensors for Named Ops") "Lowering TOSA to Linalg-on-Tensors for Named Ops")
run_pipeline_with_repro_report( run_pipeline_with_repro_report(
imported_module, imported_module,
"builtin.func(tosa-to-linalg)", "func.func(tosa-to-linalg)",
"Lowering TOSA to Linalg-on-Tensors") "Lowering TOSA to Linalg-on-Tensors")
return self.refbackend.compile(imported_module) return self.refbackend.compile(imported_module)

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @forward // CHECK-LABEL: func @forward
builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1 %int1 = torch.constant.int 1
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
%int3 = torch.constant.int 3 %int3 = torch.constant.int 3

View File

@ -42,7 +42,7 @@ torch.nn_module {
torch.slot "t1", %t : !torch.tensor torch.slot "t1", %t : !torch.tensor
torch.slot "t2", %t : !torch.tensor torch.slot "t2", %t : !torch.tensor
} : !torch.nn.Module<"c"> } : !torch.nn.Module<"c">
builtin.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor { func.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor {
%t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor %t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor
%t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor %t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor
%cst = torch.constant.int 1 %cst = torch.constant.int 1
@ -63,7 +63,7 @@ torch.nn_module {
torch.slot "t1", %t : !torch.tensor torch.slot "t1", %t : !torch.tensor
torch.slot "t2", %t : !torch.tensor torch.slot "t2", %t : !torch.tensor
} : !torch.nn.Module<"c"> } : !torch.nn.Module<"c">
builtin.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) { func.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) {
torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
return return

View File

@ -4,8 +4,8 @@
torch.class_type @c {} torch.class_type @c {}
%0 = torch.nn_module { %0 = torch.nn_module {
// expected-error @+1 {{'builtin.func' op is not allowed inside 'torch.nn_module'}} // expected-error @+1 {{'func.func' op is not allowed inside 'torch.nn_module'}}
builtin.func @f() func.func @f()
} : !torch.nn.Module<"c"> } : !torch.nn.Module<"c">
// ----- // -----
@ -32,8 +32,8 @@ torch.class_type @c {
// ----- // -----
torch.class_type @c { torch.class_type @c {
// expected-error @+1 {{'builtin.func' op is not allowed inside `torch.class_type`}} // expected-error @+1 {{'func.func' op is not allowed inside `torch.class_type`}}
builtin.func @f() func.func @f()
} }
// ----- // -----
@ -60,7 +60,7 @@ torch.class_type @c {
torch.method "f", @f torch.method "f", @f
} }
builtin.func @f(%arg0: !torch.nn.Module<"c">) { func.func @f(%arg0: !torch.nn.Module<"c">) {
return return
} }
@ -71,11 +71,11 @@ torch.class_type @c {
torch.method "f", @f torch.method "f", @f
} }
builtin.func private @f(%arg0: !torch.nn.Module<"c">) func.func private @f(%arg0: !torch.nn.Module<"c">)
// ----- // -----
builtin.func private @f() { func.func private @f() {
return return
} }
torch.class_type @c { torch.class_type @c {
@ -85,7 +85,7 @@ torch.class_type @c {
// ----- // -----
builtin.func private @f(!torch.nn.Module<"other_c">) { func.func private @f(%arg0: !torch.nn.Module<"other_c">) {
return return
} }
torch.class_type @c { torch.class_type @c {
@ -101,21 +101,21 @@ torch.class_type @c {
// ----- // -----
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}} // expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}}
builtin.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>}) func.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
// ----- // -----
// expected-error @+1 {{'torch.type_bound' must be TypeAttr}} // expected-error @+1 {{'torch.type_bound' must be TypeAttr}}
builtin.func @f(%arg0: i32 {torch.type_bound = 1}) func.func @f(%arg0: i32 {torch.type_bound = 1})
// ----- // -----
// expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}} // expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}}
builtin.func @f(%arg0: i32 {torch.type_bound = i32}) func.func @f(%arg0: i32 {torch.type_bound = i32})
// ----- // -----
builtin.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor { func.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
// expected-error @+1 {{operand type '!torch.optional<tensor>' and result type '!torch.tensor' are cast incompatible}} // expected-error @+1 {{operand type '!torch.optional<tensor>' and result type '!torch.tensor' are cast incompatible}}
%0 = torch.derefine %arg0 : !torch.optional<tensor> to !torch.tensor %0 = torch.derefine %arg0 : !torch.optional<tensor> to !torch.tensor
return %0 : !torch.tensor return %0 : !torch.tensor
@ -123,7 +123,7 @@ builtin.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
// ----- // -----
builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<tensor> { func.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<tensor> {
// expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<tensor>' are cast incompatible}} // expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<tensor>' are cast incompatible}}
%0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<tensor> %0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<tensor>
return %0 : !torch.optional<tensor> return %0 : !torch.optional<tensor>
@ -132,11 +132,11 @@ builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !
// ----- // -----
// expected-error @+1 {{invalid dtype 'tuple<>' for !torch.tensor type}} // expected-error @+1 {{invalid dtype 'tuple<>' for !torch.tensor type}}
builtin.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>> func.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
// ----- // -----
builtin.func @torch.tensor() { func.func @torch.tensor() {
// Incompatible shape. // Incompatible shape.
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}} // expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32> %0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
@ -145,7 +145,7 @@ builtin.func @torch.tensor() {
// ----- // -----
builtin.func @torch.tensor() { func.func @torch.tensor() {
// Incompatible dtype. // Incompatible dtype.
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}} // expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64> %0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
@ -154,7 +154,7 @@ builtin.func @torch.tensor() {
// ----- // -----
builtin.func @torch.tensor() { func.func @torch.tensor() {
// Incompatible type. // Incompatible type.
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}} // expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1 %0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1
@ -163,7 +163,7 @@ builtin.func @torch.tensor() {
// ----- // -----
builtin.func @torch.prim.ListConstruct() { func.func @torch.prim.ListConstruct() {
%int2 = torch.constant.int 2 %int2 = torch.constant.int 2
// expected-error@+1 {{operand types should have the same type as the list contained type}} // expected-error@+1 {{operand types should have the same type as the list contained type}}
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<tensor> torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<tensor>
@ -172,7 +172,7 @@ builtin.func @torch.prim.ListConstruct() {
// ----- // -----
builtin.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> { func.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32> %0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32>
// expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}} // expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}}
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32> torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32>

View File

@ -9,7 +9,7 @@
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64> // CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
// CHECK: return // CHECK: return
builtin.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>, func.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>,
%t1: !torch.vtensor<[1],f64>, %t1: !torch.vtensor<[1],f64>,
%alpha: !torch.float) { %alpha: !torch.float) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk> %1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
@ -25,7 +25,7 @@ builtin.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64> // CHECK-SAME: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
// CHECK: return // CHECK: return
builtin.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>, func.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
%t1: !torch.vtensor<[1],f64>, %t1: !torch.vtensor<[1],f64>,
%alpha: !torch.float) { %alpha: !torch.float) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk> %1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
@ -41,7 +41,7 @@ builtin.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],f32> // CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: return // CHECK: return
builtin.func @tensor_tensor$same_category_zero_rank_wider( func.func @tensor_tensor$same_category_zero_rank_wider(
%t0: !torch.vtensor<[1],f32>, %t0: !torch.vtensor<[1],f32>,
%t1: !torch.vtensor<[],f64>, %t1: !torch.vtensor<[],f64>,
%alpha: !torch.int) { %alpha: !torch.int) {
@ -58,7 +58,7 @@ builtin.func @tensor_tensor$same_category_zero_rank_wider(
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32> // CHECK-SAME: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: return // CHECK: return
builtin.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>, func.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>,
%t1: !torch.vtensor<[],f32>, %t1: !torch.vtensor<[],f32>,
%alpha: !torch.int) { %alpha: !torch.int) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk> %1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk>
@ -73,7 +73,7 @@ builtin.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],f32> // CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],f32>
// CHECK: return // CHECK: return
builtin.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>, func.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>,
%t1: !torch.vtensor<[1],f32>, %t1: !torch.vtensor<[1],f32>,
%alpha: !torch.float) { %alpha: !torch.float) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk> %1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk>
@ -89,7 +89,7 @@ builtin.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],f32> // CHECK-SAME: !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: return // CHECK: return
builtin.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) { func.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) {
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk> %1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk>
return return
} }
@ -103,7 +103,7 @@ builtin.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : // CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> // CHECK-SAME: !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
// CHECK: return // CHECK: return
builtin.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) { func.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) {
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk> %1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk>
return return
} }