mirror of https://github.com/llvm/torch-mlir
llvm: bump tag to e1318078 (#781)
The updated LLVM code includes a patch to create bfloat16 array attributes, thus enabling a different patch to torch-mlir to flesh out support for the bfloat16 type.pull/761/head snapshot-20220426.416
parent
9ec4712516
commit
9208bf0eb6
|
@ -463,7 +463,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
|||
$_op->getAttrs());
|
||||
for (Region &r : $_op->getRegions())
|
||||
r.cloneInto(state.addRegion(), bvm);
|
||||
return b.createOperation(state);
|
||||
return b.create(state);
|
||||
}]
|
||||
>
|
||||
];
|
||||
|
|
|
@ -31,9 +31,7 @@ class TMTensor_Op<string mnemonic, list<Trait> traits = []> :
|
|||
TMTensorInterface,
|
||||
SingleBlockImplicitTerminator<"::mlir::torch::TMTensor::YieldOp">
|
||||
])> {
|
||||
let verifier = [{ return verify$cppClass(*this); }];
|
||||
let printer = [{ return print$cppClass(p, *this); }];
|
||||
let parser = [{ return parse$cppClass(parser, result); }];
|
||||
let hasVerifier = 1;
|
||||
code extraTMTensorOpClassDeclaration = [{
|
||||
SmallVector<Value> getDestinationOperands(OpBuilder &b) {
|
||||
SmallVector<Value> dest(outputs().begin(), outputs().end());
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -10,14 +10,15 @@
|
|||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace TMTensor {
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createTMTensorToLoopsPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createTMTensorBufferizePass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorToLoopsPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorBufferizePass();
|
||||
|
||||
void registerPasses();
|
||||
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TMTensorToLoops :
|
||||
Pass<"tm-tensor-to-loops", "FuncOp"> {
|
||||
Pass<"tm-tensor-to-loops", "func::FuncOp"> {
|
||||
let summary = "Convert TMTensor ops to loops and Linalg ops.";
|
||||
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
|
||||
}
|
||||
|
||||
def TMTensorBufferize : Pass<"tm-tensor-bufferize", "FuncOp"> {
|
||||
def TMTensorBufferize : Pass<"tm-tensor-bufferize", "func::FuncOp"> {
|
||||
let summary = "Bufferize the TMTensor dialect";
|
||||
let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()";
|
||||
}
|
||||
|
|
|
@ -88,34 +88,34 @@ OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
|
|||
// ScanOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyScanOp(ScanOp op) {
|
||||
if (op.getNumInputs() != 1) {
|
||||
return op.emitOpError("expected one input operands");
|
||||
LogicalResult ScanOp::verify() {
|
||||
if (getNumInputs() != 1) {
|
||||
return emitOpError("expected one input operands");
|
||||
}
|
||||
if (op.getNumOutputs() != 2) {
|
||||
return op.emitOpError("expected two output operands");
|
||||
if (getNumOutputs() != 2) {
|
||||
return emitOpError("expected two output operands");
|
||||
}
|
||||
if (!op.input().getType().isa<ShapedType>()) {
|
||||
return op.emitOpError("expected first input element type to be shaped");
|
||||
if (!input().getType().isa<ShapedType>()) {
|
||||
return emitOpError("expected first input element type to be shaped");
|
||||
}
|
||||
auto accumulatorType = op.accumulator().getType().cast<ShapedType>();
|
||||
auto inputType = op.input().getType().cast<ShapedType>();
|
||||
auto outputType = op.output().getType().cast<ShapedType>();
|
||||
auto accumulatorType = accumulator().getType().cast<ShapedType>();
|
||||
auto inputType = input().getType().cast<ShapedType>();
|
||||
auto outputType = output().getType().cast<ShapedType>();
|
||||
ArrayRef<int64_t> inputShapes = inputType.getShape();
|
||||
ArrayRef<int64_t> outputShapes = outputType.getShape();
|
||||
if (accumulatorType.getElementType() != inputType.getElementType()) {
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"expected input/accumulator element types to be identical");
|
||||
}
|
||||
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
|
||||
int64_t accumulatorRank = accumulatorType.getRank();
|
||||
if (accumulatorRank != inputType.getRank() - 1) {
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"expected accumulator rank to be equal to input rank - 1");
|
||||
}
|
||||
SmallVector<int64_t> expectedAccumulatorShape;
|
||||
for (size_t i = 0; i < (size_t)inputType.getRank(); i++) {
|
||||
if (i != op.dimension())
|
||||
if (i != dimension())
|
||||
expectedAccumulatorShape.push_back(inputShapes[i]);
|
||||
}
|
||||
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
|
||||
|
@ -124,14 +124,13 @@ static LogicalResult verifyScanOp(ScanOp op) {
|
|||
std::get<1>(s) != ShapedType::kDynamicSize &&
|
||||
std::get<0>(s) != std::get<1>(s);
|
||||
})) {
|
||||
return op.emitOpError("incompatible input/accumulator shapes");
|
||||
return emitOpError("incompatible input/accumulator shapes");
|
||||
}
|
||||
if (inputType.getElementType() != outputType.getElementType()) {
|
||||
return op.emitOpError(
|
||||
"expected input/output element types to be identical");
|
||||
return emitOpError("expected input/output element types to be identical");
|
||||
}
|
||||
if (inputShapes.size() != outputShapes.size()) {
|
||||
return op.emitOpError("expected input/output to have identical ranks");
|
||||
return emitOpError("expected input/output to have identical ranks");
|
||||
}
|
||||
if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
|
||||
[](std::tuple<int64_t, int64_t> s) {
|
||||
|
@ -139,7 +138,7 @@ static LogicalResult verifyScanOp(ScanOp op) {
|
|||
std::get<1>(s) != ShapedType::kDynamicSize &&
|
||||
std::get<0>(s) != std::get<1>(s);
|
||||
})) {
|
||||
return op.emitOpError("incompatible input/output shapes");
|
||||
return emitOpError("incompatible input/output shapes");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -232,11 +231,11 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
|||
});
|
||||
|
||||
auto &srcBlock = region().front();
|
||||
Region ®ion = scfIf.getElseRegion();
|
||||
Region &thisRegion = scfIf.getElseRegion();
|
||||
BlockAndValueMapping bvm;
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
auto &block = region.front();
|
||||
auto &block = thisRegion.front();
|
||||
b.setInsertionPointToEnd(&block);
|
||||
for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
|
||||
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||
|
@ -275,48 +274,47 @@ LogicalResult ScanOp::fold(ArrayRef<Attribute>,
|
|||
//===----------------------------------------------------------------------===//
|
||||
// ScatterOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static LogicalResult verifyScatterOp(ScatterOp op) {
|
||||
if (op.inputs().size() != 2) {
|
||||
return op.emitOpError("expected two input operands");
|
||||
LogicalResult ScatterOp::verify() {
|
||||
if (inputs().size() != 2) {
|
||||
return emitOpError("expected two input operands");
|
||||
}
|
||||
if (op.outputs().size() != 1) {
|
||||
return op.emitOpError("expected one output operand");
|
||||
if (outputs().size() != 1) {
|
||||
return emitOpError("expected one output operand");
|
||||
}
|
||||
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
|
||||
return t1.getShape()[dim] == t2.getShape()[dim];
|
||||
};
|
||||
|
||||
auto indicesType = op.getIndicesType();
|
||||
auto indicesType = getIndicesType();
|
||||
if (indicesType.getRank() != 2 ||
|
||||
!indicesType.getElementType().isInteger(32)) {
|
||||
return op.emitOpError(
|
||||
"expected indices to be of rank 2 of i32 element type");
|
||||
return emitOpError("expected indices to be of rank 2 of i32 element type");
|
||||
}
|
||||
auto indexDepth = op.getIndexDepth();
|
||||
auto indexDepth = getIndexDepth();
|
||||
if (indexDepth == ShapedType::kDynamicSize) {
|
||||
return op.emitOpError("expected index depth is static");
|
||||
return emitOpError("expected index depth is static");
|
||||
}
|
||||
|
||||
// The first dimension of the indices should match the first dimension of the
|
||||
// output. They indicate to the number of updates.
|
||||
auto updateType = op.getUpdateType();
|
||||
auto updateType = getUpdateType();
|
||||
if (updateType.getRank() < 1) {
|
||||
return op.emitOpError("expected update value to be at least rank 1");
|
||||
return emitOpError("expected update value to be at least rank 1");
|
||||
}
|
||||
if (!checkDimensionsMatch(indicesType, updateType, 0)) {
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"mismatch in shape of indices and update value at dim#0");
|
||||
}
|
||||
auto originalType = op.getOriginalType();
|
||||
auto originalType = getOriginalType();
|
||||
if (updateType.getRank() - 1 > originalType.getRank()) {
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"update value rank exceeds the rank of the original value");
|
||||
}
|
||||
|
||||
// indexDepth + update dims should cover the original dims. The first dim of
|
||||
// update is the number of updates.
|
||||
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"index depth and update value does not cover rank of original value");
|
||||
}
|
||||
|
||||
|
@ -331,7 +329,7 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
|
|||
int64_t updateDim = std::get<1>(it);
|
||||
if (updateType.getDimSize(updateDim) !=
|
||||
originalType.getDimSize(originalDim)) {
|
||||
return op.emitOpError("mismatch in shape of update value dim#")
|
||||
return emitOpError("mismatch in shape of update value dim#")
|
||||
<< updateDim << " and original value at dim#" << originalDim;
|
||||
}
|
||||
}
|
||||
|
@ -345,36 +343,36 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
|
|||
int64_t updateDim = std::get<1>(it);
|
||||
if (updateType.getDimSize(updateDim) >
|
||||
originalType.getDimSize(originalDim)) {
|
||||
return op.emitOpError("indexed shape of update value dim#")
|
||||
return emitOpError("indexed shape of update value dim#")
|
||||
<< updateDim << " exceeds original value at dim#" << originalDim
|
||||
<< " " << updateType.getDimSize(updateDim) << " "
|
||||
<< originalType.getDimSize(originalDim);
|
||||
}
|
||||
}
|
||||
|
||||
Region ®ion = op.region();
|
||||
Block *body = ®ion.front();
|
||||
Region &thisRegion = region();
|
||||
Block *body = &thisRegion.front();
|
||||
if (body->getNumArguments() != 2) {
|
||||
return op.emitOpError("expected region to have two arguments");
|
||||
return emitOpError("expected region to have two arguments");
|
||||
}
|
||||
Type arg0Type = body->getArgument(0).getType();
|
||||
Type arg1Type = body->getArgument(1).getType();
|
||||
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"expected region to have scalar argument of integer or float types");
|
||||
}
|
||||
if (arg0Type != updateType.getElementType()) {
|
||||
return op.emitOpError("mismatch in argument 0 of region ")
|
||||
return emitOpError("mismatch in argument 0 of region ")
|
||||
<< arg0Type << " and element type of update value "
|
||||
<< updateType.getElementType();
|
||||
}
|
||||
if (arg1Type != originalType.getElementType()) {
|
||||
return op.emitOpError("mismatch in argument 1 of region ")
|
||||
return emitOpError("mismatch in argument 1 of region ")
|
||||
<< arg1Type << " and element type of original value "
|
||||
<< originalType.getElementType();
|
||||
}
|
||||
if (arg0Type != arg1Type) {
|
||||
return op.emitOpError("mismatch in region argument types ")
|
||||
return emitOpError("mismatch in region argument types ")
|
||||
<< arg0Type << " and " << arg1Type;
|
||||
}
|
||||
auto yieldOp = cast<TMTensor::YieldOp>(body->getTerminator());
|
||||
|
@ -455,7 +453,8 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
|
|||
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
|
||||
Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
|
||||
|
||||
if (starts[i]) cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
|
||||
if (starts[i])
|
||||
cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
|
||||
starts[i] = cast;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
|
@ -150,7 +151,7 @@ struct TMTensorBufferizePass
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
torch::TMTensor::createTMTensorBufferizePass() {
|
||||
return std::make_unique<TMTensorBufferizePass>();
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
|
@ -111,7 +112,7 @@ struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
torch::TMTensor::createTMTensorToLoopsPass() {
|
||||
return std::make_unique<TMTensorToLoopsPass>();
|
||||
}
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 8361c5da30588d3d4a48eae648f53be1feb5cfad
|
||||
Subproject commit e1318078a4e160eb723bcbcfcdcc9a1b618f7067
|
|
@ -16,17 +16,17 @@ include "mlir/Pass/PassBase.td"
|
|||
// Torch conversions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertTorchToStd : Pass<"convert-torch-to-std", "FuncOp"> {
|
||||
def ConvertTorchToStd : Pass<"convert-torch-to-std", "func::FuncOp"> {
|
||||
let summary = "Convert recognized Torch ops to Std ops";
|
||||
let constructor = "mlir::torch::createConvertTorchToStdPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "FuncOp"> {
|
||||
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
|
||||
let summary = "Convert recognized Torch ops to SCF ops";
|
||||
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
||||
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
|
||||
let summary = "Convert recognized Torch ops to Linalg ops";
|
||||
let description = [{
|
||||
Convert ATen ops to linalg ops.
|
||||
|
@ -105,7 +105,7 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
|||
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
|
||||
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
|
||||
let summary = "Convert Torch ops to TOSA ops";
|
||||
let description = [{
|
||||
This pass assumes that TOSA ops are responsible for emitting error
|
||||
|
@ -114,7 +114,7 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
|
|||
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "FuncOp"> {
|
||||
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
|
||||
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
|
||||
let description = [{
|
||||
Convert ATen ops to tmtensor/linalg ops.
|
||||
|
|
|
@ -10,12 +10,13 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
||||
#define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToLinalgPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -10,11 +10,12 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToSCFPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToSCFPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -10,12 +10,13 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
||||
#define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToStdPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToStdPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -10,11 +10,12 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTMTensorPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTMTensorPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -10,12 +10,13 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTosaPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
#ifndef TORCH_TYPES
|
||||
#define TORCH_TYPES
|
||||
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/DialectBase.td"
|
||||
include "torch-mlir/Dialect/Torch/IR/TorchBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -24,28 +26,8 @@ class Torch_Type<string name, string typeMnemonic,
|
|||
|
||||
class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type<name, typeMnemonic> {
|
||||
let parameters = (ins "::mlir::Type":$containedType);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let printer = [{
|
||||
$_printer << "<";
|
||||
// Print the contained type without the `!torch.` prefix.
|
||||
printTorchDialectType(getImpl()->containedType, $_printer);
|
||||
$_printer << ">";
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
if ($_parser.parseLess())
|
||||
return Type();
|
||||
|
||||
// Parse the contained type, but forward directly to our internal parsing
|
||||
// of `torch` dialect types, so that we can parse nested types without
|
||||
// the `!torch.` prefix.
|
||||
Type containedType = parseTorchDialectType($_parser);
|
||||
if (!containedType)
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
return get($_ctxt, containedType);
|
||||
}];
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
|
||||
return Base::get(containedType.getContext(), containedType);
|
||||
|
@ -59,23 +41,7 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
|
|||
Represents an instance of a `torch.nn.Module` with the given `className`.
|
||||
}];
|
||||
let parameters = (ins StringRefParameter<"class name">:$className);
|
||||
|
||||
let printer = [{
|
||||
$_printer << "<\"";
|
||||
llvm::printEscapedString(getImpl()->className, $_printer.getStream());
|
||||
$_printer << "\">";
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
if ($_parser.parseLess())
|
||||
return Type();
|
||||
std::string className;
|
||||
if ($_parser.parseOptionalString(&className))
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
return get($_ctxt, className);
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
// For standard ArrayRefs, which require allocation.
|
||||
|
@ -186,6 +152,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
|
|||
"::mlir::Type":$optionalDtype
|
||||
);
|
||||
let genVerifyDecl = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
string extraBaseClassDeclaration = [{
|
||||
}];
|
||||
}
|
||||
|
@ -243,6 +210,7 @@ def Torch_TupleType : Torch_Type<"Tuple", "tuple"> {
|
|||
let parameters = (ins
|
||||
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def Torch_UnionType : Torch_Type<"Union", "union"> {
|
||||
|
@ -261,6 +229,7 @@ def Torch_UnionType : Torch_Type<"Union", "union"> {
|
|||
let parameters = (ins
|
||||
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
|
||||
|
@ -367,30 +336,7 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
|
|||
let description = [{
|
||||
Torch Dict type with key and value type.
|
||||
}];
|
||||
|
||||
let printer = [{
|
||||
$_printer << "<";
|
||||
printTorchDialectType(getImpl()->keyType, $_printer);
|
||||
$_printer << ", ";
|
||||
printTorchDialectType(getImpl()->valueType, $_printer);
|
||||
$_printer<< ">";
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
if ($_parser.parseLess())
|
||||
return Type();
|
||||
Type keyType = parseTorchDialectType($_parser);
|
||||
if (!keyType)
|
||||
return Type();
|
||||
if ($_parser.parseComma())
|
||||
return Type();
|
||||
Type valueType = parseTorchDialectType($_parser);
|
||||
if (!valueType)
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
return get($_ctxt, keyType, valueType);
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
|
||||
"::mlir::Type":$valueType), [{
|
||||
|
|
|
@ -10,11 +10,14 @@
|
|||
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
|
@ -48,25 +51,26 @@ void createTorchShapeRefinementPipeline(
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createRefineTypesPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createReduceOpVariantsPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createReduceOpVariantsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMaximizeValueSemanticsPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createDecomposeComplexOpsPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createDecomposeComplexOpsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createSimplifyShapeCalculationsPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createSimplifyShapeCalculationsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createDropShapeCalculationsPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
|
||||
|
||||
StringRef getShapeLibrary();
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ def AdjustCallingConventions
|
|||
}];
|
||||
}
|
||||
|
||||
def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
|
||||
def RefineTypes : Pass<"torch-refine-types", "func::FuncOp"> {
|
||||
let summary = "Refine types";
|
||||
let constructor = "mlir::torch::Torch::createRefineTypesPass()";
|
||||
let description = [{
|
||||
|
@ -149,7 +149,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
|
||||
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
|
||||
let summary = "Reduces variants of ops to a smaller set of ops.";
|
||||
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
|
||||
let description = [{
|
||||
|
@ -165,7 +165,7 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "FuncOp"> {
|
||||
def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "func::FuncOp"> {
|
||||
let summary = "Use value-semantic tensors where possible.";
|
||||
let description = [{
|
||||
Use value-semantic tensors where possible to make the program more
|
||||
|
@ -215,7 +215,7 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> {
|
||||
def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
|
||||
let summary = "Decompose complicated torch operations";
|
||||
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
|
||||
let description = [{
|
||||
|
@ -238,7 +238,7 @@ def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp">
|
|||
}];
|
||||
}
|
||||
|
||||
def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "FuncOp"> {
|
||||
def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func::FuncOp"> {
|
||||
let summary = "Simplify reified shape calculations.";
|
||||
let constructor = "mlir::torch::Torch::createSimplifyShapeCalculationsPass()";
|
||||
let description = [{
|
||||
|
@ -246,7 +246,7 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "FuncO
|
|||
}];
|
||||
}
|
||||
|
||||
def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "FuncOp"> {
|
||||
def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"> {
|
||||
let summary = "Drop reified shape calculations.";
|
||||
let constructor = "mlir::torch::Torch::createDropShapeCalculationsPass()";
|
||||
let description = [{
|
||||
|
|
|
@ -10,12 +10,15 @@
|
|||
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
||||
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
|
||||
namespace torch {
|
||||
namespace TorchConversion {
|
||||
|
||||
|
@ -36,7 +39,7 @@ createVerifyInvariantsBeforeBackendLoweringPass();
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createFinalizingBackendTypeConversionPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
|
|
@ -43,7 +43,7 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu
|
|||
}
|
||||
|
||||
def FinalizingBackendTypeConversion
|
||||
: Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> {
|
||||
: Pass<"torch-finalizing-backend-type-conversion", "func::FuncOp"> {
|
||||
let summary = "Finalizes a partial conversion to builtin tensors";
|
||||
let constructor =
|
||||
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()";
|
||||
|
|
|
@ -10,10 +10,13 @@
|
|||
#ifndef TORCHMLIR_REFBACKEND_PASSES_H
|
||||
#define TORCHMLIR_REFBACKEND_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
|
||||
namespace torch {
|
||||
namespace RefBackend {
|
||||
|
||||
|
@ -22,13 +25,13 @@ void registerRefBackendPasses();
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createExpandOpsForLLVMPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMungeMemrefCopyPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createGeneralizeTensorPadPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
|
||||
} // namespace RefBackend
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -24,18 +24,18 @@ def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> {
|
|||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "FuncOp"> {
|
||||
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> {
|
||||
let summary = "Expand ops into more primitive ops before LLVM lowering.";
|
||||
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
|
||||
}
|
||||
|
||||
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "FuncOp"> {
|
||||
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
|
||||
let summary = "Munge memref.copy to linalg.copy";
|
||||
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "FuncOp"> {
|
||||
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
|
||||
let summary = "Convert tensor.pad to linalg ops";
|
||||
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#ifndef TORCHMLIR_CONVERSION_PASSDETAIL_H
|
||||
#define TORCHMLIR_CONVERSION_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -88,7 +88,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToLinalgPass() {
|
||||
return std::make_unique<ConvertTorchToLinalg>();
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToSCFPass() {
|
||||
return std::make_unique<ConvertTorchToSCF>();
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
@ -221,7 +222,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToStdPass() {
|
||||
return std::make_unique<ConvertTorchToStd>();
|
||||
}
|
||||
|
|
|
@ -10,8 +10,10 @@
|
|||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
|
@ -566,7 +568,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToTMTensorPass() {
|
||||
return std::make_unique<ConvertTorchToTMTensor>();
|
||||
}
|
||||
|
|
|
@ -3170,7 +3170,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToTosaPass() {
|
||||
return std::make_unique<ConvertTorchToTosa>();
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
@ -124,7 +125,7 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
|
|||
unsigned argIndex,
|
||||
NamedAttribute namedAttr) {
|
||||
if (namedAttr.getName().getValue() == "torch.type_bound") {
|
||||
auto func = dyn_cast<FuncOp>(op);
|
||||
auto func = dyn_cast<func::FuncOp>(op);
|
||||
if (!func)
|
||||
return op->emitError() << "'torch.type_bound' must be attached to a func";
|
||||
TypeAttr attr = namedAttr.getValue().dyn_cast<TypeAttr>();
|
||||
|
@ -134,7 +135,7 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
|
|||
if (!type)
|
||||
return op->emitError() << "'torch.type_bound' must be of "
|
||||
"!torch.tensor/!torch.vtensor type";
|
||||
if (!func.getType().getInput(argIndex).isa<BaseTensorType>())
|
||||
if (!func.getFunctionType().getInput(argIndex).isa<BaseTensorType>())
|
||||
return op->emitError() << "'torch.type_bound' must be attached to an "
|
||||
"argument of !torch.tensor/!torch.vtensor type";
|
||||
return success();
|
||||
|
@ -177,3 +178,100 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
|
|||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OptionalType and ListType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void OptionalType::print(AsmPrinter &printer) const {
|
||||
printer << "<";
|
||||
// Print the contained type without the `!torch.` prefix.
|
||||
printTorchDialectType(getImpl()->containedType, printer);
|
||||
printer << ">";
|
||||
}
|
||||
|
||||
void ListType::print(AsmPrinter &printer) const {
|
||||
printer << "<";
|
||||
// Print the contained type without the `!torch.` prefix.
|
||||
printTorchDialectType(getImpl()->containedType, printer);
|
||||
printer << ">";
|
||||
}
|
||||
|
||||
Type OptionalType::parse(AsmParser &odsParser) {
|
||||
if (odsParser.parseLess())
|
||||
return Type();
|
||||
|
||||
// Parse the contained type, but forward directly to our internal parsing
|
||||
// of `torch` dialect types, so that we can parse nested types without
|
||||
// the `!torch.` prefix.
|
||||
Type containedType = parseTorchDialectType(odsParser);
|
||||
if (!containedType)
|
||||
return Type();
|
||||
if (odsParser.parseGreater())
|
||||
return Type();
|
||||
return get(odsParser.getContext(), containedType);
|
||||
}
|
||||
|
||||
Type ListType::parse(AsmParser &odsParser) {
|
||||
if (odsParser.parseLess())
|
||||
return Type();
|
||||
|
||||
// Parse the contained type, but forward directly to our internal parsing
|
||||
// of `torch` dialect types, so that we can parse nested types without
|
||||
// the `!torch.` prefix.
|
||||
Type containedType = parseTorchDialectType(odsParser);
|
||||
if (!containedType)
|
||||
return Type();
|
||||
if (odsParser.parseGreater())
|
||||
return Type();
|
||||
return get(odsParser.getContext(), containedType);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DictType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void DictType::print(AsmPrinter &printer) const {
|
||||
printer << "<";
|
||||
printTorchDialectType(getImpl()->keyType, printer);
|
||||
printer << ", ";
|
||||
printTorchDialectType(getImpl()->valueType, printer);
|
||||
printer << ">";
|
||||
}
|
||||
|
||||
Type DictType::parse(AsmParser &odsParser) {
|
||||
if (odsParser.parseLess())
|
||||
return Type();
|
||||
Type keyType = parseTorchDialectType(odsParser);
|
||||
if (!keyType)
|
||||
return Type();
|
||||
if (odsParser.parseComma())
|
||||
return Type();
|
||||
Type valueType = parseTorchDialectType(odsParser);
|
||||
if (!valueType)
|
||||
return Type();
|
||||
if (odsParser.parseGreater())
|
||||
return Type();
|
||||
return get(odsParser.getContext(), keyType, valueType);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NnModuleType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void NnModuleType::print(AsmPrinter &printer) const {
|
||||
printer << "<\"";
|
||||
llvm::printEscapedString(getImpl()->className, printer.getStream());
|
||||
printer << "\">";
|
||||
}
|
||||
|
||||
Type NnModuleType::parse(AsmParser &odsParser) {
|
||||
if (odsParser.parseLess())
|
||||
return Type();
|
||||
std::string className;
|
||||
if (odsParser.parseOptionalString(&className))
|
||||
return Type();
|
||||
if (odsParser.parseGreater())
|
||||
return Type();
|
||||
return get(odsParser.getContext(), className);
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -99,7 +100,7 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
|
|||
|
||||
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
auto func =
|
||||
symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, functionAttr());
|
||||
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, functionAttr());
|
||||
if (!func)
|
||||
return emitError() << "'@" << function()
|
||||
<< "' does not reference a valid function";
|
||||
|
@ -112,8 +113,8 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
"merely declared)";
|
||||
auto expectedReceiverArgType = NnModuleType::get(
|
||||
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
|
||||
if (func.getType().getNumInputs() == 0 ||
|
||||
func.getType().getInput(0) != expectedReceiverArgType) {
|
||||
if (func.getFunctionType().getNumInputs() == 0 ||
|
||||
func.getFunctionType().getInput(0) != expectedReceiverArgType) {
|
||||
return emitError() << "the referenced function '" << function()
|
||||
<< "' must have a first argument of type "
|
||||
<< expectedReceiverArgType;
|
||||
|
@ -278,7 +279,7 @@ ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
Region *elseRegion = result.addRegion();
|
||||
|
||||
auto &builder = parser.getBuilder();
|
||||
OpAsmParser::OperandType cond;
|
||||
OpAsmParser::UnresolvedOperand cond;
|
||||
Type boolType = builder.getType<Torch::BoolType>();
|
||||
if (parser.parseOperand(cond) ||
|
||||
parser.resolveOperand(cond, boolType, result.operands))
|
||||
|
|
|
@ -35,7 +35,7 @@ ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser,
|
|||
OperationState &result, int numOperands,
|
||||
int numResults) {
|
||||
llvm::SMLoc loc = parser.getCurrentLocation();
|
||||
SmallVector<OpAsmParser::OperandType> operands;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> operands;
|
||||
if (parser.parseOperandList(operands, /*requiredOperandCount=*/numOperands))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
|
|
|
@ -31,11 +31,12 @@ using namespace mlir::torch::Torch;
|
|||
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type>;
|
||||
|
||||
namespace {
|
||||
class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> {
|
||||
class AdjustCallingConventionForFunc
|
||||
: public OpConversionPattern<func::FuncOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(FuncOp func, OpAdaptor adaptor,
|
||||
matchAndRewrite(func::FuncOp func, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MLIRContext *context = func.getContext();
|
||||
auto typeBoundIdent = StringAttr::get(context, "torch.type_bound");
|
||||
|
@ -70,7 +71,7 @@ public:
|
|||
typeConverter);
|
||||
|
||||
SmallVector<Type> newResultTypes;
|
||||
for (auto type : func.getType().getResults()) {
|
||||
for (auto type : func.getFunctionType().getResults()) {
|
||||
if (auto none = type.dyn_cast<Torch::NoneType>()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -186,7 +187,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
static LogicalResult adjustCallingConventions(FuncOp func,
|
||||
static LogicalResult adjustCallingConventions(func::FuncOp func,
|
||||
TypeBoundMap &typeBoundMap) {
|
||||
MLIRContext *context = func.getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
|
@ -217,7 +218,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
|||
patterns.add<AdjustCallingConventionForReturn>(typeConverter, context);
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addDynamicallyLegalOp<FuncOp>([](FuncOp func) {
|
||||
target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp func) {
|
||||
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
||||
if (func.getArgAttr(i, "torch.type_bound"))
|
||||
return false;
|
||||
|
@ -225,7 +226,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
|||
return false;
|
||||
}
|
||||
for (int i = 0, e = func.getNumResults(); i != e; i++) {
|
||||
if (func.getType().getResults()[i].isa<Torch::NoneType>())
|
||||
if (func.getFunctionType().getResults()[i].isa<Torch::NoneType>())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -266,7 +267,7 @@ class AdjustCallingConventionsPass
|
|||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
TypeBoundMap typeBoundMap;
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
for (auto func : module.getOps<func::FuncOp>()) {
|
||||
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
||||
auto typeBoundAttr =
|
||||
func.getArgAttrOfType<TypeAttr>(i, "torch.type_bound");
|
||||
|
@ -275,7 +276,7 @@ class AdjustCallingConventionsPass
|
|||
typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue();
|
||||
}
|
||||
}
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
for (auto func : module.getOps<func::FuncOp>()) {
|
||||
if (failed(adjustCallingConventions(func, typeBoundMap)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -1781,7 +1781,7 @@ class DecomposeComplexOpsPass
|
|||
}
|
||||
};
|
||||
} // namespace
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createDecomposeComplexOpsPass() {
|
||||
return std::make_unique<DecomposeComplexOpsPass>();
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ class DropShapeCalculationsPass
|
|||
patterns.insert<DropShapeCalculateOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<ShapeCalculateOp>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
target.addLegalOp<func::FuncOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
@ -64,7 +64,7 @@ class DropShapeCalculationsPass
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createDropShapeCalculationsPass() {
|
||||
return std::make_unique<DropShapeCalculationsPass>();
|
||||
}
|
||||
|
|
|
@ -88,7 +88,7 @@ public:
|
|||
return it->second;
|
||||
}
|
||||
Optional<LinkageInfo> getFuncLinkageInfo(NnModuleOp instance,
|
||||
FuncOp methodFunc) {
|
||||
func::FuncOp methodFunc) {
|
||||
auto it = funcLinkageInfo.find({instance, methodFunc});
|
||||
if (it == funcLinkageInfo.end())
|
||||
return None;
|
||||
|
@ -185,7 +185,7 @@ private:
|
|||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
nameStack.push_back(method.name().str());
|
||||
funcLinkageInfo[{nnModule,
|
||||
symbolTable.lookup<FuncOp>(method.function())}] =
|
||||
symbolTable.lookup<func::FuncOp>(method.function())}] =
|
||||
LinkageInfo{llvm::join(nameStack, "."), method.isPrivate()};
|
||||
nameStack.pop_back();
|
||||
}
|
||||
|
@ -251,7 +251,7 @@ private:
|
|||
// Linkage info for each method in the program. Since we are going to be
|
||||
// monomorphizing all the functions, we also need to key this off of the
|
||||
// instance (NnModuleOp) that the func is monomorphized for.
|
||||
DenseMap<std::pair<NnModuleOp, FuncOp>, LinkageInfo> funcLinkageInfo;
|
||||
DenseMap<std::pair<NnModuleOp, func::FuncOp>, LinkageInfo> funcLinkageInfo;
|
||||
// The corresponding GlobalSlotOp for each SlotOp in the program.
|
||||
DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot;
|
||||
// A set of values that we have copied into torch.global_slot initializers,
|
||||
|
@ -298,7 +298,7 @@ namespace {
|
|||
// any notion of "type" that we have in the IR, but still fits the formal
|
||||
// definition.
|
||||
struct Monomorphization {
|
||||
FuncOp func;
|
||||
func::FuncOp func;
|
||||
std::vector<ArgInstance> argInstances;
|
||||
};
|
||||
} // namespace
|
||||
|
@ -327,7 +327,7 @@ template <> struct llvm::DenseMapInfo<Monomorphization> {
|
|||
//
|
||||
// This generalizes to a full abstract interpretation of the function, but
|
||||
// currently only analyzes a subset of ops.
|
||||
static LogicalResult analyzeInstances(FuncOp func,
|
||||
static LogicalResult analyzeInstances(func::FuncOp func,
|
||||
ArrayRef<ArgInstance> argInstances,
|
||||
BlockAndValueMapping &mapping) {
|
||||
for (auto &argInstance : argInstances)
|
||||
|
@ -351,7 +351,7 @@ static LogicalResult analyzeInstances(FuncOp func,
|
|||
static FailureOr<Monomorphization>
|
||||
createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping,
|
||||
SymbolTable &symbolTable) {
|
||||
auto func = symbolTable.lookup<FuncOp>(op.getCallee());
|
||||
auto func = symbolTable.lookup<func::FuncOp>(op.getCallee());
|
||||
Monomorphization monomorphization;
|
||||
monomorphization.func = func;
|
||||
for (auto operand : llvm::enumerate(op->getOperands())) {
|
||||
|
@ -372,7 +372,7 @@ public:
|
|||
: module(module), symbolTable(module) {}
|
||||
LogicalResult
|
||||
initialize(DenseMap<ClassTypeOp, std::vector<NnModuleOp>> &instances) {
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
for (auto func : module.getOps<func::FuncOp>()) {
|
||||
Monomorphization monomorphization;
|
||||
monomorphization.func = func;
|
||||
bool canTriviallyMonomorphize = true;
|
||||
|
@ -455,7 +455,7 @@ static LogicalResult verifyNnModuleValueUses(Value value) {
|
|||
|
||||
// Verify that `func` conforms to the subset of allowable method bodies
|
||||
// that we can convert.
|
||||
static LogicalResult verifyFuncConformsToSubset(FuncOp func) {
|
||||
static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) {
|
||||
// TODO: Investingate why WalkResult::interrupt() doesn't propagate properly.
|
||||
LogicalResult ret = success();
|
||||
func.walk([&](Block *block) {
|
||||
|
@ -481,7 +481,7 @@ static LogicalResult verifyFuncConformsToSubset(FuncOp func) {
|
|||
static LogicalResult
|
||||
verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
|
||||
MonomorphizationTracker &tracker) {
|
||||
DenseMap<FuncOp, int> numMonomorphizations;
|
||||
DenseMap<func::FuncOp, int> numMonomorphizations;
|
||||
for (auto &monomorphization : tracker.getMonomorphizations()) {
|
||||
numMonomorphizations[monomorphization.func] += 1;
|
||||
}
|
||||
|
@ -489,7 +489,7 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
|
|||
for (auto classType : module.getOps<ClassTypeOp>()) {
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
if (!method.isPrivate()) {
|
||||
if (numMonomorphizations[symbolTable.lookup<FuncOp>(
|
||||
if (numMonomorphizations[symbolTable.lookup<func::FuncOp>(
|
||||
method.function())] > 1) {
|
||||
method.emitError()
|
||||
<< "public function with multiple monomorphizations";
|
||||
|
@ -503,11 +503,10 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
|
|||
|
||||
// Rewrite `func`, given that all values of `NnModuleType` have been mapped in
|
||||
// `mapping` to corresponding global instances.
|
||||
static LogicalResult
|
||||
rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping,
|
||||
SymbolTable &symbolTable,
|
||||
DenseMap<Monomorphization, FuncOp> &newFuncs,
|
||||
ObjectGraphInfo &objectGraphInfo) {
|
||||
static LogicalResult rewriteMonomorphizedFuncClone(
|
||||
func::FuncOp func, BlockAndValueMapping mapping, SymbolTable &symbolTable,
|
||||
DenseMap<Monomorphization, func::FuncOp> &newFuncs,
|
||||
ObjectGraphInfo &objectGraphInfo) {
|
||||
|
||||
SmallVector<Operation *> toErase;
|
||||
auto handlePrimSetAttr = [&](PrimSetAttrOp op) {
|
||||
|
@ -605,7 +604,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
|||
// static analysis to discover how to monomorphize th eprogram, including
|
||||
// tracking instances through control flow, through get/set attr, etc. We
|
||||
// implement a very simple subset of cases.
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
for (auto func : module.getOps<func::FuncOp>()) {
|
||||
if (failed(verifyFuncConformsToSubset(func)))
|
||||
return failure();
|
||||
}
|
||||
|
@ -637,10 +636,10 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
|||
|
||||
// Step 4: Clone/rewrite functions to implement the necessary
|
||||
// monomorphizations.
|
||||
DenseMap<Monomorphization, FuncOp> newFuncs;
|
||||
DenseMap<Monomorphization, func::FuncOp> newFuncs;
|
||||
int uniquifier = 0;
|
||||
for (auto &monomorphization : tracker.getMonomorphizations()) {
|
||||
auto newFunc = cast<FuncOp>(monomorphization.func->clone());
|
||||
auto newFunc = cast<func::FuncOp>(monomorphization.func->clone());
|
||||
newFuncs[monomorphization] = newFunc;
|
||||
Optional<LinkageInfo> linkageInfo = None;
|
||||
// If it is potentially a method, check its linkage info.
|
||||
|
@ -675,14 +674,14 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
|||
}
|
||||
|
||||
// Step 5: Clean up object graph.
|
||||
DenseSet<FuncOp> liveFuncs;
|
||||
DenseSet<func::FuncOp> liveFuncs;
|
||||
for (auto &kv : newFuncs) {
|
||||
liveFuncs.insert(kv.second);
|
||||
}
|
||||
for (auto &op : llvm::make_early_inc_range(module.getOps())) {
|
||||
if (isa<GlobalSlotOp>(&op))
|
||||
continue;
|
||||
if (auto func = dyn_cast<FuncOp>(op)) {
|
||||
if (auto func = dyn_cast<func::FuncOp>(op)) {
|
||||
if (liveFuncs.contains(func))
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -355,7 +355,7 @@ class MaximizeValueSemanticsPass
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createMaximizeValueSemanticsPass() {
|
||||
return std::make_unique<MaximizeValueSemanticsPass>();
|
||||
}
|
||||
|
|
|
@ -10,9 +10,11 @@
|
|||
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
if (options.optimize) {
|
||||
// Eliminate the PrimTupleIndexOp generated from the
|
||||
// adjustCallingConventions
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// Inline global slots, which for most inference scenarios deletes them.
|
||||
// This also exposes more information to intraprocedural transformations
|
||||
// below like MaximizeValueSemantics and RefineTypes.
|
||||
|
@ -105,14 +105,14 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
}
|
||||
|
||||
// Reduce variants of ops to a smaller set of primitives.
|
||||
pm.addNestedPass<FuncOp>(createReduceOpVariantsPass());
|
||||
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// OPT-ONLY: Right now we rely on this to eliminate certain branches that
|
||||
// guard unreachable code that backends can't handle yet, such as lists,
|
||||
// RaiseException, unimplemented tensor ops, and only-used-in-training
|
||||
// operations on `torch.global_slot`'s.
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// OPT-ONLY: We may have deleted some `torch.global_slot.get` /
|
||||
// `torch.global_slot.get` ops, which may have left more
|
||||
// `torch.global_slot`'s unused.
|
||||
|
@ -124,7 +124,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
|
||||
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||
|
||||
// Do shape refinement.
|
||||
// This must be run before RefineTypes (which primarily does dtype inference),
|
||||
|
@ -132,7 +132,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
// operand.
|
||||
createTorchShapeRefinementPipeline(pm, options);
|
||||
// Refine types in the program, which mainly means inferring dtypes of ops.
|
||||
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
|
||||
|
||||
// Propagate to ABI return types the shape/dtype information discovered by
|
||||
// the previous pass. Doing this is ABI-compatible for our backends.
|
||||
|
@ -142,7 +142,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
// This can fold away some branches given the information got from
|
||||
// RefineTypes before doing maximize value sematics which only works with
|
||||
// basic blocks.
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
|
||||
if (options.optimize) {
|
||||
|
@ -152,9 +152,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
// branches that guard unreachable code that backends can't handle yet, such
|
||||
// as lists, RaiseException, unimplemented aten ops, and
|
||||
// only-used-in-training operations on `torch.global_slot`'s.
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
}
|
||||
pm.addNestedPass<FuncOp>(Torch::createDecomposeComplexOpsPass());
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createDecomposeComplexOpsPass());
|
||||
|
||||
// TODO: VerifyTorchBackendContractPass.
|
||||
}
|
||||
|
@ -172,11 +172,11 @@ void mlir::torch::Torch::createTorchShapeRefinementPipeline(
|
|||
// as hard as possible" kind of thing, so it's inherently somewhat brittle.
|
||||
// The idea is to keep strengthening what we do here to support the shape
|
||||
// library. We don't need to support arbitrary programs, thankfully.
|
||||
pm.addNestedPass<FuncOp>(Torch::createSimplifyShapeCalculationsPass());
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
|
||||
// Run CSE, then see if we can simplify further.
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
pm.addNestedPass<FuncOp>(Torch::createSimplifyShapeCalculationsPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
|
||||
|
||||
// Drop shape calculations, leaving behind the shape-refined program.
|
||||
pm.addNestedPass<FuncOp>(Torch::createDropShapeCalculationsPass());
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createDropShapeCalculationsPass());
|
||||
}
|
||||
|
|
|
@ -33,10 +33,10 @@ public:
|
|||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||
op.receiver().getType().cast<NnModuleType>().getClassName());
|
||||
assert(classType && "malformed module -- missing ClassTypeOp");
|
||||
FuncOp func;
|
||||
func::FuncOp func;
|
||||
for (auto method : classType.getOps<MethodOp>()) {
|
||||
if (method.name() == op.name()) {
|
||||
func = symbolTable.lookup<FuncOp>(method.function());
|
||||
func = symbolTable.lookup<func::FuncOp>(method.function());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -207,7 +207,7 @@ public:
|
|||
assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 &&
|
||||
"Torch JIT operators shouldn't have regions or successors");
|
||||
|
||||
Operation *newOp = rewriter.createOperation(state);
|
||||
Operation *newOp = rewriter.create(state);
|
||||
auto tensor =
|
||||
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
||||
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
||||
|
@ -276,7 +276,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createReduceOpVariantsPass() {
|
||||
return std::make_unique<ReduceOpVariantsPass>();
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ class RefinePublicReturnPass
|
|||
: public RefinePublicReturnBase<RefinePublicReturnPass> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
module.walk([&](FuncOp func) {
|
||||
module.walk([&](func::FuncOp func) {
|
||||
if (func.getVisibility() != SymbolTable::Visibility::Public)
|
||||
return;
|
||||
if (func.isExternal())
|
||||
|
@ -40,7 +40,7 @@ class RefinePublicReturnPass
|
|||
});
|
||||
}
|
||||
|
||||
void rewriteSignature(FuncOp func) {
|
||||
void rewriteSignature(func::FuncOp func) {
|
||||
// Find the unique return op.
|
||||
func::ReturnOp returnOp;
|
||||
WalkResult walkResult = func.walk([&](func::ReturnOp op) {
|
||||
|
@ -90,7 +90,7 @@ class RefinePublicReturnPass
|
|||
returnOp->setOperands(newOperands);
|
||||
|
||||
// Update the function type.
|
||||
auto funcType = func.getType();
|
||||
auto funcType = func.getFunctionType();
|
||||
func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(),
|
||||
ValueRange(newOperands).getTypes()));
|
||||
}
|
||||
|
|
|
@ -1191,7 +1191,7 @@ static bool isSafeToRefineOperandInPlace(OpOperand *use, Type newOperandType) {
|
|||
return operationIsValidWithRefinedType(use, newOperandType);
|
||||
}
|
||||
|
||||
void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
||||
void optimize(func::FuncOp func, TypeAnalyzer &analyzer) {
|
||||
func.walk([&](Operation *op) {
|
||||
auto convertValuesToMostRefinedType = [&](ValueRange values, OpBuilder &b) {
|
||||
for (Value v : values) {
|
||||
|
@ -1336,7 +1336,7 @@ class RefineTypesPass : public RefineTypesBase<RefineTypesPass> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createRefineTypesPass() {
|
||||
return std::make_unique<RefineTypesPass>();
|
||||
}
|
||||
|
|
|
@ -169,7 +169,7 @@ static Value adjustShapeFunctionArg(Value operand, Type desiredType,
|
|||
// from the shape library.
|
||||
static LogicalResult
|
||||
populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands,
|
||||
mlir::FuncOp shapeFunction) {
|
||||
mlir::func::FuncOp shapeFunction) {
|
||||
// Create a call to the shape function in the `shapeCalculation` region.
|
||||
// We will import the callee from the shape library later.
|
||||
OpBuilder b(op.getContext());
|
||||
|
@ -241,7 +241,7 @@ class ReifyShapeCalculationsPass
|
|||
name = name.drop_front(strlen("valsem."));
|
||||
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
|
||||
auto shapeFunction =
|
||||
shapeLibrary->lookupSymbol<FuncOp>(shapeFunctionName);
|
||||
shapeLibrary->lookupSymbol<func::FuncOp>(shapeFunctionName);
|
||||
if (!shapeFunction)
|
||||
return;
|
||||
neededShapeFunctions.push_back(shapeFunctionName);
|
||||
|
@ -276,7 +276,7 @@ class ReifyShapeCalculationsPass
|
|||
auto symName = worklist.pop_back_val();
|
||||
if (importedFunctions.count(symName))
|
||||
continue;
|
||||
auto func = shapeLibrary->lookupSymbol<mlir::FuncOp>(symName);
|
||||
auto func = shapeLibrary->lookupSymbol<mlir::func::FuncOp>(symName);
|
||||
assert(func && "broken shape library");
|
||||
// Move the shape function from the library to the module this pass
|
||||
// is running on. (this mutates the library, but we re-parse it each time
|
||||
|
|
|
@ -415,7 +415,7 @@ class SimplifyShapeCalculationsPass
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createSimplifyShapeCalculationsPass() {
|
||||
return std::make_unique<SimplifyShapeCalculationsPass>();
|
||||
}
|
||||
|
|
|
@ -45,9 +45,10 @@ struct FuncBackendTypeConversionPass
|
|||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, typeConverter);
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getType()) &&
|
||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
|
||||
patterns, typeConverter);
|
||||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
|
||||
typeConverter.isLegal(&op.getBody());
|
||||
});
|
||||
populateCallOpTypeConversionPattern(patterns, typeConverter);
|
||||
|
@ -155,7 +156,7 @@ struct FinalizingBackendTypeConversionPass
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
|
||||
return std::make_unique<FinalizingBackendTypeConversionPass>();
|
||||
}
|
||||
|
|
|
@ -10,9 +10,12 @@
|
|||
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
|
||||
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
|
||||
namespace torch {
|
||||
namespace TorchConversion {
|
||||
|
||||
|
|
|
@ -59,26 +59,27 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
|||
// We do this first as it tends to involve pattern-matching against constants,
|
||||
// (e.g. dimensions which must be constant in a ranked programming model)
|
||||
// and those constants get somewhat obscured by TorchToStd.
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToTMTensorPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
|
||||
pm.addNestedPass<FuncOp>(memref::createExpandOpsPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToStdPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
||||
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// Resolve `dim` ops on tensors (which currently live in the `memref`
|
||||
// dialect for some reason -- we don't have memrefs at this level).
|
||||
pm.addNestedPass<FuncOp>(memref::createResolveShapedTypeResultDimsPass());
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
memref::createResolveShapedTypeResultDimsPass());
|
||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// linalg-on-tensors backend contract.
|
||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
pm.addNestedPass<FuncOp>(
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
|
||||
// Verify that we have lowered to the form that linalg on tensors backends
|
||||
|
@ -93,21 +94,21 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
|||
pm.addPass(
|
||||
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
||||
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass());
|
||||
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
|
||||
// Perform rank broadcasting so TosaToLinalg pass works
|
||||
pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass());
|
||||
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
|
||||
|
||||
if (options.optimize) {
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// TOSA backend contract.
|
||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
pm.addNestedPass<FuncOp>(
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
|
||||
// Verify that we have lowered to the form that TOSA backends
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
|
@ -60,7 +62,7 @@ class VerifyLinalgOnTensorsBackendContractPass
|
|||
ConversionTarget target(*context);
|
||||
|
||||
// Structural operations.
|
||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>(
|
||||
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
|
||||
opHasLegalTypes);
|
||||
|
||||
target.addDynamicallyLegalOp<GetNextSeedOp>(opHasLegalTypes);
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
|
@ -39,7 +40,7 @@ class VerifyTosaBackendContractPass
|
|||
ConversionTarget target(*context);
|
||||
|
||||
// Structural operations.
|
||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>(
|
||||
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
|
||||
opHasLegalTypes);
|
||||
// Basic scalar operations.
|
||||
target.addLegalDialect<tosa::TosaDialect>();
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#ifndef REFBACKEND_PASSDETAIL_H
|
||||
#define REFBACKEND_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
|
@ -68,7 +68,7 @@ static bool isArgMemRefTypeValid(Type type) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static void addEmitCInterfaceAttr(FuncOp func) {
|
||||
static void addEmitCInterfaceAttr(func::FuncOp func) {
|
||||
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
|
||||
}
|
||||
|
||||
|
@ -115,7 +115,7 @@ static void replaceReturnWithCall(OpBuilder b, func::ReturnOp op,
|
|||
}
|
||||
|
||||
static LogicalResult mungeFunction(
|
||||
FuncOp func,
|
||||
func::FuncOp func,
|
||||
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
|
||||
// Only need to call mungeFunction for functions callable from outside of the
|
||||
// module.
|
||||
|
@ -188,17 +188,17 @@ class MungeCallingConventions
|
|||
auto module = getOperation();
|
||||
OpBuilder b(module.getBodyRegion());
|
||||
std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
for (auto func : module.getOps<func::FuncOp>()) {
|
||||
if (failed(mungeFunction(func, invokedConsumeFuncReturnFuncs)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// Create FuncOp for consumeFuncReturnFuncs that are used.
|
||||
for (auto &p : invokedConsumeFuncReturnFuncs) {
|
||||
auto consumeFuncReturnFunc =
|
||||
b.create<FuncOp>(module.getLoc(), p.first,
|
||||
FunctionType::get(module.getContext(), p.second, {}),
|
||||
b.getStringAttr("private"));
|
||||
auto consumeFuncReturnFunc = b.create<func::FuncOp>(
|
||||
module.getLoc(), p.first,
|
||||
FunctionType::get(module.getContext(), p.second, {}),
|
||||
b.getStringAttr("private"));
|
||||
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
||||
}
|
||||
}
|
||||
|
@ -309,7 +309,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
|
||||
return std::make_unique<ExpandOpsForLLVM>();
|
||||
}
|
||||
|
@ -366,7 +366,7 @@ class MungeMemrefCopy : public MungeMemrefCopyBase<MungeMemrefCopy> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::RefBackend::createMungeMemrefCopyPass() {
|
||||
return std::make_unique<MungeMemrefCopy>();
|
||||
}
|
||||
|
@ -390,7 +390,7 @@ class GeneralizeTensorPad
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::RefBackend::createGeneralizeTensorPadPass() {
|
||||
return std::make_unique<GeneralizeTensorPad>();
|
||||
}
|
||||
|
|
|
@ -36,8 +36,8 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
|||
MlirAttribute symNameAttr = mlirStringAttrGet(
|
||||
context, toMlirStringRef(function->qualname().qualifiedName()));
|
||||
MlirOperation func = createMlirOperation(
|
||||
"builtin.func", loc, mlirRegionCreate(),
|
||||
toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)),
|
||||
"func.func", loc, mlirRegionCreate(),
|
||||
toMlirNamedAttribute("function_type", mlirTypeAttrGet(functionType)),
|
||||
toMlirNamedAttribute("sym_name", symNameAttr));
|
||||
std::vector<MlirAttribute> argAttrDicts;
|
||||
for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) {
|
||||
|
|
|
@ -29,7 +29,7 @@ import torch
|
|||
from torch.jit import ScriptFunction
|
||||
|
||||
from torch_mlir import ir
|
||||
from torch_mlir.dialects.builtin import FuncOp
|
||||
from torch_mlir.dialects.func import FuncOp
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
|
||||
|
||||
|
|
|
@ -115,15 +115,15 @@ class RefBackendInvoker:
|
|||
|
||||
|
||||
LOWERING_PIPELINE = ",".join([
|
||||
"builtin.func(refback-generalize-tensor-pad)",
|
||||
"func.func(refback-generalize-tensor-pad)",
|
||||
# Bufferize.
|
||||
"builtin.func(scf-bufferize)",
|
||||
"builtin.func(tm-tensor-bufferize)",
|
||||
"builtin.func(linalg-bufferize)",
|
||||
"func.func(scf-bufferize)",
|
||||
"func.func(tm-tensor-bufferize)",
|
||||
"func.func(linalg-bufferize)",
|
||||
"func-bufferize",
|
||||
"arith-bufferize",
|
||||
"builtin.func(tensor-bufferize)",
|
||||
"builtin.func(finalizing-bufferize)",
|
||||
"func.func(tensor-bufferize)",
|
||||
"func.func(finalizing-bufferize)",
|
||||
# Munge to make it ExecutionEngine compatible.
|
||||
# Specifically, we rewrite calling convention boundaries to be in terms
|
||||
# of unranked memref, and we rewrite the return to actually be a
|
||||
|
@ -135,17 +135,17 @@ LOWERING_PIPELINE = ",".join([
|
|||
# global seed used in stateful rng.
|
||||
"refback-insert-rng-globals",
|
||||
# Lower to LLVM
|
||||
"builtin.func(tm-tensor-to-loops)",
|
||||
"builtin.func(refback-munge-memref-copy)",
|
||||
"builtin.func(convert-linalg-to-loops)",
|
||||
"builtin.func(lower-affine)",
|
||||
"func.func(tm-tensor-to-loops)",
|
||||
"func.func(refback-munge-memref-copy)",
|
||||
"func.func(convert-linalg-to-loops)",
|
||||
"func.func(lower-affine)",
|
||||
"convert-scf-to-cf",
|
||||
"builtin.func(refback-expand-ops-for-llvm)",
|
||||
"builtin.func(arith-expand)",
|
||||
"builtin.func(convert-math-to-llvm)",
|
||||
"func.func(refback-expand-ops-for-llvm)",
|
||||
"func.func(arith-expand)",
|
||||
"func.func(convert-math-to-llvm)",
|
||||
"convert-linalg-to-llvm",
|
||||
"convert-memref-to-llvm",
|
||||
"builtin.func(convert-arith-to-llvm)",
|
||||
"func.func(convert-arith-to-llvm)",
|
||||
"convert-func-to-llvm",
|
||||
"convert-cf-to-llvm",
|
||||
"reconcile-unrealized-casts",
|
||||
|
|
|
@ -38,25 +38,25 @@ class LinalgOnTensorsTosaBackend(TosaBackend):
|
|||
"""
|
||||
|
||||
# TOSA legalization may emit tosa.const() ops. These are legalized
|
||||
# by tosa-to-standard to arith.constants. This mechanical transformation
|
||||
# by tosa-to-arith to arith.constants. This mechanical transformation
|
||||
# must be done prior to TOSA-to-LinAlg so that the latter does not fail.
|
||||
# This is an artifact of legalizations spread across a collection of simple
|
||||
# ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg,
|
||||
# that depend on TOSA as well as TOSA-to-Standard.
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
"builtin.func(tosa-to-standard)",
|
||||
"Lowering TOSA to Standard")
|
||||
"func.func(tosa-to-arith)",
|
||||
"Lowering TOSA to Arith")
|
||||
|
||||
# Named ops must be legalized prior to general tosa-to-linalg
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
"builtin.func(tosa-to-linalg-named)",
|
||||
"func.func(tosa-to-linalg-named)",
|
||||
"Lowering TOSA to Linalg-on-Tensors for Named Ops")
|
||||
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
"builtin.func(tosa-to-linalg)",
|
||||
"func.func(tosa-to-linalg)",
|
||||
"Lowering TOSA to Linalg-on-Tensors")
|
||||
return self.refbackend.compile(imported_module)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @forward
|
||||
builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
|
|
|
@ -42,7 +42,7 @@ torch.nn_module {
|
|||
torch.slot "t1", %t : !torch.tensor
|
||||
torch.slot "t2", %t : !torch.tensor
|
||||
} : !torch.nn.Module<"c">
|
||||
builtin.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor {
|
||||
func.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor {
|
||||
%t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor
|
||||
%t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor
|
||||
%cst = torch.constant.int 1
|
||||
|
@ -63,7 +63,7 @@ torch.nn_module {
|
|||
torch.slot "t1", %t : !torch.tensor
|
||||
torch.slot "t2", %t : !torch.tensor
|
||||
} : !torch.nn.Module<"c">
|
||||
builtin.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) {
|
||||
func.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) {
|
||||
torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
|
||||
torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
|
||||
return
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
|
||||
torch.class_type @c {}
|
||||
%0 = torch.nn_module {
|
||||
// expected-error @+1 {{'builtin.func' op is not allowed inside 'torch.nn_module'}}
|
||||
builtin.func @f()
|
||||
// expected-error @+1 {{'func.func' op is not allowed inside 'torch.nn_module'}}
|
||||
func.func @f()
|
||||
} : !torch.nn.Module<"c">
|
||||
|
||||
// -----
|
||||
|
@ -32,8 +32,8 @@ torch.class_type @c {
|
|||
// -----
|
||||
|
||||
torch.class_type @c {
|
||||
// expected-error @+1 {{'builtin.func' op is not allowed inside `torch.class_type`}}
|
||||
builtin.func @f()
|
||||
// expected-error @+1 {{'func.func' op is not allowed inside `torch.class_type`}}
|
||||
func.func @f()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -60,7 +60,7 @@ torch.class_type @c {
|
|||
torch.method "f", @f
|
||||
}
|
||||
|
||||
builtin.func @f(%arg0: !torch.nn.Module<"c">) {
|
||||
func.func @f(%arg0: !torch.nn.Module<"c">) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -71,11 +71,11 @@ torch.class_type @c {
|
|||
torch.method "f", @f
|
||||
}
|
||||
|
||||
builtin.func private @f(%arg0: !torch.nn.Module<"c">)
|
||||
func.func private @f(%arg0: !torch.nn.Module<"c">)
|
||||
|
||||
// -----
|
||||
|
||||
builtin.func private @f() {
|
||||
func.func private @f() {
|
||||
return
|
||||
}
|
||||
torch.class_type @c {
|
||||
|
@ -85,7 +85,7 @@ torch.class_type @c {
|
|||
|
||||
// -----
|
||||
|
||||
builtin.func private @f(!torch.nn.Module<"other_c">) {
|
||||
func.func private @f(%arg0: !torch.nn.Module<"other_c">) {
|
||||
return
|
||||
}
|
||||
torch.class_type @c {
|
||||
|
@ -101,21 +101,21 @@ torch.class_type @c {
|
|||
// -----
|
||||
|
||||
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}}
|
||||
builtin.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
|
||||
func.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{'torch.type_bound' must be TypeAttr}}
|
||||
builtin.func @f(%arg0: i32 {torch.type_bound = 1})
|
||||
func.func @f(%arg0: i32 {torch.type_bound = 1})
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}}
|
||||
builtin.func @f(%arg0: i32 {torch.type_bound = i32})
|
||||
func.func @f(%arg0: i32 {torch.type_bound = i32})
|
||||
|
||||
// -----
|
||||
|
||||
builtin.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
|
||||
func.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
|
||||
// expected-error @+1 {{operand type '!torch.optional<tensor>' and result type '!torch.tensor' are cast incompatible}}
|
||||
%0 = torch.derefine %arg0 : !torch.optional<tensor> to !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
|
@ -123,7 +123,7 @@ builtin.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
|
|||
|
||||
// -----
|
||||
|
||||
builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<tensor> {
|
||||
func.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<tensor> {
|
||||
// expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<tensor>' are cast incompatible}}
|
||||
%0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<tensor>
|
||||
return %0 : !torch.optional<tensor>
|
||||
|
@ -132,11 +132,11 @@ builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !
|
|||
// -----
|
||||
|
||||
// expected-error @+1 {{invalid dtype 'tuple<>' for !torch.tensor type}}
|
||||
builtin.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
|
||||
func.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
|
||||
|
||||
// -----
|
||||
|
||||
builtin.func @torch.tensor() {
|
||||
func.func @torch.tensor() {
|
||||
// Incompatible shape.
|
||||
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
||||
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
|
||||
|
@ -145,7 +145,7 @@ builtin.func @torch.tensor() {
|
|||
|
||||
// -----
|
||||
|
||||
builtin.func @torch.tensor() {
|
||||
func.func @torch.tensor() {
|
||||
// Incompatible dtype.
|
||||
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
||||
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
|
||||
|
@ -154,7 +154,7 @@ builtin.func @torch.tensor() {
|
|||
|
||||
// -----
|
||||
|
||||
builtin.func @torch.tensor() {
|
||||
func.func @torch.tensor() {
|
||||
// Incompatible type.
|
||||
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
||||
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1
|
||||
|
@ -163,7 +163,7 @@ builtin.func @torch.tensor() {
|
|||
|
||||
// -----
|
||||
|
||||
builtin.func @torch.prim.ListConstruct() {
|
||||
func.func @torch.prim.ListConstruct() {
|
||||
%int2 = torch.constant.int 2
|
||||
// expected-error@+1 {{operand types should have the same type as the list contained type}}
|
||||
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<tensor>
|
||||
|
@ -172,7 +172,7 @@ builtin.func @torch.prim.ListConstruct() {
|
|||
|
||||
// -----
|
||||
|
||||
builtin.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
|
||||
func.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
|
||||
%0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32>
|
||||
// expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}}
|
||||
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32>
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
|
||||
// CHECK: return
|
||||
builtin.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>,
|
||||
func.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>,
|
||||
%t1: !torch.vtensor<[1],f64>,
|
||||
%alpha: !torch.float) {
|
||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
|
||||
|
@ -25,7 +25,7 @@ builtin.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1
|
|||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
|
||||
// CHECK: return
|
||||
builtin.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
|
||||
func.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
|
||||
%t1: !torch.vtensor<[1],f64>,
|
||||
%alpha: !torch.float) {
|
||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
|
||||
|
@ -41,7 +41,7 @@ builtin.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
|
|||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return
|
||||
builtin.func @tensor_tensor$same_category_zero_rank_wider(
|
||||
func.func @tensor_tensor$same_category_zero_rank_wider(
|
||||
%t0: !torch.vtensor<[1],f32>,
|
||||
%t1: !torch.vtensor<[],f64>,
|
||||
%alpha: !torch.int) {
|
||||
|
@ -58,7 +58,7 @@ builtin.func @tensor_tensor$same_category_zero_rank_wider(
|
|||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return
|
||||
builtin.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>,
|
||||
func.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>,
|
||||
%t1: !torch.vtensor<[],f32>,
|
||||
%alpha: !torch.int) {
|
||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk>
|
||||
|
@ -73,7 +73,7 @@ builtin.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si
|
|||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return
|
||||
builtin.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>,
|
||||
func.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>,
|
||||
%t1: !torch.vtensor<[1],f32>,
|
||||
%alpha: !torch.float) {
|
||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk>
|
||||
|
@ -89,7 +89,7 @@ builtin.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],
|
|||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK: return
|
||||
builtin.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) {
|
||||
func.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) {
|
||||
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk>
|
||||
return
|
||||
}
|
||||
|
@ -103,7 +103,7 @@ builtin.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>
|
|||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
|
||||
// CHECK: return
|
||||
builtin.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) {
|
||||
func.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) {
|
||||
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue