mirror of https://github.com/llvm/torch-mlir
llvm: bump tag to e1318078 (#781)
The updated LLVM code includes a patch to create bfloat16 array attributes, thus enabling a different patch to torch-mlir to flesh out support for the bfloat16 type.pull/761/head snapshot-20220426.416
parent
9ec4712516
commit
9208bf0eb6
|
@ -463,7 +463,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
|
||||||
$_op->getAttrs());
|
$_op->getAttrs());
|
||||||
for (Region &r : $_op->getRegions())
|
for (Region &r : $_op->getRegions())
|
||||||
r.cloneInto(state.addRegion(), bvm);
|
r.cloneInto(state.addRegion(), bvm);
|
||||||
return b.createOperation(state);
|
return b.create(state);
|
||||||
}]
|
}]
|
||||||
>
|
>
|
||||||
];
|
];
|
||||||
|
|
|
@ -31,9 +31,7 @@ class TMTensor_Op<string mnemonic, list<Trait> traits = []> :
|
||||||
TMTensorInterface,
|
TMTensorInterface,
|
||||||
SingleBlockImplicitTerminator<"::mlir::torch::TMTensor::YieldOp">
|
SingleBlockImplicitTerminator<"::mlir::torch::TMTensor::YieldOp">
|
||||||
])> {
|
])> {
|
||||||
let verifier = [{ return verify$cppClass(*this); }];
|
let hasVerifier = 1;
|
||||||
let printer = [{ return print$cppClass(p, *this); }];
|
|
||||||
let parser = [{ return parse$cppClass(parser, result); }];
|
|
||||||
code extraTMTensorOpClassDeclaration = [{
|
code extraTMTensorOpClassDeclaration = [{
|
||||||
SmallVector<Value> getDestinationOperands(OpBuilder &b) {
|
SmallVector<Value> getDestinationOperands(OpBuilder &b) {
|
||||||
SmallVector<Value> dest(outputs().begin(), outputs().end());
|
SmallVector<Value> dest(outputs().begin(), outputs().end());
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
||||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
|
@ -10,14 +10,15 @@
|
||||||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
||||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace TMTensor {
|
namespace TMTensor {
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createTMTensorToLoopsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorToLoopsPass();
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createTMTensorBufferizePass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorBufferizePass();
|
||||||
|
|
||||||
void registerPasses();
|
void registerPasses();
|
||||||
|
|
||||||
|
|
|
@ -13,12 +13,12 @@
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def TMTensorToLoops :
|
def TMTensorToLoops :
|
||||||
Pass<"tm-tensor-to-loops", "FuncOp"> {
|
Pass<"tm-tensor-to-loops", "func::FuncOp"> {
|
||||||
let summary = "Convert TMTensor ops to loops and Linalg ops.";
|
let summary = "Convert TMTensor ops to loops and Linalg ops.";
|
||||||
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
|
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TMTensorBufferize : Pass<"tm-tensor-bufferize", "FuncOp"> {
|
def TMTensorBufferize : Pass<"tm-tensor-bufferize", "func::FuncOp"> {
|
||||||
let summary = "Bufferize the TMTensor dialect";
|
let summary = "Bufferize the TMTensor dialect";
|
||||||
let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()";
|
let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()";
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,34 +88,34 @@ OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
|
||||||
// ScanOp
|
// ScanOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult verifyScanOp(ScanOp op) {
|
LogicalResult ScanOp::verify() {
|
||||||
if (op.getNumInputs() != 1) {
|
if (getNumInputs() != 1) {
|
||||||
return op.emitOpError("expected one input operands");
|
return emitOpError("expected one input operands");
|
||||||
}
|
}
|
||||||
if (op.getNumOutputs() != 2) {
|
if (getNumOutputs() != 2) {
|
||||||
return op.emitOpError("expected two output operands");
|
return emitOpError("expected two output operands");
|
||||||
}
|
}
|
||||||
if (!op.input().getType().isa<ShapedType>()) {
|
if (!input().getType().isa<ShapedType>()) {
|
||||||
return op.emitOpError("expected first input element type to be shaped");
|
return emitOpError("expected first input element type to be shaped");
|
||||||
}
|
}
|
||||||
auto accumulatorType = op.accumulator().getType().cast<ShapedType>();
|
auto accumulatorType = accumulator().getType().cast<ShapedType>();
|
||||||
auto inputType = op.input().getType().cast<ShapedType>();
|
auto inputType = input().getType().cast<ShapedType>();
|
||||||
auto outputType = op.output().getType().cast<ShapedType>();
|
auto outputType = output().getType().cast<ShapedType>();
|
||||||
ArrayRef<int64_t> inputShapes = inputType.getShape();
|
ArrayRef<int64_t> inputShapes = inputType.getShape();
|
||||||
ArrayRef<int64_t> outputShapes = outputType.getShape();
|
ArrayRef<int64_t> outputShapes = outputType.getShape();
|
||||||
if (accumulatorType.getElementType() != inputType.getElementType()) {
|
if (accumulatorType.getElementType() != inputType.getElementType()) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"expected input/accumulator element types to be identical");
|
"expected input/accumulator element types to be identical");
|
||||||
}
|
}
|
||||||
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
|
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
|
||||||
int64_t accumulatorRank = accumulatorType.getRank();
|
int64_t accumulatorRank = accumulatorType.getRank();
|
||||||
if (accumulatorRank != inputType.getRank() - 1) {
|
if (accumulatorRank != inputType.getRank() - 1) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"expected accumulator rank to be equal to input rank - 1");
|
"expected accumulator rank to be equal to input rank - 1");
|
||||||
}
|
}
|
||||||
SmallVector<int64_t> expectedAccumulatorShape;
|
SmallVector<int64_t> expectedAccumulatorShape;
|
||||||
for (size_t i = 0; i < (size_t)inputType.getRank(); i++) {
|
for (size_t i = 0; i < (size_t)inputType.getRank(); i++) {
|
||||||
if (i != op.dimension())
|
if (i != dimension())
|
||||||
expectedAccumulatorShape.push_back(inputShapes[i]);
|
expectedAccumulatorShape.push_back(inputShapes[i]);
|
||||||
}
|
}
|
||||||
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
|
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
|
||||||
|
@ -124,14 +124,13 @@ static LogicalResult verifyScanOp(ScanOp op) {
|
||||||
std::get<1>(s) != ShapedType::kDynamicSize &&
|
std::get<1>(s) != ShapedType::kDynamicSize &&
|
||||||
std::get<0>(s) != std::get<1>(s);
|
std::get<0>(s) != std::get<1>(s);
|
||||||
})) {
|
})) {
|
||||||
return op.emitOpError("incompatible input/accumulator shapes");
|
return emitOpError("incompatible input/accumulator shapes");
|
||||||
}
|
}
|
||||||
if (inputType.getElementType() != outputType.getElementType()) {
|
if (inputType.getElementType() != outputType.getElementType()) {
|
||||||
return op.emitOpError(
|
return emitOpError("expected input/output element types to be identical");
|
||||||
"expected input/output element types to be identical");
|
|
||||||
}
|
}
|
||||||
if (inputShapes.size() != outputShapes.size()) {
|
if (inputShapes.size() != outputShapes.size()) {
|
||||||
return op.emitOpError("expected input/output to have identical ranks");
|
return emitOpError("expected input/output to have identical ranks");
|
||||||
}
|
}
|
||||||
if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
|
if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
|
||||||
[](std::tuple<int64_t, int64_t> s) {
|
[](std::tuple<int64_t, int64_t> s) {
|
||||||
|
@ -139,7 +138,7 @@ static LogicalResult verifyScanOp(ScanOp op) {
|
||||||
std::get<1>(s) != ShapedType::kDynamicSize &&
|
std::get<1>(s) != ShapedType::kDynamicSize &&
|
||||||
std::get<0>(s) != std::get<1>(s);
|
std::get<0>(s) != std::get<1>(s);
|
||||||
})) {
|
})) {
|
||||||
return op.emitOpError("incompatible input/output shapes");
|
return emitOpError("incompatible input/output shapes");
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -232,11 +231,11 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
|
||||||
});
|
});
|
||||||
|
|
||||||
auto &srcBlock = region().front();
|
auto &srcBlock = region().front();
|
||||||
Region ®ion = scfIf.getElseRegion();
|
Region &thisRegion = scfIf.getElseRegion();
|
||||||
BlockAndValueMapping bvm;
|
BlockAndValueMapping bvm;
|
||||||
{
|
{
|
||||||
OpBuilder::InsertionGuard guard(b);
|
OpBuilder::InsertionGuard guard(b);
|
||||||
auto &block = region.front();
|
auto &block = thisRegion.front();
|
||||||
b.setInsertionPointToEnd(&block);
|
b.setInsertionPointToEnd(&block);
|
||||||
for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
|
for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
|
||||||
bvm.map(std::get<0>(it), std::get<1>(it));
|
bvm.map(std::get<0>(it), std::get<1>(it));
|
||||||
|
@ -275,48 +274,47 @@ LogicalResult ScanOp::fold(ArrayRef<Attribute>,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ScatterOp
|
// ScatterOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
static LogicalResult verifyScatterOp(ScatterOp op) {
|
LogicalResult ScatterOp::verify() {
|
||||||
if (op.inputs().size() != 2) {
|
if (inputs().size() != 2) {
|
||||||
return op.emitOpError("expected two input operands");
|
return emitOpError("expected two input operands");
|
||||||
}
|
}
|
||||||
if (op.outputs().size() != 1) {
|
if (outputs().size() != 1) {
|
||||||
return op.emitOpError("expected one output operand");
|
return emitOpError("expected one output operand");
|
||||||
}
|
}
|
||||||
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
|
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
|
||||||
return t1.getShape()[dim] == t2.getShape()[dim];
|
return t1.getShape()[dim] == t2.getShape()[dim];
|
||||||
};
|
};
|
||||||
|
|
||||||
auto indicesType = op.getIndicesType();
|
auto indicesType = getIndicesType();
|
||||||
if (indicesType.getRank() != 2 ||
|
if (indicesType.getRank() != 2 ||
|
||||||
!indicesType.getElementType().isInteger(32)) {
|
!indicesType.getElementType().isInteger(32)) {
|
||||||
return op.emitOpError(
|
return emitOpError("expected indices to be of rank 2 of i32 element type");
|
||||||
"expected indices to be of rank 2 of i32 element type");
|
|
||||||
}
|
}
|
||||||
auto indexDepth = op.getIndexDepth();
|
auto indexDepth = getIndexDepth();
|
||||||
if (indexDepth == ShapedType::kDynamicSize) {
|
if (indexDepth == ShapedType::kDynamicSize) {
|
||||||
return op.emitOpError("expected index depth is static");
|
return emitOpError("expected index depth is static");
|
||||||
}
|
}
|
||||||
|
|
||||||
// The first dimension of the indices should match the first dimension of the
|
// The first dimension of the indices should match the first dimension of the
|
||||||
// output. They indicate to the number of updates.
|
// output. They indicate to the number of updates.
|
||||||
auto updateType = op.getUpdateType();
|
auto updateType = getUpdateType();
|
||||||
if (updateType.getRank() < 1) {
|
if (updateType.getRank() < 1) {
|
||||||
return op.emitOpError("expected update value to be at least rank 1");
|
return emitOpError("expected update value to be at least rank 1");
|
||||||
}
|
}
|
||||||
if (!checkDimensionsMatch(indicesType, updateType, 0)) {
|
if (!checkDimensionsMatch(indicesType, updateType, 0)) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"mismatch in shape of indices and update value at dim#0");
|
"mismatch in shape of indices and update value at dim#0");
|
||||||
}
|
}
|
||||||
auto originalType = op.getOriginalType();
|
auto originalType = getOriginalType();
|
||||||
if (updateType.getRank() - 1 > originalType.getRank()) {
|
if (updateType.getRank() - 1 > originalType.getRank()) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"update value rank exceeds the rank of the original value");
|
"update value rank exceeds the rank of the original value");
|
||||||
}
|
}
|
||||||
|
|
||||||
// indexDepth + update dims should cover the original dims. The first dim of
|
// indexDepth + update dims should cover the original dims. The first dim of
|
||||||
// update is the number of updates.
|
// update is the number of updates.
|
||||||
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
|
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"index depth and update value does not cover rank of original value");
|
"index depth and update value does not cover rank of original value");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -331,7 +329,7 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
|
||||||
int64_t updateDim = std::get<1>(it);
|
int64_t updateDim = std::get<1>(it);
|
||||||
if (updateType.getDimSize(updateDim) !=
|
if (updateType.getDimSize(updateDim) !=
|
||||||
originalType.getDimSize(originalDim)) {
|
originalType.getDimSize(originalDim)) {
|
||||||
return op.emitOpError("mismatch in shape of update value dim#")
|
return emitOpError("mismatch in shape of update value dim#")
|
||||||
<< updateDim << " and original value at dim#" << originalDim;
|
<< updateDim << " and original value at dim#" << originalDim;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -345,36 +343,36 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
|
||||||
int64_t updateDim = std::get<1>(it);
|
int64_t updateDim = std::get<1>(it);
|
||||||
if (updateType.getDimSize(updateDim) >
|
if (updateType.getDimSize(updateDim) >
|
||||||
originalType.getDimSize(originalDim)) {
|
originalType.getDimSize(originalDim)) {
|
||||||
return op.emitOpError("indexed shape of update value dim#")
|
return emitOpError("indexed shape of update value dim#")
|
||||||
<< updateDim << " exceeds original value at dim#" << originalDim
|
<< updateDim << " exceeds original value at dim#" << originalDim
|
||||||
<< " " << updateType.getDimSize(updateDim) << " "
|
<< " " << updateType.getDimSize(updateDim) << " "
|
||||||
<< originalType.getDimSize(originalDim);
|
<< originalType.getDimSize(originalDim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Region ®ion = op.region();
|
Region &thisRegion = region();
|
||||||
Block *body = ®ion.front();
|
Block *body = &thisRegion.front();
|
||||||
if (body->getNumArguments() != 2) {
|
if (body->getNumArguments() != 2) {
|
||||||
return op.emitOpError("expected region to have two arguments");
|
return emitOpError("expected region to have two arguments");
|
||||||
}
|
}
|
||||||
Type arg0Type = body->getArgument(0).getType();
|
Type arg0Type = body->getArgument(0).getType();
|
||||||
Type arg1Type = body->getArgument(1).getType();
|
Type arg1Type = body->getArgument(1).getType();
|
||||||
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
|
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
|
||||||
return op.emitOpError(
|
return emitOpError(
|
||||||
"expected region to have scalar argument of integer or float types");
|
"expected region to have scalar argument of integer or float types");
|
||||||
}
|
}
|
||||||
if (arg0Type != updateType.getElementType()) {
|
if (arg0Type != updateType.getElementType()) {
|
||||||
return op.emitOpError("mismatch in argument 0 of region ")
|
return emitOpError("mismatch in argument 0 of region ")
|
||||||
<< arg0Type << " and element type of update value "
|
<< arg0Type << " and element type of update value "
|
||||||
<< updateType.getElementType();
|
<< updateType.getElementType();
|
||||||
}
|
}
|
||||||
if (arg1Type != originalType.getElementType()) {
|
if (arg1Type != originalType.getElementType()) {
|
||||||
return op.emitOpError("mismatch in argument 1 of region ")
|
return emitOpError("mismatch in argument 1 of region ")
|
||||||
<< arg1Type << " and element type of original value "
|
<< arg1Type << " and element type of original value "
|
||||||
<< originalType.getElementType();
|
<< originalType.getElementType();
|
||||||
}
|
}
|
||||||
if (arg0Type != arg1Type) {
|
if (arg0Type != arg1Type) {
|
||||||
return op.emitOpError("mismatch in region argument types ")
|
return emitOpError("mismatch in region argument types ")
|
||||||
<< arg0Type << " and " << arg1Type;
|
<< arg0Type << " and " << arg1Type;
|
||||||
}
|
}
|
||||||
auto yieldOp = cast<TMTensor::YieldOp>(body->getTerminator());
|
auto yieldOp = cast<TMTensor::YieldOp>(body->getTerminator());
|
||||||
|
@ -455,7 +453,8 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
|
||||||
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
|
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
|
||||||
Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
|
Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
|
||||||
|
|
||||||
if (starts[i]) cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
|
if (starts[i])
|
||||||
|
cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
|
||||||
starts[i] = cast;
|
starts[i] = cast;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/IR/BuiltinDialect.h"
|
#include "mlir/IR/BuiltinDialect.h"
|
||||||
|
@ -150,7 +151,7 @@ struct TMTensorBufferizePass
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
torch::TMTensor::createTMTensorBufferizePass() {
|
torch::TMTensor::createTMTensorBufferizePass() {
|
||||||
return std::make_unique<TMTensorBufferizePass>();
|
return std::make_unique<TMTensorBufferizePass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
@ -111,7 +112,7 @@ struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
torch::TMTensor::createTMTensorToLoopsPass() {
|
torch::TMTensor::createTMTensorToLoopsPass() {
|
||||||
return std::make_unique<TMTensorToLoopsPass>();
|
return std::make_unique<TMTensorToLoopsPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 8361c5da30588d3d4a48eae648f53be1feb5cfad
|
Subproject commit e1318078a4e160eb723bcbcfcdcc9a1b618f7067
|
|
@ -16,17 +16,17 @@ include "mlir/Pass/PassBase.td"
|
||||||
// Torch conversions
|
// Torch conversions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def ConvertTorchToStd : Pass<"convert-torch-to-std", "FuncOp"> {
|
def ConvertTorchToStd : Pass<"convert-torch-to-std", "func::FuncOp"> {
|
||||||
let summary = "Convert recognized Torch ops to Std ops";
|
let summary = "Convert recognized Torch ops to Std ops";
|
||||||
let constructor = "mlir::torch::createConvertTorchToStdPass()";
|
let constructor = "mlir::torch::createConvertTorchToStdPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "FuncOp"> {
|
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
|
||||||
let summary = "Convert recognized Torch ops to SCF ops";
|
let summary = "Convert recognized Torch ops to SCF ops";
|
||||||
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
|
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
|
||||||
let summary = "Convert recognized Torch ops to Linalg ops";
|
let summary = "Convert recognized Torch ops to Linalg ops";
|
||||||
let description = [{
|
let description = [{
|
||||||
Convert ATen ops to linalg ops.
|
Convert ATen ops to linalg ops.
|
||||||
|
@ -105,7 +105,7 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
||||||
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
|
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
|
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
|
||||||
let summary = "Convert Torch ops to TOSA ops";
|
let summary = "Convert Torch ops to TOSA ops";
|
||||||
let description = [{
|
let description = [{
|
||||||
This pass assumes that TOSA ops are responsible for emitting error
|
This pass assumes that TOSA ops are responsible for emitting error
|
||||||
|
@ -114,7 +114,7 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
|
||||||
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
|
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "FuncOp"> {
|
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
|
||||||
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
|
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
|
||||||
let description = [{
|
let description = [{
|
||||||
Convert ATen ops to tmtensor/linalg ops.
|
Convert ATen ops to tmtensor/linalg ops.
|
||||||
|
|
|
@ -10,12 +10,13 @@
|
||||||
#ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
#ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
||||||
#define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
#define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToLinalgPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -10,11 +10,12 @@
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
#define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToSCFPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToSCFPass();
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -10,12 +10,13 @@
|
||||||
#ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
#ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
||||||
#define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
#define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToStdPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToStdPass();
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -10,11 +10,12 @@
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTMTensorPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTMTensorPass();
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -10,12 +10,13 @@
|
||||||
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
||||||
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTosaPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
|
||||||
}
|
}
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,8 @@
|
||||||
#ifndef TORCH_TYPES
|
#ifndef TORCH_TYPES
|
||||||
#define TORCH_TYPES
|
#define TORCH_TYPES
|
||||||
|
|
||||||
|
include "mlir/IR/AttrTypeBase.td"
|
||||||
|
include "mlir/IR/DialectBase.td"
|
||||||
include "torch-mlir/Dialect/Torch/IR/TorchBase.td"
|
include "torch-mlir/Dialect/Torch/IR/TorchBase.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -24,28 +26,8 @@ class Torch_Type<string name, string typeMnemonic,
|
||||||
|
|
||||||
class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type<name, typeMnemonic> {
|
class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type<name, typeMnemonic> {
|
||||||
let parameters = (ins "::mlir::Type":$containedType);
|
let parameters = (ins "::mlir::Type":$containedType);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
|
||||||
let printer = [{
|
|
||||||
$_printer << "<";
|
|
||||||
// Print the contained type without the `!torch.` prefix.
|
|
||||||
printTorchDialectType(getImpl()->containedType, $_printer);
|
|
||||||
$_printer << ">";
|
|
||||||
}];
|
|
||||||
|
|
||||||
let parser = [{
|
|
||||||
if ($_parser.parseLess())
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
// Parse the contained type, but forward directly to our internal parsing
|
|
||||||
// of `torch` dialect types, so that we can parse nested types without
|
|
||||||
// the `!torch.` prefix.
|
|
||||||
Type containedType = parseTorchDialectType($_parser);
|
|
||||||
if (!containedType)
|
|
||||||
return Type();
|
|
||||||
if ($_parser.parseGreater())
|
|
||||||
return Type();
|
|
||||||
return get($_ctxt, containedType);
|
|
||||||
}];
|
|
||||||
let builders = [
|
let builders = [
|
||||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
|
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
|
||||||
return Base::get(containedType.getContext(), containedType);
|
return Base::get(containedType.getContext(), containedType);
|
||||||
|
@ -59,23 +41,7 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
|
||||||
Represents an instance of a `torch.nn.Module` with the given `className`.
|
Represents an instance of a `torch.nn.Module` with the given `className`.
|
||||||
}];
|
}];
|
||||||
let parameters = (ins StringRefParameter<"class name">:$className);
|
let parameters = (ins StringRefParameter<"class name">:$className);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
let printer = [{
|
|
||||||
$_printer << "<\"";
|
|
||||||
llvm::printEscapedString(getImpl()->className, $_printer.getStream());
|
|
||||||
$_printer << "\">";
|
|
||||||
}];
|
|
||||||
|
|
||||||
let parser = [{
|
|
||||||
if ($_parser.parseLess())
|
|
||||||
return Type();
|
|
||||||
std::string className;
|
|
||||||
if ($_parser.parseOptionalString(&className))
|
|
||||||
return Type();
|
|
||||||
if ($_parser.parseGreater())
|
|
||||||
return Type();
|
|
||||||
return get($_ctxt, className);
|
|
||||||
}];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For standard ArrayRefs, which require allocation.
|
// For standard ArrayRefs, which require allocation.
|
||||||
|
@ -186,6 +152,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
|
||||||
"::mlir::Type":$optionalDtype
|
"::mlir::Type":$optionalDtype
|
||||||
);
|
);
|
||||||
let genVerifyDecl = 1;
|
let genVerifyDecl = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
string extraBaseClassDeclaration = [{
|
string extraBaseClassDeclaration = [{
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
@ -243,6 +210,7 @@ def Torch_TupleType : Torch_Type<"Tuple", "tuple"> {
|
||||||
let parameters = (ins
|
let parameters = (ins
|
||||||
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
|
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
|
||||||
);
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_UnionType : Torch_Type<"Union", "union"> {
|
def Torch_UnionType : Torch_Type<"Union", "union"> {
|
||||||
|
@ -261,6 +229,7 @@ def Torch_UnionType : Torch_Type<"Union", "union"> {
|
||||||
let parameters = (ins
|
let parameters = (ins
|
||||||
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
|
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
|
||||||
);
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
|
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
|
||||||
|
@ -367,30 +336,7 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Torch Dict type with key and value type.
|
Torch Dict type with key and value type.
|
||||||
}];
|
}];
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
let printer = [{
|
|
||||||
$_printer << "<";
|
|
||||||
printTorchDialectType(getImpl()->keyType, $_printer);
|
|
||||||
$_printer << ", ";
|
|
||||||
printTorchDialectType(getImpl()->valueType, $_printer);
|
|
||||||
$_printer<< ">";
|
|
||||||
}];
|
|
||||||
|
|
||||||
let parser = [{
|
|
||||||
if ($_parser.parseLess())
|
|
||||||
return Type();
|
|
||||||
Type keyType = parseTorchDialectType($_parser);
|
|
||||||
if (!keyType)
|
|
||||||
return Type();
|
|
||||||
if ($_parser.parseComma())
|
|
||||||
return Type();
|
|
||||||
Type valueType = parseTorchDialectType($_parser);
|
|
||||||
if (!valueType)
|
|
||||||
return Type();
|
|
||||||
if ($_parser.parseGreater())
|
|
||||||
return Type();
|
|
||||||
return get($_ctxt, keyType, valueType);
|
|
||||||
}];
|
|
||||||
let builders = [
|
let builders = [
|
||||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
|
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
|
||||||
"::mlir::Type":$valueType), [{
|
"::mlir::Type":$valueType), [{
|
||||||
|
|
|
@ -10,11 +10,14 @@
|
||||||
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||||
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
class ModuleOp;
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
|
|
||||||
|
@ -48,25 +51,26 @@ void createTorchShapeRefinementPipeline(
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createRefineTypesPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createReduceOpVariantsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createReduceOpVariantsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createMaximizeValueSemanticsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createDecomposeComplexOpsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createDecomposeComplexOpsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createSimplifyShapeCalculationsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
createSimplifyShapeCalculationsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createDropShapeCalculationsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
|
||||||
|
|
||||||
StringRef getShapeLibrary();
|
StringRef getShapeLibrary();
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,7 @@ def AdjustCallingConventions
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
|
def RefineTypes : Pass<"torch-refine-types", "func::FuncOp"> {
|
||||||
let summary = "Refine types";
|
let summary = "Refine types";
|
||||||
let constructor = "mlir::torch::Torch::createRefineTypesPass()";
|
let constructor = "mlir::torch::Torch::createRefineTypesPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -149,7 +149,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
|
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
|
||||||
let summary = "Reduces variants of ops to a smaller set of ops.";
|
let summary = "Reduces variants of ops to a smaller set of ops.";
|
||||||
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
|
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -165,7 +165,7 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "FuncOp"> {
|
def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "func::FuncOp"> {
|
||||||
let summary = "Use value-semantic tensors where possible.";
|
let summary = "Use value-semantic tensors where possible.";
|
||||||
let description = [{
|
let description = [{
|
||||||
Use value-semantic tensors where possible to make the program more
|
Use value-semantic tensors where possible to make the program more
|
||||||
|
@ -215,7 +215,7 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> {
|
def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
|
||||||
let summary = "Decompose complicated torch operations";
|
let summary = "Decompose complicated torch operations";
|
||||||
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
|
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -238,7 +238,7 @@ def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp">
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "FuncOp"> {
|
def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func::FuncOp"> {
|
||||||
let summary = "Simplify reified shape calculations.";
|
let summary = "Simplify reified shape calculations.";
|
||||||
let constructor = "mlir::torch::Torch::createSimplifyShapeCalculationsPass()";
|
let constructor = "mlir::torch::Torch::createSimplifyShapeCalculationsPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -246,7 +246,7 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "FuncO
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "FuncOp"> {
|
def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"> {
|
||||||
let summary = "Drop reified shape calculations.";
|
let summary = "Drop reified shape calculations.";
|
||||||
let constructor = "mlir::torch::Torch::createDropShapeCalculationsPass()";
|
let constructor = "mlir::torch::Torch::createDropShapeCalculationsPass()";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -10,12 +10,15 @@
|
||||||
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
||||||
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
class ModuleOp;
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace TorchConversion {
|
namespace TorchConversion {
|
||||||
|
|
||||||
|
@ -36,7 +39,7 @@ createVerifyInvariantsBeforeBackendLoweringPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
createFinalizingBackendTypeConversionPass();
|
createFinalizingBackendTypeConversionPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
|
|
@ -43,7 +43,7 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu
|
||||||
}
|
}
|
||||||
|
|
||||||
def FinalizingBackendTypeConversion
|
def FinalizingBackendTypeConversion
|
||||||
: Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> {
|
: Pass<"torch-finalizing-backend-type-conversion", "func::FuncOp"> {
|
||||||
let summary = "Finalizes a partial conversion to builtin tensors";
|
let summary = "Finalizes a partial conversion to builtin tensors";
|
||||||
let constructor =
|
let constructor =
|
||||||
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()";
|
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()";
|
||||||
|
|
|
@ -10,10 +10,13 @@
|
||||||
#ifndef TORCHMLIR_REFBACKEND_PASSES_H
|
#ifndef TORCHMLIR_REFBACKEND_PASSES_H
|
||||||
#define TORCHMLIR_REFBACKEND_PASSES_H
|
#define TORCHMLIR_REFBACKEND_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
class ModuleOp;
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace RefBackend {
|
namespace RefBackend {
|
||||||
|
|
||||||
|
@ -22,13 +25,13 @@ void registerRefBackendPasses();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createExpandOpsForLLVMPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createMungeMemrefCopyPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createGeneralizeTensorPadPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
|
||||||
} // namespace RefBackend
|
} // namespace RefBackend
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -24,18 +24,18 @@ def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> {
|
||||||
let dependentDialects = ["memref::MemRefDialect"];
|
let dependentDialects = ["memref::MemRefDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "FuncOp"> {
|
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> {
|
||||||
let summary = "Expand ops into more primitive ops before LLVM lowering.";
|
let summary = "Expand ops into more primitive ops before LLVM lowering.";
|
||||||
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
|
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
|
||||||
}
|
}
|
||||||
|
|
||||||
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "FuncOp"> {
|
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
|
||||||
let summary = "Munge memref.copy to linalg.copy";
|
let summary = "Munge memref.copy to linalg.copy";
|
||||||
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
|
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
|
||||||
let dependentDialects = ["memref::MemRefDialect"];
|
let dependentDialects = ["memref::MemRefDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "FuncOp"> {
|
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
|
||||||
let summary = "Convert tensor.pad to linalg ops";
|
let summary = "Convert tensor.pad to linalg ops";
|
||||||
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
|
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#ifndef TORCHMLIR_CONVERSION_PASSDETAIL_H
|
#ifndef TORCHMLIR_CONVERSION_PASSDETAIL_H
|
||||||
#define TORCHMLIR_CONVERSION_PASSDETAIL_H
|
#define TORCHMLIR_CONVERSION_PASSDETAIL_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
|
@ -88,7 +88,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::createConvertTorchToLinalgPass() {
|
mlir::torch::createConvertTorchToLinalgPass() {
|
||||||
return std::make_unique<ConvertTorchToLinalg>();
|
return std::make_unique<ConvertTorchToLinalg>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,7 +91,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::createConvertTorchToSCFPass() {
|
mlir::torch::createConvertTorchToSCFPass() {
|
||||||
return std::make_unique<ConvertTorchToSCF>();
|
return std::make_unique<ConvertTorchToSCF>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -221,7 +222,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::createConvertTorchToStdPass() {
|
mlir::torch::createConvertTorchToStdPass() {
|
||||||
return std::make_unique<ConvertTorchToStd>();
|
return std::make_unique<ConvertTorchToStd>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,8 +10,10 @@
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||||
|
@ -566,7 +568,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::createConvertTorchToTMTensorPass() {
|
mlir::torch::createConvertTorchToTMTensorPass() {
|
||||||
return std::make_unique<ConvertTorchToTMTensor>();
|
return std::make_unique<ConvertTorchToTMTensor>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -3170,7 +3170,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::createConvertTorchToTosaPass() {
|
mlir::torch::createConvertTorchToTosaPass() {
|
||||||
return std::make_unique<ConvertTorchToTosa>();
|
return std::make_unique<ConvertTorchToTosa>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "mlir/Transforms/InliningUtils.h"
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
|
@ -124,7 +125,7 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
|
||||||
unsigned argIndex,
|
unsigned argIndex,
|
||||||
NamedAttribute namedAttr) {
|
NamedAttribute namedAttr) {
|
||||||
if (namedAttr.getName().getValue() == "torch.type_bound") {
|
if (namedAttr.getName().getValue() == "torch.type_bound") {
|
||||||
auto func = dyn_cast<FuncOp>(op);
|
auto func = dyn_cast<func::FuncOp>(op);
|
||||||
if (!func)
|
if (!func)
|
||||||
return op->emitError() << "'torch.type_bound' must be attached to a func";
|
return op->emitError() << "'torch.type_bound' must be attached to a func";
|
||||||
TypeAttr attr = namedAttr.getValue().dyn_cast<TypeAttr>();
|
TypeAttr attr = namedAttr.getValue().dyn_cast<TypeAttr>();
|
||||||
|
@ -134,7 +135,7 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
|
||||||
if (!type)
|
if (!type)
|
||||||
return op->emitError() << "'torch.type_bound' must be of "
|
return op->emitError() << "'torch.type_bound' must be of "
|
||||||
"!torch.tensor/!torch.vtensor type";
|
"!torch.tensor/!torch.vtensor type";
|
||||||
if (!func.getType().getInput(argIndex).isa<BaseTensorType>())
|
if (!func.getFunctionType().getInput(argIndex).isa<BaseTensorType>())
|
||||||
return op->emitError() << "'torch.type_bound' must be attached to an "
|
return op->emitError() << "'torch.type_bound' must be attached to an "
|
||||||
"argument of !torch.tensor/!torch.vtensor type";
|
"argument of !torch.tensor/!torch.vtensor type";
|
||||||
return success();
|
return success();
|
||||||
|
@ -177,3 +178,100 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// OptionalType and ListType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void OptionalType::print(AsmPrinter &printer) const {
|
||||||
|
printer << "<";
|
||||||
|
// Print the contained type without the `!torch.` prefix.
|
||||||
|
printTorchDialectType(getImpl()->containedType, printer);
|
||||||
|
printer << ">";
|
||||||
|
}
|
||||||
|
|
||||||
|
void ListType::print(AsmPrinter &printer) const {
|
||||||
|
printer << "<";
|
||||||
|
// Print the contained type without the `!torch.` prefix.
|
||||||
|
printTorchDialectType(getImpl()->containedType, printer);
|
||||||
|
printer << ">";
|
||||||
|
}
|
||||||
|
|
||||||
|
Type OptionalType::parse(AsmParser &odsParser) {
|
||||||
|
if (odsParser.parseLess())
|
||||||
|
return Type();
|
||||||
|
|
||||||
|
// Parse the contained type, but forward directly to our internal parsing
|
||||||
|
// of `torch` dialect types, so that we can parse nested types without
|
||||||
|
// the `!torch.` prefix.
|
||||||
|
Type containedType = parseTorchDialectType(odsParser);
|
||||||
|
if (!containedType)
|
||||||
|
return Type();
|
||||||
|
if (odsParser.parseGreater())
|
||||||
|
return Type();
|
||||||
|
return get(odsParser.getContext(), containedType);
|
||||||
|
}
|
||||||
|
|
||||||
|
Type ListType::parse(AsmParser &odsParser) {
|
||||||
|
if (odsParser.parseLess())
|
||||||
|
return Type();
|
||||||
|
|
||||||
|
// Parse the contained type, but forward directly to our internal parsing
|
||||||
|
// of `torch` dialect types, so that we can parse nested types without
|
||||||
|
// the `!torch.` prefix.
|
||||||
|
Type containedType = parseTorchDialectType(odsParser);
|
||||||
|
if (!containedType)
|
||||||
|
return Type();
|
||||||
|
if (odsParser.parseGreater())
|
||||||
|
return Type();
|
||||||
|
return get(odsParser.getContext(), containedType);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DictType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void DictType::print(AsmPrinter &printer) const {
|
||||||
|
printer << "<";
|
||||||
|
printTorchDialectType(getImpl()->keyType, printer);
|
||||||
|
printer << ", ";
|
||||||
|
printTorchDialectType(getImpl()->valueType, printer);
|
||||||
|
printer << ">";
|
||||||
|
}
|
||||||
|
|
||||||
|
Type DictType::parse(AsmParser &odsParser) {
|
||||||
|
if (odsParser.parseLess())
|
||||||
|
return Type();
|
||||||
|
Type keyType = parseTorchDialectType(odsParser);
|
||||||
|
if (!keyType)
|
||||||
|
return Type();
|
||||||
|
if (odsParser.parseComma())
|
||||||
|
return Type();
|
||||||
|
Type valueType = parseTorchDialectType(odsParser);
|
||||||
|
if (!valueType)
|
||||||
|
return Type();
|
||||||
|
if (odsParser.parseGreater())
|
||||||
|
return Type();
|
||||||
|
return get(odsParser.getContext(), keyType, valueType);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// NnModuleType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void NnModuleType::print(AsmPrinter &printer) const {
|
||||||
|
printer << "<\"";
|
||||||
|
llvm::printEscapedString(getImpl()->className, printer.getStream());
|
||||||
|
printer << "\">";
|
||||||
|
}
|
||||||
|
|
||||||
|
Type NnModuleType::parse(AsmParser &odsParser) {
|
||||||
|
if (odsParser.parseLess())
|
||||||
|
return Type();
|
||||||
|
std::string className;
|
||||||
|
if (odsParser.parseOptionalString(&className))
|
||||||
|
return Type();
|
||||||
|
if (odsParser.parseGreater())
|
||||||
|
return Type();
|
||||||
|
return get(odsParser.getContext(), className);
|
||||||
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
@ -99,7 +100,7 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
|
||||||
|
|
||||||
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||||
auto func =
|
auto func =
|
||||||
symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, functionAttr());
|
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, functionAttr());
|
||||||
if (!func)
|
if (!func)
|
||||||
return emitError() << "'@" << function()
|
return emitError() << "'@" << function()
|
||||||
<< "' does not reference a valid function";
|
<< "' does not reference a valid function";
|
||||||
|
@ -112,8 +113,8 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||||
"merely declared)";
|
"merely declared)";
|
||||||
auto expectedReceiverArgType = NnModuleType::get(
|
auto expectedReceiverArgType = NnModuleType::get(
|
||||||
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
|
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
|
||||||
if (func.getType().getNumInputs() == 0 ||
|
if (func.getFunctionType().getNumInputs() == 0 ||
|
||||||
func.getType().getInput(0) != expectedReceiverArgType) {
|
func.getFunctionType().getInput(0) != expectedReceiverArgType) {
|
||||||
return emitError() << "the referenced function '" << function()
|
return emitError() << "the referenced function '" << function()
|
||||||
<< "' must have a first argument of type "
|
<< "' must have a first argument of type "
|
||||||
<< expectedReceiverArgType;
|
<< expectedReceiverArgType;
|
||||||
|
@ -278,7 +279,7 @@ ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
Region *elseRegion = result.addRegion();
|
Region *elseRegion = result.addRegion();
|
||||||
|
|
||||||
auto &builder = parser.getBuilder();
|
auto &builder = parser.getBuilder();
|
||||||
OpAsmParser::OperandType cond;
|
OpAsmParser::UnresolvedOperand cond;
|
||||||
Type boolType = builder.getType<Torch::BoolType>();
|
Type boolType = builder.getType<Torch::BoolType>();
|
||||||
if (parser.parseOperand(cond) ||
|
if (parser.parseOperand(cond) ||
|
||||||
parser.resolveOperand(cond, boolType, result.operands))
|
parser.resolveOperand(cond, boolType, result.operands))
|
||||||
|
|
|
@ -35,7 +35,7 @@ ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser,
|
||||||
OperationState &result, int numOperands,
|
OperationState &result, int numOperands,
|
||||||
int numResults) {
|
int numResults) {
|
||||||
llvm::SMLoc loc = parser.getCurrentLocation();
|
llvm::SMLoc loc = parser.getCurrentLocation();
|
||||||
SmallVector<OpAsmParser::OperandType> operands;
|
SmallVector<OpAsmParser::UnresolvedOperand> operands;
|
||||||
if (parser.parseOperandList(operands, /*requiredOperandCount=*/numOperands))
|
if (parser.parseOperandList(operands, /*requiredOperandCount=*/numOperands))
|
||||||
return failure();
|
return failure();
|
||||||
if (parser.parseOptionalAttrDict(result.attributes))
|
if (parser.parseOptionalAttrDict(result.attributes))
|
||||||
|
|
|
@ -31,11 +31,12 @@ using namespace mlir::torch::Torch;
|
||||||
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type>;
|
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type>;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> {
|
class AdjustCallingConventionForFunc
|
||||||
|
: public OpConversionPattern<func::FuncOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(FuncOp func, OpAdaptor adaptor,
|
matchAndRewrite(func::FuncOp func, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
MLIRContext *context = func.getContext();
|
MLIRContext *context = func.getContext();
|
||||||
auto typeBoundIdent = StringAttr::get(context, "torch.type_bound");
|
auto typeBoundIdent = StringAttr::get(context, "torch.type_bound");
|
||||||
|
@ -70,7 +71,7 @@ public:
|
||||||
typeConverter);
|
typeConverter);
|
||||||
|
|
||||||
SmallVector<Type> newResultTypes;
|
SmallVector<Type> newResultTypes;
|
||||||
for (auto type : func.getType().getResults()) {
|
for (auto type : func.getFunctionType().getResults()) {
|
||||||
if (auto none = type.dyn_cast<Torch::NoneType>()) {
|
if (auto none = type.dyn_cast<Torch::NoneType>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -186,7 +187,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static LogicalResult adjustCallingConventions(FuncOp func,
|
static LogicalResult adjustCallingConventions(func::FuncOp func,
|
||||||
TypeBoundMap &typeBoundMap) {
|
TypeBoundMap &typeBoundMap) {
|
||||||
MLIRContext *context = func.getContext();
|
MLIRContext *context = func.getContext();
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
@ -217,7 +218,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
||||||
patterns.add<AdjustCallingConventionForReturn>(typeConverter, context);
|
patterns.add<AdjustCallingConventionForReturn>(typeConverter, context);
|
||||||
|
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addDynamicallyLegalOp<FuncOp>([](FuncOp func) {
|
target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp func) {
|
||||||
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
||||||
if (func.getArgAttr(i, "torch.type_bound"))
|
if (func.getArgAttr(i, "torch.type_bound"))
|
||||||
return false;
|
return false;
|
||||||
|
@ -225,7 +226,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
for (int i = 0, e = func.getNumResults(); i != e; i++) {
|
for (int i = 0, e = func.getNumResults(); i != e; i++) {
|
||||||
if (func.getType().getResults()[i].isa<Torch::NoneType>())
|
if (func.getFunctionType().getResults()[i].isa<Torch::NoneType>())
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -266,7 +267,7 @@ class AdjustCallingConventionsPass
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
TypeBoundMap typeBoundMap;
|
TypeBoundMap typeBoundMap;
|
||||||
for (auto func : module.getOps<FuncOp>()) {
|
for (auto func : module.getOps<func::FuncOp>()) {
|
||||||
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
|
||||||
auto typeBoundAttr =
|
auto typeBoundAttr =
|
||||||
func.getArgAttrOfType<TypeAttr>(i, "torch.type_bound");
|
func.getArgAttrOfType<TypeAttr>(i, "torch.type_bound");
|
||||||
|
@ -275,7 +276,7 @@ class AdjustCallingConventionsPass
|
||||||
typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue();
|
typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto func : module.getOps<FuncOp>()) {
|
for (auto func : module.getOps<func::FuncOp>()) {
|
||||||
if (failed(adjustCallingConventions(func, typeBoundMap)))
|
if (failed(adjustCallingConventions(func, typeBoundMap)))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1781,7 +1781,7 @@ class DecomposeComplexOpsPass
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::Torch::createDecomposeComplexOpsPass() {
|
mlir::torch::Torch::createDecomposeComplexOpsPass() {
|
||||||
return std::make_unique<DecomposeComplexOpsPass>();
|
return std::make_unique<DecomposeComplexOpsPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ class DropShapeCalculationsPass
|
||||||
patterns.insert<DropShapeCalculateOp>(context);
|
patterns.insert<DropShapeCalculateOp>(context);
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addIllegalOp<ShapeCalculateOp>();
|
target.addIllegalOp<ShapeCalculateOp>();
|
||||||
target.addLegalOp<FuncOp>();
|
target.addLegalOp<func::FuncOp>();
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
@ -64,7 +64,7 @@ class DropShapeCalculationsPass
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::Torch::createDropShapeCalculationsPass() {
|
mlir::torch::Torch::createDropShapeCalculationsPass() {
|
||||||
return std::make_unique<DropShapeCalculationsPass>();
|
return std::make_unique<DropShapeCalculationsPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,7 +88,7 @@ public:
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
Optional<LinkageInfo> getFuncLinkageInfo(NnModuleOp instance,
|
Optional<LinkageInfo> getFuncLinkageInfo(NnModuleOp instance,
|
||||||
FuncOp methodFunc) {
|
func::FuncOp methodFunc) {
|
||||||
auto it = funcLinkageInfo.find({instance, methodFunc});
|
auto it = funcLinkageInfo.find({instance, methodFunc});
|
||||||
if (it == funcLinkageInfo.end())
|
if (it == funcLinkageInfo.end())
|
||||||
return None;
|
return None;
|
||||||
|
@ -185,7 +185,7 @@ private:
|
||||||
for (auto method : classType.getOps<MethodOp>()) {
|
for (auto method : classType.getOps<MethodOp>()) {
|
||||||
nameStack.push_back(method.name().str());
|
nameStack.push_back(method.name().str());
|
||||||
funcLinkageInfo[{nnModule,
|
funcLinkageInfo[{nnModule,
|
||||||
symbolTable.lookup<FuncOp>(method.function())}] =
|
symbolTable.lookup<func::FuncOp>(method.function())}] =
|
||||||
LinkageInfo{llvm::join(nameStack, "."), method.isPrivate()};
|
LinkageInfo{llvm::join(nameStack, "."), method.isPrivate()};
|
||||||
nameStack.pop_back();
|
nameStack.pop_back();
|
||||||
}
|
}
|
||||||
|
@ -251,7 +251,7 @@ private:
|
||||||
// Linkage info for each method in the program. Since we are going to be
|
// Linkage info for each method in the program. Since we are going to be
|
||||||
// monomorphizing all the functions, we also need to key this off of the
|
// monomorphizing all the functions, we also need to key this off of the
|
||||||
// instance (NnModuleOp) that the func is monomorphized for.
|
// instance (NnModuleOp) that the func is monomorphized for.
|
||||||
DenseMap<std::pair<NnModuleOp, FuncOp>, LinkageInfo> funcLinkageInfo;
|
DenseMap<std::pair<NnModuleOp, func::FuncOp>, LinkageInfo> funcLinkageInfo;
|
||||||
// The corresponding GlobalSlotOp for each SlotOp in the program.
|
// The corresponding GlobalSlotOp for each SlotOp in the program.
|
||||||
DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot;
|
DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot;
|
||||||
// A set of values that we have copied into torch.global_slot initializers,
|
// A set of values that we have copied into torch.global_slot initializers,
|
||||||
|
@ -298,7 +298,7 @@ namespace {
|
||||||
// any notion of "type" that we have in the IR, but still fits the formal
|
// any notion of "type" that we have in the IR, but still fits the formal
|
||||||
// definition.
|
// definition.
|
||||||
struct Monomorphization {
|
struct Monomorphization {
|
||||||
FuncOp func;
|
func::FuncOp func;
|
||||||
std::vector<ArgInstance> argInstances;
|
std::vector<ArgInstance> argInstances;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -327,7 +327,7 @@ template <> struct llvm::DenseMapInfo<Monomorphization> {
|
||||||
//
|
//
|
||||||
// This generalizes to a full abstract interpretation of the function, but
|
// This generalizes to a full abstract interpretation of the function, but
|
||||||
// currently only analyzes a subset of ops.
|
// currently only analyzes a subset of ops.
|
||||||
static LogicalResult analyzeInstances(FuncOp func,
|
static LogicalResult analyzeInstances(func::FuncOp func,
|
||||||
ArrayRef<ArgInstance> argInstances,
|
ArrayRef<ArgInstance> argInstances,
|
||||||
BlockAndValueMapping &mapping) {
|
BlockAndValueMapping &mapping) {
|
||||||
for (auto &argInstance : argInstances)
|
for (auto &argInstance : argInstances)
|
||||||
|
@ -351,7 +351,7 @@ static LogicalResult analyzeInstances(FuncOp func,
|
||||||
static FailureOr<Monomorphization>
|
static FailureOr<Monomorphization>
|
||||||
createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping,
|
createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping,
|
||||||
SymbolTable &symbolTable) {
|
SymbolTable &symbolTable) {
|
||||||
auto func = symbolTable.lookup<FuncOp>(op.getCallee());
|
auto func = symbolTable.lookup<func::FuncOp>(op.getCallee());
|
||||||
Monomorphization monomorphization;
|
Monomorphization monomorphization;
|
||||||
monomorphization.func = func;
|
monomorphization.func = func;
|
||||||
for (auto operand : llvm::enumerate(op->getOperands())) {
|
for (auto operand : llvm::enumerate(op->getOperands())) {
|
||||||
|
@ -372,7 +372,7 @@ public:
|
||||||
: module(module), symbolTable(module) {}
|
: module(module), symbolTable(module) {}
|
||||||
LogicalResult
|
LogicalResult
|
||||||
initialize(DenseMap<ClassTypeOp, std::vector<NnModuleOp>> &instances) {
|
initialize(DenseMap<ClassTypeOp, std::vector<NnModuleOp>> &instances) {
|
||||||
for (auto func : module.getOps<FuncOp>()) {
|
for (auto func : module.getOps<func::FuncOp>()) {
|
||||||
Monomorphization monomorphization;
|
Monomorphization monomorphization;
|
||||||
monomorphization.func = func;
|
monomorphization.func = func;
|
||||||
bool canTriviallyMonomorphize = true;
|
bool canTriviallyMonomorphize = true;
|
||||||
|
@ -455,7 +455,7 @@ static LogicalResult verifyNnModuleValueUses(Value value) {
|
||||||
|
|
||||||
// Verify that `func` conforms to the subset of allowable method bodies
|
// Verify that `func` conforms to the subset of allowable method bodies
|
||||||
// that we can convert.
|
// that we can convert.
|
||||||
static LogicalResult verifyFuncConformsToSubset(FuncOp func) {
|
static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) {
|
||||||
// TODO: Investingate why WalkResult::interrupt() doesn't propagate properly.
|
// TODO: Investingate why WalkResult::interrupt() doesn't propagate properly.
|
||||||
LogicalResult ret = success();
|
LogicalResult ret = success();
|
||||||
func.walk([&](Block *block) {
|
func.walk([&](Block *block) {
|
||||||
|
@ -481,7 +481,7 @@ static LogicalResult verifyFuncConformsToSubset(FuncOp func) {
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
|
verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
|
||||||
MonomorphizationTracker &tracker) {
|
MonomorphizationTracker &tracker) {
|
||||||
DenseMap<FuncOp, int> numMonomorphizations;
|
DenseMap<func::FuncOp, int> numMonomorphizations;
|
||||||
for (auto &monomorphization : tracker.getMonomorphizations()) {
|
for (auto &monomorphization : tracker.getMonomorphizations()) {
|
||||||
numMonomorphizations[monomorphization.func] += 1;
|
numMonomorphizations[monomorphization.func] += 1;
|
||||||
}
|
}
|
||||||
|
@ -489,7 +489,7 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
|
||||||
for (auto classType : module.getOps<ClassTypeOp>()) {
|
for (auto classType : module.getOps<ClassTypeOp>()) {
|
||||||
for (auto method : classType.getOps<MethodOp>()) {
|
for (auto method : classType.getOps<MethodOp>()) {
|
||||||
if (!method.isPrivate()) {
|
if (!method.isPrivate()) {
|
||||||
if (numMonomorphizations[symbolTable.lookup<FuncOp>(
|
if (numMonomorphizations[symbolTable.lookup<func::FuncOp>(
|
||||||
method.function())] > 1) {
|
method.function())] > 1) {
|
||||||
method.emitError()
|
method.emitError()
|
||||||
<< "public function with multiple monomorphizations";
|
<< "public function with multiple monomorphizations";
|
||||||
|
@ -503,11 +503,10 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
|
||||||
|
|
||||||
// Rewrite `func`, given that all values of `NnModuleType` have been mapped in
|
// Rewrite `func`, given that all values of `NnModuleType` have been mapped in
|
||||||
// `mapping` to corresponding global instances.
|
// `mapping` to corresponding global instances.
|
||||||
static LogicalResult
|
static LogicalResult rewriteMonomorphizedFuncClone(
|
||||||
rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping,
|
func::FuncOp func, BlockAndValueMapping mapping, SymbolTable &symbolTable,
|
||||||
SymbolTable &symbolTable,
|
DenseMap<Monomorphization, func::FuncOp> &newFuncs,
|
||||||
DenseMap<Monomorphization, FuncOp> &newFuncs,
|
ObjectGraphInfo &objectGraphInfo) {
|
||||||
ObjectGraphInfo &objectGraphInfo) {
|
|
||||||
|
|
||||||
SmallVector<Operation *> toErase;
|
SmallVector<Operation *> toErase;
|
||||||
auto handlePrimSetAttr = [&](PrimSetAttrOp op) {
|
auto handlePrimSetAttr = [&](PrimSetAttrOp op) {
|
||||||
|
@ -605,7 +604,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
||||||
// static analysis to discover how to monomorphize th eprogram, including
|
// static analysis to discover how to monomorphize th eprogram, including
|
||||||
// tracking instances through control flow, through get/set attr, etc. We
|
// tracking instances through control flow, through get/set attr, etc. We
|
||||||
// implement a very simple subset of cases.
|
// implement a very simple subset of cases.
|
||||||
for (auto func : module.getOps<FuncOp>()) {
|
for (auto func : module.getOps<func::FuncOp>()) {
|
||||||
if (failed(verifyFuncConformsToSubset(func)))
|
if (failed(verifyFuncConformsToSubset(func)))
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -637,10 +636,10 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
||||||
|
|
||||||
// Step 4: Clone/rewrite functions to implement the necessary
|
// Step 4: Clone/rewrite functions to implement the necessary
|
||||||
// monomorphizations.
|
// monomorphizations.
|
||||||
DenseMap<Monomorphization, FuncOp> newFuncs;
|
DenseMap<Monomorphization, func::FuncOp> newFuncs;
|
||||||
int uniquifier = 0;
|
int uniquifier = 0;
|
||||||
for (auto &monomorphization : tracker.getMonomorphizations()) {
|
for (auto &monomorphization : tracker.getMonomorphizations()) {
|
||||||
auto newFunc = cast<FuncOp>(monomorphization.func->clone());
|
auto newFunc = cast<func::FuncOp>(monomorphization.func->clone());
|
||||||
newFuncs[monomorphization] = newFunc;
|
newFuncs[monomorphization] = newFunc;
|
||||||
Optional<LinkageInfo> linkageInfo = None;
|
Optional<LinkageInfo> linkageInfo = None;
|
||||||
// If it is potentially a method, check its linkage info.
|
// If it is potentially a method, check its linkage info.
|
||||||
|
@ -675,14 +674,14 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 5: Clean up object graph.
|
// Step 5: Clean up object graph.
|
||||||
DenseSet<FuncOp> liveFuncs;
|
DenseSet<func::FuncOp> liveFuncs;
|
||||||
for (auto &kv : newFuncs) {
|
for (auto &kv : newFuncs) {
|
||||||
liveFuncs.insert(kv.second);
|
liveFuncs.insert(kv.second);
|
||||||
}
|
}
|
||||||
for (auto &op : llvm::make_early_inc_range(module.getOps())) {
|
for (auto &op : llvm::make_early_inc_range(module.getOps())) {
|
||||||
if (isa<GlobalSlotOp>(&op))
|
if (isa<GlobalSlotOp>(&op))
|
||||||
continue;
|
continue;
|
||||||
if (auto func = dyn_cast<FuncOp>(op)) {
|
if (auto func = dyn_cast<func::FuncOp>(op)) {
|
||||||
if (liveFuncs.contains(func))
|
if (liveFuncs.contains(func))
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -355,7 +355,7 @@ class MaximizeValueSemanticsPass
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::Torch::createMaximizeValueSemanticsPass() {
|
mlir::torch::Torch::createMaximizeValueSemanticsPass() {
|
||||||
return std::make_unique<MaximizeValueSemanticsPass>();
|
return std::make_unique<MaximizeValueSemanticsPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,9 +10,11 @@
|
||||||
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||||
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
class ModuleOp;
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace Torch {
|
namespace Torch {
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
// Eliminate the PrimTupleIndexOp generated from the
|
// Eliminate the PrimTupleIndexOp generated from the
|
||||||
// adjustCallingConventions
|
// adjustCallingConventions
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
// Inline global slots, which for most inference scenarios deletes them.
|
// Inline global slots, which for most inference scenarios deletes them.
|
||||||
// This also exposes more information to intraprocedural transformations
|
// This also exposes more information to intraprocedural transformations
|
||||||
// below like MaximizeValueSemantics and RefineTypes.
|
// below like MaximizeValueSemantics and RefineTypes.
|
||||||
|
@ -105,14 +105,14 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reduce variants of ops to a smaller set of primitives.
|
// Reduce variants of ops to a smaller set of primitives.
|
||||||
pm.addNestedPass<FuncOp>(createReduceOpVariantsPass());
|
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
|
||||||
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
// OPT-ONLY: Right now we rely on this to eliminate certain branches that
|
// OPT-ONLY: Right now we rely on this to eliminate certain branches that
|
||||||
// guard unreachable code that backends can't handle yet, such as lists,
|
// guard unreachable code that backends can't handle yet, such as lists,
|
||||||
// RaiseException, unimplemented tensor ops, and only-used-in-training
|
// RaiseException, unimplemented tensor ops, and only-used-in-training
|
||||||
// operations on `torch.global_slot`'s.
|
// operations on `torch.global_slot`'s.
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
// OPT-ONLY: We may have deleted some `torch.global_slot.get` /
|
// OPT-ONLY: We may have deleted some `torch.global_slot.get` /
|
||||||
// `torch.global_slot.get` ops, which may have left more
|
// `torch.global_slot.get` ops, which may have left more
|
||||||
// `torch.global_slot`'s unused.
|
// `torch.global_slot`'s unused.
|
||||||
|
@ -124,7 +124,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
|
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
|
||||||
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||||
|
|
||||||
// Do shape refinement.
|
// Do shape refinement.
|
||||||
// This must be run before RefineTypes (which primarily does dtype inference),
|
// This must be run before RefineTypes (which primarily does dtype inference),
|
||||||
|
@ -132,7 +132,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
// operand.
|
// operand.
|
||||||
createTorchShapeRefinementPipeline(pm, options);
|
createTorchShapeRefinementPipeline(pm, options);
|
||||||
// Refine types in the program, which mainly means inferring dtypes of ops.
|
// Refine types in the program, which mainly means inferring dtypes of ops.
|
||||||
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
|
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
|
||||||
|
|
||||||
// Propagate to ABI return types the shape/dtype information discovered by
|
// Propagate to ABI return types the shape/dtype information discovered by
|
||||||
// the previous pass. Doing this is ABI-compatible for our backends.
|
// the previous pass. Doing this is ABI-compatible for our backends.
|
||||||
|
@ -142,7 +142,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
// This can fold away some branches given the information got from
|
// This can fold away some branches given the information got from
|
||||||
// RefineTypes before doing maximize value sematics which only works with
|
// RefineTypes before doing maximize value sematics which only works with
|
||||||
// basic blocks.
|
// basic blocks.
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
|
@ -152,9 +152,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
// branches that guard unreachable code that backends can't handle yet, such
|
// branches that guard unreachable code that backends can't handle yet, such
|
||||||
// as lists, RaiseException, unimplemented aten ops, and
|
// as lists, RaiseException, unimplemented aten ops, and
|
||||||
// only-used-in-training operations on `torch.global_slot`'s.
|
// only-used-in-training operations on `torch.global_slot`'s.
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
}
|
}
|
||||||
pm.addNestedPass<FuncOp>(Torch::createDecomposeComplexOpsPass());
|
pm.addNestedPass<func::FuncOp>(Torch::createDecomposeComplexOpsPass());
|
||||||
|
|
||||||
// TODO: VerifyTorchBackendContractPass.
|
// TODO: VerifyTorchBackendContractPass.
|
||||||
}
|
}
|
||||||
|
@ -172,11 +172,11 @@ void mlir::torch::Torch::createTorchShapeRefinementPipeline(
|
||||||
// as hard as possible" kind of thing, so it's inherently somewhat brittle.
|
// as hard as possible" kind of thing, so it's inherently somewhat brittle.
|
||||||
// The idea is to keep strengthening what we do here to support the shape
|
// The idea is to keep strengthening what we do here to support the shape
|
||||||
// library. We don't need to support arbitrary programs, thankfully.
|
// library. We don't need to support arbitrary programs, thankfully.
|
||||||
pm.addNestedPass<FuncOp>(Torch::createSimplifyShapeCalculationsPass());
|
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
|
||||||
// Run CSE, then see if we can simplify further.
|
// Run CSE, then see if we can simplify further.
|
||||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||||
pm.addNestedPass<FuncOp>(Torch::createSimplifyShapeCalculationsPass());
|
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
|
||||||
|
|
||||||
// Drop shape calculations, leaving behind the shape-refined program.
|
// Drop shape calculations, leaving behind the shape-refined program.
|
||||||
pm.addNestedPass<FuncOp>(Torch::createDropShapeCalculationsPass());
|
pm.addNestedPass<func::FuncOp>(Torch::createDropShapeCalculationsPass());
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,10 +33,10 @@ public:
|
||||||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||||
op.receiver().getType().cast<NnModuleType>().getClassName());
|
op.receiver().getType().cast<NnModuleType>().getClassName());
|
||||||
assert(classType && "malformed module -- missing ClassTypeOp");
|
assert(classType && "malformed module -- missing ClassTypeOp");
|
||||||
FuncOp func;
|
func::FuncOp func;
|
||||||
for (auto method : classType.getOps<MethodOp>()) {
|
for (auto method : classType.getOps<MethodOp>()) {
|
||||||
if (method.name() == op.name()) {
|
if (method.name() == op.name()) {
|
||||||
func = symbolTable.lookup<FuncOp>(method.function());
|
func = symbolTable.lookup<func::FuncOp>(method.function());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -207,7 +207,7 @@ public:
|
||||||
assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 &&
|
assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 &&
|
||||||
"Torch JIT operators shouldn't have regions or successors");
|
"Torch JIT operators shouldn't have regions or successors");
|
||||||
|
|
||||||
Operation *newOp = rewriter.createOperation(state);
|
Operation *newOp = rewriter.create(state);
|
||||||
auto tensor =
|
auto tensor =
|
||||||
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
||||||
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
||||||
|
@ -276,7 +276,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::Torch::createReduceOpVariantsPass() {
|
mlir::torch::Torch::createReduceOpVariantsPass() {
|
||||||
return std::make_unique<ReduceOpVariantsPass>();
|
return std::make_unique<ReduceOpVariantsPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ class RefinePublicReturnPass
|
||||||
: public RefinePublicReturnBase<RefinePublicReturnPass> {
|
: public RefinePublicReturnBase<RefinePublicReturnPass> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
module.walk([&](FuncOp func) {
|
module.walk([&](func::FuncOp func) {
|
||||||
if (func.getVisibility() != SymbolTable::Visibility::Public)
|
if (func.getVisibility() != SymbolTable::Visibility::Public)
|
||||||
return;
|
return;
|
||||||
if (func.isExternal())
|
if (func.isExternal())
|
||||||
|
@ -40,7 +40,7 @@ class RefinePublicReturnPass
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void rewriteSignature(FuncOp func) {
|
void rewriteSignature(func::FuncOp func) {
|
||||||
// Find the unique return op.
|
// Find the unique return op.
|
||||||
func::ReturnOp returnOp;
|
func::ReturnOp returnOp;
|
||||||
WalkResult walkResult = func.walk([&](func::ReturnOp op) {
|
WalkResult walkResult = func.walk([&](func::ReturnOp op) {
|
||||||
|
@ -90,7 +90,7 @@ class RefinePublicReturnPass
|
||||||
returnOp->setOperands(newOperands);
|
returnOp->setOperands(newOperands);
|
||||||
|
|
||||||
// Update the function type.
|
// Update the function type.
|
||||||
auto funcType = func.getType();
|
auto funcType = func.getFunctionType();
|
||||||
func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(),
|
func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(),
|
||||||
ValueRange(newOperands).getTypes()));
|
ValueRange(newOperands).getTypes()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -1191,7 +1191,7 @@ static bool isSafeToRefineOperandInPlace(OpOperand *use, Type newOperandType) {
|
||||||
return operationIsValidWithRefinedType(use, newOperandType);
|
return operationIsValidWithRefinedType(use, newOperandType);
|
||||||
}
|
}
|
||||||
|
|
||||||
void optimize(FuncOp func, TypeAnalyzer &analyzer) {
|
void optimize(func::FuncOp func, TypeAnalyzer &analyzer) {
|
||||||
func.walk([&](Operation *op) {
|
func.walk([&](Operation *op) {
|
||||||
auto convertValuesToMostRefinedType = [&](ValueRange values, OpBuilder &b) {
|
auto convertValuesToMostRefinedType = [&](ValueRange values, OpBuilder &b) {
|
||||||
for (Value v : values) {
|
for (Value v : values) {
|
||||||
|
@ -1336,7 +1336,7 @@ class RefineTypesPass : public RefineTypesBase<RefineTypesPass> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::Torch::createRefineTypesPass() {
|
mlir::torch::Torch::createRefineTypesPass() {
|
||||||
return std::make_unique<RefineTypesPass>();
|
return std::make_unique<RefineTypesPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -169,7 +169,7 @@ static Value adjustShapeFunctionArg(Value operand, Type desiredType,
|
||||||
// from the shape library.
|
// from the shape library.
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands,
|
populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands,
|
||||||
mlir::FuncOp shapeFunction) {
|
mlir::func::FuncOp shapeFunction) {
|
||||||
// Create a call to the shape function in the `shapeCalculation` region.
|
// Create a call to the shape function in the `shapeCalculation` region.
|
||||||
// We will import the callee from the shape library later.
|
// We will import the callee from the shape library later.
|
||||||
OpBuilder b(op.getContext());
|
OpBuilder b(op.getContext());
|
||||||
|
@ -241,7 +241,7 @@ class ReifyShapeCalculationsPass
|
||||||
name = name.drop_front(strlen("valsem."));
|
name = name.drop_front(strlen("valsem."));
|
||||||
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
|
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
|
||||||
auto shapeFunction =
|
auto shapeFunction =
|
||||||
shapeLibrary->lookupSymbol<FuncOp>(shapeFunctionName);
|
shapeLibrary->lookupSymbol<func::FuncOp>(shapeFunctionName);
|
||||||
if (!shapeFunction)
|
if (!shapeFunction)
|
||||||
return;
|
return;
|
||||||
neededShapeFunctions.push_back(shapeFunctionName);
|
neededShapeFunctions.push_back(shapeFunctionName);
|
||||||
|
@ -276,7 +276,7 @@ class ReifyShapeCalculationsPass
|
||||||
auto symName = worklist.pop_back_val();
|
auto symName = worklist.pop_back_val();
|
||||||
if (importedFunctions.count(symName))
|
if (importedFunctions.count(symName))
|
||||||
continue;
|
continue;
|
||||||
auto func = shapeLibrary->lookupSymbol<mlir::FuncOp>(symName);
|
auto func = shapeLibrary->lookupSymbol<mlir::func::FuncOp>(symName);
|
||||||
assert(func && "broken shape library");
|
assert(func && "broken shape library");
|
||||||
// Move the shape function from the library to the module this pass
|
// Move the shape function from the library to the module this pass
|
||||||
// is running on. (this mutates the library, but we re-parse it each time
|
// is running on. (this mutates the library, but we re-parse it each time
|
||||||
|
|
|
@ -415,7 +415,7 @@ class SimplifyShapeCalculationsPass
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::Torch::createSimplifyShapeCalculationsPass() {
|
mlir::torch::Torch::createSimplifyShapeCalculationsPass() {
|
||||||
return std::make_unique<SimplifyShapeCalculationsPass>();
|
return std::make_unique<SimplifyShapeCalculationsPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,9 +45,10 @@ struct FuncBackendTypeConversionPass
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, typeConverter);
|
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
|
||||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
patterns, typeConverter);
|
||||||
return typeConverter.isSignatureLegal(op.getType()) &&
|
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||||
|
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
|
||||||
typeConverter.isLegal(&op.getBody());
|
typeConverter.isLegal(&op.getBody());
|
||||||
});
|
});
|
||||||
populateCallOpTypeConversionPattern(patterns, typeConverter);
|
populateCallOpTypeConversionPattern(patterns, typeConverter);
|
||||||
|
@ -155,7 +156,7 @@ struct FinalizingBackendTypeConversionPass
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
|
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
|
||||||
return std::make_unique<FinalizingBackendTypeConversionPass>();
|
return std::make_unique<FinalizingBackendTypeConversionPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,9 +10,12 @@
|
||||||
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
|
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
|
||||||
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
|
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
class ModuleOp;
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace TorchConversion {
|
namespace TorchConversion {
|
||||||
|
|
||||||
|
|
|
@ -59,26 +59,27 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
// We do this first as it tends to involve pattern-matching against constants,
|
// We do this first as it tends to involve pattern-matching against constants,
|
||||||
// (e.g. dimensions which must be constant in a ranked programming model)
|
// (e.g. dimensions which must be constant in a ranked programming model)
|
||||||
// and those constants get somewhat obscured by TorchToStd.
|
// and those constants get somewhat obscured by TorchToStd.
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToTMTensorPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToStdPass());
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
|
||||||
pm.addNestedPass<FuncOp>(memref::createExpandOpsPass());
|
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
|
||||||
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
// Clean up any non-canonical code introduced above..
|
// Clean up any non-canonical code introduced above..
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
// Resolve `dim` ops on tensors (which currently live in the `memref`
|
// Resolve `dim` ops on tensors (which currently live in the `memref`
|
||||||
// dialect for some reason -- we don't have memrefs at this level).
|
// dialect for some reason -- we don't have memrefs at this level).
|
||||||
pm.addNestedPass<FuncOp>(memref::createResolveShapedTypeResultDimsPass());
|
pm.addNestedPass<func::FuncOp>(
|
||||||
|
memref::createResolveShapedTypeResultDimsPass());
|
||||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finish the type conversion from `torch` types to the types of the
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
// linalg-on-tensors backend contract.
|
// linalg-on-tensors backend contract.
|
||||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
pm.addNestedPass<FuncOp>(
|
pm.addNestedPass<func::FuncOp>(
|
||||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
|
|
||||||
// Verify that we have lowered to the form that linalg on tensors backends
|
// Verify that we have lowered to the form that linalg on tensors backends
|
||||||
|
@ -93,21 +94,21 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
||||||
pm.addPass(
|
pm.addPass(
|
||||||
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
|
||||||
|
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass());
|
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
|
||||||
// Perform rank broadcasting so TosaToLinalg pass works
|
// Perform rank broadcasting so TosaToLinalg pass works
|
||||||
pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass());
|
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
|
||||||
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
// Clean up any non-canonical code introduced above..
|
// Clean up any non-canonical code introduced above..
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finish the type conversion from `torch` types to the types of the
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
// TOSA backend contract.
|
// TOSA backend contract.
|
||||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
pm.addNestedPass<FuncOp>(
|
pm.addNestedPass<func::FuncOp>(
|
||||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||||
|
|
||||||
// Verify that we have lowered to the form that TOSA backends
|
// Verify that we have lowered to the form that TOSA backends
|
||||||
|
|
|
@ -9,6 +9,8 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
@ -60,7 +62,7 @@ class VerifyLinalgOnTensorsBackendContractPass
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
// Structural operations.
|
// Structural operations.
|
||||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>(
|
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
|
||||||
opHasLegalTypes);
|
opHasLegalTypes);
|
||||||
|
|
||||||
target.addDynamicallyLegalOp<GetNextSeedOp>(opHasLegalTypes);
|
target.addDynamicallyLegalOp<GetNextSeedOp>(opHasLegalTypes);
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
@ -39,7 +40,7 @@ class VerifyTosaBackendContractPass
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
// Structural operations.
|
// Structural operations.
|
||||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>(
|
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
|
||||||
opHasLegalTypes);
|
opHasLegalTypes);
|
||||||
// Basic scalar operations.
|
// Basic scalar operations.
|
||||||
target.addLegalDialect<tosa::TosaDialect>();
|
target.addLegalDialect<tosa::TosaDialect>();
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#ifndef REFBACKEND_PASSDETAIL_H
|
#ifndef REFBACKEND_PASSDETAIL_H
|
||||||
#define REFBACKEND_PASSDETAIL_H
|
#define REFBACKEND_PASSDETAIL_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
|
@ -68,7 +68,7 @@ static bool isArgMemRefTypeValid(Type type) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void addEmitCInterfaceAttr(FuncOp func) {
|
static void addEmitCInterfaceAttr(func::FuncOp func) {
|
||||||
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
|
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ static void replaceReturnWithCall(OpBuilder b, func::ReturnOp op,
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult mungeFunction(
|
static LogicalResult mungeFunction(
|
||||||
FuncOp func,
|
func::FuncOp func,
|
||||||
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
|
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
|
||||||
// Only need to call mungeFunction for functions callable from outside of the
|
// Only need to call mungeFunction for functions callable from outside of the
|
||||||
// module.
|
// module.
|
||||||
|
@ -188,17 +188,17 @@ class MungeCallingConventions
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
OpBuilder b(module.getBodyRegion());
|
OpBuilder b(module.getBodyRegion());
|
||||||
std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
|
std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
|
||||||
for (auto func : module.getOps<FuncOp>()) {
|
for (auto func : module.getOps<func::FuncOp>()) {
|
||||||
if (failed(mungeFunction(func, invokedConsumeFuncReturnFuncs)))
|
if (failed(mungeFunction(func, invokedConsumeFuncReturnFuncs)))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create FuncOp for consumeFuncReturnFuncs that are used.
|
// Create FuncOp for consumeFuncReturnFuncs that are used.
|
||||||
for (auto &p : invokedConsumeFuncReturnFuncs) {
|
for (auto &p : invokedConsumeFuncReturnFuncs) {
|
||||||
auto consumeFuncReturnFunc =
|
auto consumeFuncReturnFunc = b.create<func::FuncOp>(
|
||||||
b.create<FuncOp>(module.getLoc(), p.first,
|
module.getLoc(), p.first,
|
||||||
FunctionType::get(module.getContext(), p.second, {}),
|
FunctionType::get(module.getContext(), p.second, {}),
|
||||||
b.getStringAttr("private"));
|
b.getStringAttr("private"));
|
||||||
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -309,7 +309,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
|
mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
|
||||||
return std::make_unique<ExpandOpsForLLVM>();
|
return std::make_unique<ExpandOpsForLLVM>();
|
||||||
}
|
}
|
||||||
|
@ -366,7 +366,7 @@ class MungeMemrefCopy : public MungeMemrefCopyBase<MungeMemrefCopy> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::RefBackend::createMungeMemrefCopyPass() {
|
mlir::torch::RefBackend::createMungeMemrefCopyPass() {
|
||||||
return std::make_unique<MungeMemrefCopy>();
|
return std::make_unique<MungeMemrefCopy>();
|
||||||
}
|
}
|
||||||
|
@ -390,7 +390,7 @@ class GeneralizeTensorPad
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::RefBackend::createGeneralizeTensorPadPass() {
|
mlir::torch::RefBackend::createGeneralizeTensorPadPass() {
|
||||||
return std::make_unique<GeneralizeTensorPad>();
|
return std::make_unique<GeneralizeTensorPad>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,8 +36,8 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
||||||
MlirAttribute symNameAttr = mlirStringAttrGet(
|
MlirAttribute symNameAttr = mlirStringAttrGet(
|
||||||
context, toMlirStringRef(function->qualname().qualifiedName()));
|
context, toMlirStringRef(function->qualname().qualifiedName()));
|
||||||
MlirOperation func = createMlirOperation(
|
MlirOperation func = createMlirOperation(
|
||||||
"builtin.func", loc, mlirRegionCreate(),
|
"func.func", loc, mlirRegionCreate(),
|
||||||
toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)),
|
toMlirNamedAttribute("function_type", mlirTypeAttrGet(functionType)),
|
||||||
toMlirNamedAttribute("sym_name", symNameAttr));
|
toMlirNamedAttribute("sym_name", symNameAttr));
|
||||||
std::vector<MlirAttribute> argAttrDicts;
|
std::vector<MlirAttribute> argAttrDicts;
|
||||||
for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) {
|
for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) {
|
||||||
|
|
|
@ -29,7 +29,7 @@ import torch
|
||||||
from torch.jit import ScriptFunction
|
from torch.jit import ScriptFunction
|
||||||
|
|
||||||
from torch_mlir import ir
|
from torch_mlir import ir
|
||||||
from torch_mlir.dialects.builtin import FuncOp
|
from torch_mlir.dialects.func import FuncOp
|
||||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -115,15 +115,15 @@ class RefBackendInvoker:
|
||||||
|
|
||||||
|
|
||||||
LOWERING_PIPELINE = ",".join([
|
LOWERING_PIPELINE = ",".join([
|
||||||
"builtin.func(refback-generalize-tensor-pad)",
|
"func.func(refback-generalize-tensor-pad)",
|
||||||
# Bufferize.
|
# Bufferize.
|
||||||
"builtin.func(scf-bufferize)",
|
"func.func(scf-bufferize)",
|
||||||
"builtin.func(tm-tensor-bufferize)",
|
"func.func(tm-tensor-bufferize)",
|
||||||
"builtin.func(linalg-bufferize)",
|
"func.func(linalg-bufferize)",
|
||||||
"func-bufferize",
|
"func-bufferize",
|
||||||
"arith-bufferize",
|
"arith-bufferize",
|
||||||
"builtin.func(tensor-bufferize)",
|
"func.func(tensor-bufferize)",
|
||||||
"builtin.func(finalizing-bufferize)",
|
"func.func(finalizing-bufferize)",
|
||||||
# Munge to make it ExecutionEngine compatible.
|
# Munge to make it ExecutionEngine compatible.
|
||||||
# Specifically, we rewrite calling convention boundaries to be in terms
|
# Specifically, we rewrite calling convention boundaries to be in terms
|
||||||
# of unranked memref, and we rewrite the return to actually be a
|
# of unranked memref, and we rewrite the return to actually be a
|
||||||
|
@ -135,17 +135,17 @@ LOWERING_PIPELINE = ",".join([
|
||||||
# global seed used in stateful rng.
|
# global seed used in stateful rng.
|
||||||
"refback-insert-rng-globals",
|
"refback-insert-rng-globals",
|
||||||
# Lower to LLVM
|
# Lower to LLVM
|
||||||
"builtin.func(tm-tensor-to-loops)",
|
"func.func(tm-tensor-to-loops)",
|
||||||
"builtin.func(refback-munge-memref-copy)",
|
"func.func(refback-munge-memref-copy)",
|
||||||
"builtin.func(convert-linalg-to-loops)",
|
"func.func(convert-linalg-to-loops)",
|
||||||
"builtin.func(lower-affine)",
|
"func.func(lower-affine)",
|
||||||
"convert-scf-to-cf",
|
"convert-scf-to-cf",
|
||||||
"builtin.func(refback-expand-ops-for-llvm)",
|
"func.func(refback-expand-ops-for-llvm)",
|
||||||
"builtin.func(arith-expand)",
|
"func.func(arith-expand)",
|
||||||
"builtin.func(convert-math-to-llvm)",
|
"func.func(convert-math-to-llvm)",
|
||||||
"convert-linalg-to-llvm",
|
"convert-linalg-to-llvm",
|
||||||
"convert-memref-to-llvm",
|
"convert-memref-to-llvm",
|
||||||
"builtin.func(convert-arith-to-llvm)",
|
"func.func(convert-arith-to-llvm)",
|
||||||
"convert-func-to-llvm",
|
"convert-func-to-llvm",
|
||||||
"convert-cf-to-llvm",
|
"convert-cf-to-llvm",
|
||||||
"reconcile-unrealized-casts",
|
"reconcile-unrealized-casts",
|
||||||
|
|
|
@ -38,25 +38,25 @@ class LinalgOnTensorsTosaBackend(TosaBackend):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TOSA legalization may emit tosa.const() ops. These are legalized
|
# TOSA legalization may emit tosa.const() ops. These are legalized
|
||||||
# by tosa-to-standard to arith.constants. This mechanical transformation
|
# by tosa-to-arith to arith.constants. This mechanical transformation
|
||||||
# must be done prior to TOSA-to-LinAlg so that the latter does not fail.
|
# must be done prior to TOSA-to-LinAlg so that the latter does not fail.
|
||||||
# This is an artifact of legalizations spread across a collection of simple
|
# This is an artifact of legalizations spread across a collection of simple
|
||||||
# ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg,
|
# ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg,
|
||||||
# that depend on TOSA as well as TOSA-to-Standard.
|
# that depend on TOSA as well as TOSA-to-Standard.
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
imported_module,
|
imported_module,
|
||||||
"builtin.func(tosa-to-standard)",
|
"func.func(tosa-to-arith)",
|
||||||
"Lowering TOSA to Standard")
|
"Lowering TOSA to Arith")
|
||||||
|
|
||||||
# Named ops must be legalized prior to general tosa-to-linalg
|
# Named ops must be legalized prior to general tosa-to-linalg
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
imported_module,
|
imported_module,
|
||||||
"builtin.func(tosa-to-linalg-named)",
|
"func.func(tosa-to-linalg-named)",
|
||||||
"Lowering TOSA to Linalg-on-Tensors for Named Ops")
|
"Lowering TOSA to Linalg-on-Tensors for Named Ops")
|
||||||
|
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
imported_module,
|
imported_module,
|
||||||
"builtin.func(tosa-to-linalg)",
|
"func.func(tosa-to-linalg)",
|
||||||
"Lowering TOSA to Linalg-on-Tensors")
|
"Lowering TOSA to Linalg-on-Tensors")
|
||||||
return self.refbackend.compile(imported_module)
|
return self.refbackend.compile(imported_module)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @forward
|
// CHECK-LABEL: func @forward
|
||||||
builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
%int1 = torch.constant.int 1
|
%int1 = torch.constant.int 1
|
||||||
%int2 = torch.constant.int 2
|
%int2 = torch.constant.int 2
|
||||||
%int3 = torch.constant.int 3
|
%int3 = torch.constant.int 3
|
||||||
|
|
|
@ -42,7 +42,7 @@ torch.nn_module {
|
||||||
torch.slot "t1", %t : !torch.tensor
|
torch.slot "t1", %t : !torch.tensor
|
||||||
torch.slot "t2", %t : !torch.tensor
|
torch.slot "t2", %t : !torch.tensor
|
||||||
} : !torch.nn.Module<"c">
|
} : !torch.nn.Module<"c">
|
||||||
builtin.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor {
|
func.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor {
|
||||||
%t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor
|
%t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor
|
||||||
%t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor
|
%t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor
|
||||||
%cst = torch.constant.int 1
|
%cst = torch.constant.int 1
|
||||||
|
@ -63,7 +63,7 @@ torch.nn_module {
|
||||||
torch.slot "t1", %t : !torch.tensor
|
torch.slot "t1", %t : !torch.tensor
|
||||||
torch.slot "t2", %t : !torch.tensor
|
torch.slot "t2", %t : !torch.tensor
|
||||||
} : !torch.nn.Module<"c">
|
} : !torch.nn.Module<"c">
|
||||||
builtin.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) {
|
func.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) {
|
||||||
torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
|
torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
|
||||||
torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
|
torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
|
||||||
return
|
return
|
||||||
|
|
|
@ -4,8 +4,8 @@
|
||||||
|
|
||||||
torch.class_type @c {}
|
torch.class_type @c {}
|
||||||
%0 = torch.nn_module {
|
%0 = torch.nn_module {
|
||||||
// expected-error @+1 {{'builtin.func' op is not allowed inside 'torch.nn_module'}}
|
// expected-error @+1 {{'func.func' op is not allowed inside 'torch.nn_module'}}
|
||||||
builtin.func @f()
|
func.func @f()
|
||||||
} : !torch.nn.Module<"c">
|
} : !torch.nn.Module<"c">
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -32,8 +32,8 @@ torch.class_type @c {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
torch.class_type @c {
|
torch.class_type @c {
|
||||||
// expected-error @+1 {{'builtin.func' op is not allowed inside `torch.class_type`}}
|
// expected-error @+1 {{'func.func' op is not allowed inside `torch.class_type`}}
|
||||||
builtin.func @f()
|
func.func @f()
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -60,7 +60,7 @@ torch.class_type @c {
|
||||||
torch.method "f", @f
|
torch.method "f", @f
|
||||||
}
|
}
|
||||||
|
|
||||||
builtin.func @f(%arg0: !torch.nn.Module<"c">) {
|
func.func @f(%arg0: !torch.nn.Module<"c">) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,11 +71,11 @@ torch.class_type @c {
|
||||||
torch.method "f", @f
|
torch.method "f", @f
|
||||||
}
|
}
|
||||||
|
|
||||||
builtin.func private @f(%arg0: !torch.nn.Module<"c">)
|
func.func private @f(%arg0: !torch.nn.Module<"c">)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
builtin.func private @f() {
|
func.func private @f() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
torch.class_type @c {
|
torch.class_type @c {
|
||||||
|
@ -85,7 +85,7 @@ torch.class_type @c {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
builtin.func private @f(!torch.nn.Module<"other_c">) {
|
func.func private @f(%arg0: !torch.nn.Module<"other_c">) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
torch.class_type @c {
|
torch.class_type @c {
|
||||||
|
@ -101,21 +101,21 @@ torch.class_type @c {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}}
|
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}}
|
||||||
builtin.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
|
func.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// expected-error @+1 {{'torch.type_bound' must be TypeAttr}}
|
// expected-error @+1 {{'torch.type_bound' must be TypeAttr}}
|
||||||
builtin.func @f(%arg0: i32 {torch.type_bound = 1})
|
func.func @f(%arg0: i32 {torch.type_bound = 1})
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}}
|
// expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}}
|
||||||
builtin.func @f(%arg0: i32 {torch.type_bound = i32})
|
func.func @f(%arg0: i32 {torch.type_bound = i32})
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
builtin.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
|
func.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
|
||||||
// expected-error @+1 {{operand type '!torch.optional<tensor>' and result type '!torch.tensor' are cast incompatible}}
|
// expected-error @+1 {{operand type '!torch.optional<tensor>' and result type '!torch.tensor' are cast incompatible}}
|
||||||
%0 = torch.derefine %arg0 : !torch.optional<tensor> to !torch.tensor
|
%0 = torch.derefine %arg0 : !torch.optional<tensor> to !torch.tensor
|
||||||
return %0 : !torch.tensor
|
return %0 : !torch.tensor
|
||||||
|
@ -123,7 +123,7 @@ builtin.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<tensor> {
|
func.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<tensor> {
|
||||||
// expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<tensor>' are cast incompatible}}
|
// expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<tensor>' are cast incompatible}}
|
||||||
%0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<tensor>
|
%0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<tensor>
|
||||||
return %0 : !torch.optional<tensor>
|
return %0 : !torch.optional<tensor>
|
||||||
|
@ -132,11 +132,11 @@ builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// expected-error @+1 {{invalid dtype 'tuple<>' for !torch.tensor type}}
|
// expected-error @+1 {{invalid dtype 'tuple<>' for !torch.tensor type}}
|
||||||
builtin.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
|
func.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
builtin.func @torch.tensor() {
|
func.func @torch.tensor() {
|
||||||
// Incompatible shape.
|
// Incompatible shape.
|
||||||
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
||||||
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
|
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
|
||||||
|
@ -145,7 +145,7 @@ builtin.func @torch.tensor() {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
builtin.func @torch.tensor() {
|
func.func @torch.tensor() {
|
||||||
// Incompatible dtype.
|
// Incompatible dtype.
|
||||||
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
||||||
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
|
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
|
||||||
|
@ -154,7 +154,7 @@ builtin.func @torch.tensor() {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
builtin.func @torch.tensor() {
|
func.func @torch.tensor() {
|
||||||
// Incompatible type.
|
// Incompatible type.
|
||||||
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
||||||
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1
|
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1
|
||||||
|
@ -163,7 +163,7 @@ builtin.func @torch.tensor() {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
builtin.func @torch.prim.ListConstruct() {
|
func.func @torch.prim.ListConstruct() {
|
||||||
%int2 = torch.constant.int 2
|
%int2 = torch.constant.int 2
|
||||||
// expected-error@+1 {{operand types should have the same type as the list contained type}}
|
// expected-error@+1 {{operand types should have the same type as the list contained type}}
|
||||||
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<tensor>
|
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<tensor>
|
||||||
|
@ -172,7 +172,7 @@ builtin.func @torch.prim.ListConstruct() {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
builtin.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
|
func.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
|
||||||
%0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32>
|
%0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32>
|
||||||
// expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}}
|
// expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}}
|
||||||
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32>
|
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32>
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||||
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
|
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
builtin.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>,
|
func.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>,
|
||||||
%t1: !torch.vtensor<[1],f64>,
|
%t1: !torch.vtensor<[1],f64>,
|
||||||
%alpha: !torch.float) {
|
%alpha: !torch.float) {
|
||||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
|
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
|
||||||
|
@ -25,7 +25,7 @@ builtin.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||||
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
|
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
builtin.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
|
func.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
|
||||||
%t1: !torch.vtensor<[1],f64>,
|
%t1: !torch.vtensor<[1],f64>,
|
||||||
%alpha: !torch.float) {
|
%alpha: !torch.float) {
|
||||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
|
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
|
||||||
|
@ -41,7 +41,7 @@ builtin.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||||
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],f32>
|
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
builtin.func @tensor_tensor$same_category_zero_rank_wider(
|
func.func @tensor_tensor$same_category_zero_rank_wider(
|
||||||
%t0: !torch.vtensor<[1],f32>,
|
%t0: !torch.vtensor<[1],f32>,
|
||||||
%t1: !torch.vtensor<[],f64>,
|
%t1: !torch.vtensor<[],f64>,
|
||||||
%alpha: !torch.int) {
|
%alpha: !torch.int) {
|
||||||
|
@ -58,7 +58,7 @@ builtin.func @tensor_tensor$same_category_zero_rank_wider(
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||||
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
builtin.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>,
|
func.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>,
|
||||||
%t1: !torch.vtensor<[],f32>,
|
%t1: !torch.vtensor<[],f32>,
|
||||||
%alpha: !torch.int) {
|
%alpha: !torch.int) {
|
||||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk>
|
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk>
|
||||||
|
@ -73,7 +73,7 @@ builtin.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||||
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],f32>
|
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],f32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
builtin.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>,
|
func.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>,
|
||||||
%t1: !torch.vtensor<[1],f32>,
|
%t1: !torch.vtensor<[1],f32>,
|
||||||
%alpha: !torch.float) {
|
%alpha: !torch.float) {
|
||||||
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk>
|
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk>
|
||||||
|
@ -89,7 +89,7 @@ builtin.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||||
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],f32>
|
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],f32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
builtin.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) {
|
func.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) {
|
||||||
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk>
|
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -103,7 +103,7 @@ builtin.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>
|
||||||
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
|
||||||
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
|
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
builtin.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) {
|
func.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) {
|
||||||
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk>
|
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue