llvm: bump tag to e1318078 (#781)

The updated LLVM code includes a patch to create bfloat16 array
attributes, thus enabling a different patch to torch-mlir to flesh out
support for the bfloat16 type.
pull/761/head snapshot-20220426.416
Ashay Rane 2022-04-26 12:27:51 -07:00 committed by GitHub
parent 9ec4712516
commit 9208bf0eb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 377 additions and 300 deletions

View File

@ -463,7 +463,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
$_op->getAttrs());
for (Region &r : $_op->getRegions())
r.cloneInto(state.addRegion(), bvm);
return b.createOperation(state);
return b.create(state);
}]
>
];

View File

@ -31,9 +31,7 @@ class TMTensor_Op<string mnemonic, list<Trait> traits = []> :
TMTensorInterface,
SingleBlockImplicitTerminator<"::mlir::torch::TMTensor::YieldOp">
])> {
let verifier = [{ return verify$cppClass(*this); }];
let printer = [{ return print$cppClass(p, *this); }];
let parser = [{ return parse$cppClass(parser, result); }];
let hasVerifier = 1;
code extraTMTensorOpClassDeclaration = [{
SmallVector<Value> getDestinationOperands(OpBuilder &b) {
SmallVector<Value> dest(outputs().begin(), outputs().end());

View File

@ -10,6 +10,7 @@
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASS_DETAIL_H_
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {

View File

@ -10,14 +10,15 @@
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_TRANSFORMS_PASSES_H_
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace torch {
namespace TMTensor {
std::unique_ptr<OperationPass<FuncOp>> createTMTensorToLoopsPass();
std::unique_ptr<OperationPass<FuncOp>> createTMTensorBufferizePass();
std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorToLoopsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createTMTensorBufferizePass();
void registerPasses();

View File

@ -13,12 +13,12 @@
include "mlir/Pass/PassBase.td"
def TMTensorToLoops :
Pass<"tm-tensor-to-loops", "FuncOp"> {
Pass<"tm-tensor-to-loops", "func::FuncOp"> {
let summary = "Convert TMTensor ops to loops and Linalg ops.";
let constructor = "mlir::torch::TMTensor::createTMTensorToLoopsPass()";
}
def TMTensorBufferize : Pass<"tm-tensor-bufferize", "FuncOp"> {
def TMTensorBufferize : Pass<"tm-tensor-bufferize", "func::FuncOp"> {
let summary = "Bufferize the TMTensor dialect";
let constructor = "mlir::torch::TMTensor::createTMTensorBufferizePass()";
}

View File

@ -88,34 +88,34 @@ OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
// ScanOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyScanOp(ScanOp op) {
if (op.getNumInputs() != 1) {
return op.emitOpError("expected one input operands");
LogicalResult ScanOp::verify() {
if (getNumInputs() != 1) {
return emitOpError("expected one input operands");
}
if (op.getNumOutputs() != 2) {
return op.emitOpError("expected two output operands");
if (getNumOutputs() != 2) {
return emitOpError("expected two output operands");
}
if (!op.input().getType().isa<ShapedType>()) {
return op.emitOpError("expected first input element type to be shaped");
if (!input().getType().isa<ShapedType>()) {
return emitOpError("expected first input element type to be shaped");
}
auto accumulatorType = op.accumulator().getType().cast<ShapedType>();
auto inputType = op.input().getType().cast<ShapedType>();
auto outputType = op.output().getType().cast<ShapedType>();
auto accumulatorType = accumulator().getType().cast<ShapedType>();
auto inputType = input().getType().cast<ShapedType>();
auto outputType = output().getType().cast<ShapedType>();
ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape();
if (accumulatorType.getElementType() != inputType.getElementType()) {
return op.emitOpError(
return emitOpError(
"expected input/accumulator element types to be identical");
}
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
int64_t accumulatorRank = accumulatorType.getRank();
if (accumulatorRank != inputType.getRank() - 1) {
return op.emitOpError(
return emitOpError(
"expected accumulator rank to be equal to input rank - 1");
}
SmallVector<int64_t> expectedAccumulatorShape;
for (size_t i = 0; i < (size_t)inputType.getRank(); i++) {
if (i != op.dimension())
if (i != dimension())
expectedAccumulatorShape.push_back(inputShapes[i]);
}
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
@ -124,14 +124,13 @@ static LogicalResult verifyScanOp(ScanOp op) {
std::get<1>(s) != ShapedType::kDynamicSize &&
std::get<0>(s) != std::get<1>(s);
})) {
return op.emitOpError("incompatible input/accumulator shapes");
return emitOpError("incompatible input/accumulator shapes");
}
if (inputType.getElementType() != outputType.getElementType()) {
return op.emitOpError(
"expected input/output element types to be identical");
return emitOpError("expected input/output element types to be identical");
}
if (inputShapes.size() != outputShapes.size()) {
return op.emitOpError("expected input/output to have identical ranks");
return emitOpError("expected input/output to have identical ranks");
}
if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
[](std::tuple<int64_t, int64_t> s) {
@ -139,7 +138,7 @@ static LogicalResult verifyScanOp(ScanOp op) {
std::get<1>(s) != ShapedType::kDynamicSize &&
std::get<0>(s) != std::get<1>(s);
})) {
return op.emitOpError("incompatible input/output shapes");
return emitOpError("incompatible input/output shapes");
}
return success();
}
@ -232,11 +231,11 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
});
auto &srcBlock = region().front();
Region &region = scfIf.getElseRegion();
Region &thisRegion = scfIf.getElseRegion();
BlockAndValueMapping bvm;
{
OpBuilder::InsertionGuard guard(b);
auto &block = region.front();
auto &block = thisRegion.front();
b.setInsertionPointToEnd(&block);
for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
bvm.map(std::get<0>(it), std::get<1>(it));
@ -275,48 +274,47 @@ LogicalResult ScanOp::fold(ArrayRef<Attribute>,
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyScatterOp(ScatterOp op) {
if (op.inputs().size() != 2) {
return op.emitOpError("expected two input operands");
LogicalResult ScatterOp::verify() {
if (inputs().size() != 2) {
return emitOpError("expected two input operands");
}
if (op.outputs().size() != 1) {
return op.emitOpError("expected one output operand");
if (outputs().size() != 1) {
return emitOpError("expected one output operand");
}
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
return t1.getShape()[dim] == t2.getShape()[dim];
};
auto indicesType = op.getIndicesType();
auto indicesType = getIndicesType();
if (indicesType.getRank() != 2 ||
!indicesType.getElementType().isInteger(32)) {
return op.emitOpError(
"expected indices to be of rank 2 of i32 element type");
return emitOpError("expected indices to be of rank 2 of i32 element type");
}
auto indexDepth = op.getIndexDepth();
auto indexDepth = getIndexDepth();
if (indexDepth == ShapedType::kDynamicSize) {
return op.emitOpError("expected index depth is static");
return emitOpError("expected index depth is static");
}
// The first dimension of the indices should match the first dimension of the
// output. They indicate to the number of updates.
auto updateType = op.getUpdateType();
auto updateType = getUpdateType();
if (updateType.getRank() < 1) {
return op.emitOpError("expected update value to be at least rank 1");
return emitOpError("expected update value to be at least rank 1");
}
if (!checkDimensionsMatch(indicesType, updateType, 0)) {
return op.emitOpError(
return emitOpError(
"mismatch in shape of indices and update value at dim#0");
}
auto originalType = op.getOriginalType();
auto originalType = getOriginalType();
if (updateType.getRank() - 1 > originalType.getRank()) {
return op.emitOpError(
return emitOpError(
"update value rank exceeds the rank of the original value");
}
// indexDepth + update dims should cover the original dims. The first dim of
// update is the number of updates.
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
return op.emitOpError(
return emitOpError(
"index depth and update value does not cover rank of original value");
}
@ -331,7 +329,7 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
int64_t updateDim = std::get<1>(it);
if (updateType.getDimSize(updateDim) !=
originalType.getDimSize(originalDim)) {
return op.emitOpError("mismatch in shape of update value dim#")
return emitOpError("mismatch in shape of update value dim#")
<< updateDim << " and original value at dim#" << originalDim;
}
}
@ -345,36 +343,36 @@ static LogicalResult verifyScatterOp(ScatterOp op) {
int64_t updateDim = std::get<1>(it);
if (updateType.getDimSize(updateDim) >
originalType.getDimSize(originalDim)) {
return op.emitOpError("indexed shape of update value dim#")
return emitOpError("indexed shape of update value dim#")
<< updateDim << " exceeds original value at dim#" << originalDim
<< " " << updateType.getDimSize(updateDim) << " "
<< originalType.getDimSize(originalDim);
}
}
Region &region = op.region();
Block *body = &region.front();
Region &thisRegion = region();
Block *body = &thisRegion.front();
if (body->getNumArguments() != 2) {
return op.emitOpError("expected region to have two arguments");
return emitOpError("expected region to have two arguments");
}
Type arg0Type = body->getArgument(0).getType();
Type arg1Type = body->getArgument(1).getType();
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
return op.emitOpError(
return emitOpError(
"expected region to have scalar argument of integer or float types");
}
if (arg0Type != updateType.getElementType()) {
return op.emitOpError("mismatch in argument 0 of region ")
return emitOpError("mismatch in argument 0 of region ")
<< arg0Type << " and element type of update value "
<< updateType.getElementType();
}
if (arg1Type != originalType.getElementType()) {
return op.emitOpError("mismatch in argument 1 of region ")
return emitOpError("mismatch in argument 1 of region ")
<< arg1Type << " and element type of original value "
<< originalType.getElementType();
}
if (arg0Type != arg1Type) {
return op.emitOpError("mismatch in region argument types ")
return emitOpError("mismatch in region argument types ")
<< arg0Type << " and " << arg1Type;
}
auto yieldOp = cast<TMTensor::YieldOp>(body->getTerminator());
@ -455,7 +453,8 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
if (starts[i]) cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
if (starts[i])
cast = b.create<arith::AddIOp>(loc, cast, starts[i]);
starts[i] = cast;
}

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
@ -150,7 +151,7 @@ struct TMTensorBufferizePass
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
torch::TMTensor::createTMTensorBufferizePass() {
return std::make_unique<TMTensorBufferizePass>();
}

View File

@ -7,6 +7,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
@ -111,7 +112,7 @@ struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
torch::TMTensor::createTMTensorToLoopsPass() {
return std::make_unique<TMTensorToLoopsPass>();
}

@ -1 +1 @@
Subproject commit 8361c5da30588d3d4a48eae648f53be1feb5cfad
Subproject commit e1318078a4e160eb723bcbcfcdcc9a1b618f7067

View File

@ -16,17 +16,17 @@ include "mlir/Pass/PassBase.td"
// Torch conversions
//===----------------------------------------------------------------------===//
def ConvertTorchToStd : Pass<"convert-torch-to-std", "FuncOp"> {
def ConvertTorchToStd : Pass<"convert-torch-to-std", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Std ops";
let constructor = "mlir::torch::createConvertTorchToStdPass()";
}
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "FuncOp"> {
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to SCF ops";
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
}
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Linalg ops";
let description = [{
Convert ATen ops to linalg ops.
@ -105,7 +105,7 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
}
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
let summary = "Convert Torch ops to TOSA ops";
let description = [{
This pass assumes that TOSA ops are responsible for emitting error
@ -114,7 +114,7 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToTosaPass()";
}
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "FuncOp"> {
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
let description = [{
Convert ATen ops to tmtensor/linalg ops.

View File

@ -10,12 +10,13 @@
#ifndef TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
#define TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToLinalgPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
}
} // namespace mlir

View File

@ -10,11 +10,12 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
#define TORCHMLIR_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToSCFPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToSCFPass();
}
} // namespace mlir

View File

@ -10,12 +10,13 @@
#ifndef TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
#define TORCHMLIR_CONVERSION_ATENTOSTD_ATENTOSTD_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToStdPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToStdPass();
}
} // namespace mlir

View File

@ -10,11 +10,12 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
#define TORCHMLIR_CONVERSION_TORCHTOTMTENSOR_TORCHTOTMTENSOR_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTMTensorPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTMTensorPass();
}
} // namespace mlir

View File

@ -10,12 +10,13 @@
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToTosaPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTosaPass();
}
} // namespace mlir

View File

@ -10,6 +10,8 @@
#ifndef TORCH_TYPES
#define TORCH_TYPES
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/DialectBase.td"
include "torch-mlir/Dialect/Torch/IR/TorchBase.td"
//===----------------------------------------------------------------------===//
@ -24,28 +26,8 @@ class Torch_Type<string name, string typeMnemonic,
class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type<name, typeMnemonic> {
let parameters = (ins "::mlir::Type":$containedType);
let hasCustomAssemblyFormat = 1;
let printer = [{
$_printer << "<";
// Print the contained type without the `!torch.` prefix.
printTorchDialectType(getImpl()->containedType, $_printer);
$_printer << ">";
}];
let parser = [{
if ($_parser.parseLess())
return Type();
// Parse the contained type, but forward directly to our internal parsing
// of `torch` dialect types, so that we can parse nested types without
// the `!torch.` prefix.
Type containedType = parseTorchDialectType($_parser);
if (!containedType)
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, containedType);
}];
let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
return Base::get(containedType.getContext(), containedType);
@ -59,23 +41,7 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
Represents an instance of a `torch.nn.Module` with the given `className`.
}];
let parameters = (ins StringRefParameter<"class name">:$className);
let printer = [{
$_printer << "<\"";
llvm::printEscapedString(getImpl()->className, $_printer.getStream());
$_printer << "\">";
}];
let parser = [{
if ($_parser.parseLess())
return Type();
std::string className;
if ($_parser.parseOptionalString(&className))
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, className);
}];
let hasCustomAssemblyFormat = 1;
}
// For standard ArrayRefs, which require allocation.
@ -186,6 +152,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
"::mlir::Type":$optionalDtype
);
let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
string extraBaseClassDeclaration = [{
}];
}
@ -243,6 +210,7 @@ def Torch_TupleType : Torch_Type<"Tuple", "tuple"> {
let parameters = (ins
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
);
let hasCustomAssemblyFormat = 1;
}
def Torch_UnionType : Torch_Type<"Union", "union"> {
@ -261,6 +229,7 @@ def Torch_UnionType : Torch_Type<"Union", "union"> {
let parameters = (ins
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
);
let hasCustomAssemblyFormat = 1;
}
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
@ -367,30 +336,7 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
let description = [{
Torch Dict type with key and value type.
}];
let printer = [{
$_printer << "<";
printTorchDialectType(getImpl()->keyType, $_printer);
$_printer << ", ";
printTorchDialectType(getImpl()->valueType, $_printer);
$_printer<< ">";
}];
let parser = [{
if ($_parser.parseLess())
return Type();
Type keyType = parseTorchDialectType($_parser);
if (!keyType)
return Type();
if ($_parser.parseComma())
return Type();
Type valueType = parseTorchDialectType($_parser);
if (!valueType)
return Type();
if ($_parser.parseGreater())
return Type();
return get($_ctxt, keyType, valueType);
}];
let hasCustomAssemblyFormat = 1;
let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
"::mlir::Type":$valueType), [{

View File

@ -10,11 +10,14 @@
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
class ModuleOp;
namespace torch {
namespace Torch {
@ -48,25 +51,26 @@ void createTorchShapeRefinementPipeline(
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
std::unique_ptr<OperationPass<FuncOp>> createRefineTypesPass();
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
std::unique_ptr<OperationPass<ModuleOp>> createInlineGlobalSlotsPass();
std::unique_ptr<OperationPass<FuncOp>> createReduceOpVariantsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createReduceOpVariantsPass();
std::unique_ptr<OperationPass<FuncOp>> createMaximizeValueSemanticsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createMaximizeValueSemanticsPass();
std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<FuncOp>> createDecomposeComplexOpsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createDecomposeComplexOpsPass();
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
std::unique_ptr<OperationPass<FuncOp>> createSimplifyShapeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createSimplifyShapeCalculationsPass();
std::unique_ptr<OperationPass<FuncOp>> createDropShapeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
StringRef getShapeLibrary();

View File

@ -126,7 +126,7 @@ def AdjustCallingConventions
}];
}
def RefineTypes : Pass<"torch-refine-types", "FuncOp"> {
def RefineTypes : Pass<"torch-refine-types", "func::FuncOp"> {
let summary = "Refine types";
let constructor = "mlir::torch::Torch::createRefineTypesPass()";
let description = [{
@ -149,7 +149,7 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> {
}];
}
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> {
let summary = "Reduces variants of ops to a smaller set of ops.";
let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()";
let description = [{
@ -165,7 +165,7 @@ def ReduceOpVariants : Pass<"torch-reduce-op-variants", "FuncOp"> {
}];
}
def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "FuncOp"> {
def MaximizeValueSemantics : Pass<"torch-maximize-value-semantics", "func::FuncOp"> {
let summary = "Use value-semantic tensors where possible.";
let description = [{
Use value-semantic tensors where possible to make the program more
@ -215,7 +215,7 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {
}];
}
def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> {
def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
let summary = "Decompose complicated torch operations";
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
let description = [{
@ -238,7 +238,7 @@ def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp">
}];
}
def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "FuncOp"> {
def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func::FuncOp"> {
let summary = "Simplify reified shape calculations.";
let constructor = "mlir::torch::Torch::createSimplifyShapeCalculationsPass()";
let description = [{
@ -246,7 +246,7 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "FuncO
}];
}
def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "FuncOp"> {
def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"> {
let summary = "Drop reified shape calculations.";
let constructor = "mlir::torch::Torch::createDropShapeCalculationsPass()";
let description = [{

View File

@ -10,12 +10,15 @@
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include <memory>
namespace mlir {
class ModuleOp;
namespace torch {
namespace TorchConversion {
@ -36,7 +39,7 @@ createVerifyInvariantsBeforeBackendLoweringPass();
std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
createFinalizingBackendTypeConversionPass();
std::unique_ptr<OperationPass<ModuleOp>>

View File

@ -43,7 +43,7 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu
}
def FinalizingBackendTypeConversion
: Pass<"torch-finalizing-backend-type-conversion", "FuncOp"> {
: Pass<"torch-finalizing-backend-type-conversion", "func::FuncOp"> {
let summary = "Finalizes a partial conversion to builtin tensors";
let constructor =
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()";

View File

@ -10,10 +10,13 @@
#ifndef TORCHMLIR_REFBACKEND_PASSES_H
#define TORCHMLIR_REFBACKEND_PASSES_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
namespace mlir {
class ModuleOp;
namespace torch {
namespace RefBackend {
@ -22,13 +25,13 @@ void registerRefBackendPasses();
std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
std::unique_ptr<OperationPass<FuncOp>> createExpandOpsForLLVMPass();
std::unique_ptr<OperationPass<func::FuncOp>> createExpandOpsForLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
std::unique_ptr<OperationPass<FuncOp>> createMungeMemrefCopyPass();
std::unique_ptr<OperationPass<func::FuncOp>> createMungeMemrefCopyPass();
std::unique_ptr<OperationPass<FuncOp>> createGeneralizeTensorPadPass();
std::unique_ptr<OperationPass<func::FuncOp>> createGeneralizeTensorPadPass();
} // namespace RefBackend
} // namespace torch
} // namespace mlir

View File

@ -24,18 +24,18 @@ def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> {
let dependentDialects = ["memref::MemRefDialect"];
}
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "FuncOp"> {
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "func::FuncOp"> {
let summary = "Expand ops into more primitive ops before LLVM lowering.";
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
}
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "FuncOp"> {
def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> {
let summary = "Munge memref.copy to linalg.copy";
let constructor = "mlir::torch::RefBackend::createMungeMemrefCopyPass();";
let dependentDialects = ["memref::MemRefDialect"];
}
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "FuncOp"> {
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> {
let summary = "Convert tensor.pad to linalg ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
}

View File

@ -10,6 +10,7 @@
#ifndef TORCHMLIR_CONVERSION_PASSDETAIL_H
#define TORCHMLIR_CONVERSION_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {

View File

@ -88,7 +88,7 @@ public:
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToLinalgPass() {
return std::make_unique<ConvertTorchToLinalg>();
}

View File

@ -91,7 +91,7 @@ public:
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToSCFPass() {
return std::make_unique<ConvertTorchToSCF>();
}

View File

@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -221,7 +222,7 @@ public:
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToStdPass() {
return std::make_unique<ConvertTorchToStd>();
}

View File

@ -10,8 +10,10 @@
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/MLIRContext.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
@ -566,7 +568,7 @@ public:
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTMTensorPass() {
return std::make_unique<ConvertTorchToTMTensor>();
}

View File

@ -3170,7 +3170,7 @@ public:
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToTosaPass() {
return std::make_unique<ConvertTorchToTosa>();
}

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

View File

@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
@ -124,7 +125,7 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
unsigned argIndex,
NamedAttribute namedAttr) {
if (namedAttr.getName().getValue() == "torch.type_bound") {
auto func = dyn_cast<FuncOp>(op);
auto func = dyn_cast<func::FuncOp>(op);
if (!func)
return op->emitError() << "'torch.type_bound' must be attached to a func";
TypeAttr attr = namedAttr.getValue().dyn_cast<TypeAttr>();
@ -134,7 +135,7 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
if (!type)
return op->emitError() << "'torch.type_bound' must be of "
"!torch.tensor/!torch.vtensor type";
if (!func.getType().getInput(argIndex).isa<BaseTensorType>())
if (!func.getFunctionType().getInput(argIndex).isa<BaseTensorType>())
return op->emitError() << "'torch.type_bound' must be attached to an "
"argument of !torch.tensor/!torch.vtensor type";
return success();
@ -177,3 +178,100 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
return nullptr;
}
//===----------------------------------------------------------------------===//
// OptionalType and ListType
//===----------------------------------------------------------------------===//
void OptionalType::print(AsmPrinter &printer) const {
printer << "<";
// Print the contained type without the `!torch.` prefix.
printTorchDialectType(getImpl()->containedType, printer);
printer << ">";
}
void ListType::print(AsmPrinter &printer) const {
printer << "<";
// Print the contained type without the `!torch.` prefix.
printTorchDialectType(getImpl()->containedType, printer);
printer << ">";
}
Type OptionalType::parse(AsmParser &odsParser) {
if (odsParser.parseLess())
return Type();
// Parse the contained type, but forward directly to our internal parsing
// of `torch` dialect types, so that we can parse nested types without
// the `!torch.` prefix.
Type containedType = parseTorchDialectType(odsParser);
if (!containedType)
return Type();
if (odsParser.parseGreater())
return Type();
return get(odsParser.getContext(), containedType);
}
Type ListType::parse(AsmParser &odsParser) {
if (odsParser.parseLess())
return Type();
// Parse the contained type, but forward directly to our internal parsing
// of `torch` dialect types, so that we can parse nested types without
// the `!torch.` prefix.
Type containedType = parseTorchDialectType(odsParser);
if (!containedType)
return Type();
if (odsParser.parseGreater())
return Type();
return get(odsParser.getContext(), containedType);
}
//===----------------------------------------------------------------------===//
// DictType
//===----------------------------------------------------------------------===//
void DictType::print(AsmPrinter &printer) const {
printer << "<";
printTorchDialectType(getImpl()->keyType, printer);
printer << ", ";
printTorchDialectType(getImpl()->valueType, printer);
printer << ">";
}
Type DictType::parse(AsmParser &odsParser) {
if (odsParser.parseLess())
return Type();
Type keyType = parseTorchDialectType(odsParser);
if (!keyType)
return Type();
if (odsParser.parseComma())
return Type();
Type valueType = parseTorchDialectType(odsParser);
if (!valueType)
return Type();
if (odsParser.parseGreater())
return Type();
return get(odsParser.getContext(), keyType, valueType);
}
//===----------------------------------------------------------------------===//
// NnModuleType
//===----------------------------------------------------------------------===//
void NnModuleType::print(AsmPrinter &printer) const {
printer << "<\"";
llvm::printEscapedString(getImpl()->className, printer.getStream());
printer << "\">";
}
Type NnModuleType::parse(AsmParser &odsParser) {
if (odsParser.parseLess())
return Type();
std::string className;
if (odsParser.parseOptionalString(&className))
return Type();
if (odsParser.parseGreater())
return Type();
return get(odsParser.getContext(), className);
}

View File

@ -9,6 +9,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
@ -99,7 +100,7 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto func =
symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, functionAttr());
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, functionAttr());
if (!func)
return emitError() << "'@" << function()
<< "' does not reference a valid function";
@ -112,8 +113,8 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
"merely declared)";
auto expectedReceiverArgType = NnModuleType::get(
getContext(), getOperation()->getParentOfType<ClassTypeOp>().getName());
if (func.getType().getNumInputs() == 0 ||
func.getType().getInput(0) != expectedReceiverArgType) {
if (func.getFunctionType().getNumInputs() == 0 ||
func.getFunctionType().getInput(0) != expectedReceiverArgType) {
return emitError() << "the referenced function '" << function()
<< "' must have a first argument of type "
<< expectedReceiverArgType;
@ -278,7 +279,7 @@ ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
Region *elseRegion = result.addRegion();
auto &builder = parser.getBuilder();
OpAsmParser::OperandType cond;
OpAsmParser::UnresolvedOperand cond;
Type boolType = builder.getType<Torch::BoolType>();
if (parser.parseOperand(cond) ||
parser.resolveOperand(cond, boolType, result.operands))

View File

@ -35,7 +35,7 @@ ParseResult Torch::parseDefaultTorchOp(OpAsmParser &parser,
OperationState &result, int numOperands,
int numResults) {
llvm::SMLoc loc = parser.getCurrentLocation();
SmallVector<OpAsmParser::OperandType> operands;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/numOperands))
return failure();
if (parser.parseOptionalAttrDict(result.attributes))

View File

@ -31,11 +31,12 @@ using namespace mlir::torch::Torch;
using TypeBoundMap = DenseMap<std::pair<StringRef, int>, Type>;
namespace {
class AdjustCallingConventionForFunc : public OpConversionPattern<FuncOp> {
class AdjustCallingConventionForFunc
: public OpConversionPattern<func::FuncOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(FuncOp func, OpAdaptor adaptor,
matchAndRewrite(func::FuncOp func, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = func.getContext();
auto typeBoundIdent = StringAttr::get(context, "torch.type_bound");
@ -70,7 +71,7 @@ public:
typeConverter);
SmallVector<Type> newResultTypes;
for (auto type : func.getType().getResults()) {
for (auto type : func.getFunctionType().getResults()) {
if (auto none = type.dyn_cast<Torch::NoneType>()) {
continue;
}
@ -186,7 +187,7 @@ public:
};
} // namespace
static LogicalResult adjustCallingConventions(FuncOp func,
static LogicalResult adjustCallingConventions(func::FuncOp func,
TypeBoundMap &typeBoundMap) {
MLIRContext *context = func.getContext();
RewritePatternSet patterns(context);
@ -217,7 +218,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
patterns.add<AdjustCallingConventionForReturn>(typeConverter, context);
ConversionTarget target(*context);
target.addDynamicallyLegalOp<FuncOp>([](FuncOp func) {
target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp func) {
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
if (func.getArgAttr(i, "torch.type_bound"))
return false;
@ -225,7 +226,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
return false;
}
for (int i = 0, e = func.getNumResults(); i != e; i++) {
if (func.getType().getResults()[i].isa<Torch::NoneType>())
if (func.getFunctionType().getResults()[i].isa<Torch::NoneType>())
return false;
}
return true;
@ -266,7 +267,7 @@ class AdjustCallingConventionsPass
void runOnOperation() override {
auto module = getOperation();
TypeBoundMap typeBoundMap;
for (auto func : module.getOps<FuncOp>()) {
for (auto func : module.getOps<func::FuncOp>()) {
for (int i = 0, e = func.getNumArguments(); i != e; i++) {
auto typeBoundAttr =
func.getArgAttrOfType<TypeAttr>(i, "torch.type_bound");
@ -275,7 +276,7 @@ class AdjustCallingConventionsPass
typeBoundMap[{func.getName(), i}] = typeBoundAttr.getValue();
}
}
for (auto func : module.getOps<FuncOp>()) {
for (auto func : module.getOps<func::FuncOp>()) {
if (failed(adjustCallingConventions(func, typeBoundMap)))
return signalPassFailure();
}

View File

@ -1781,7 +1781,7 @@ class DecomposeComplexOpsPass
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createDecomposeComplexOpsPass() {
return std::make_unique<DecomposeComplexOpsPass>();
}

View File

@ -54,7 +54,7 @@ class DropShapeCalculationsPass
patterns.insert<DropShapeCalculateOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<ShapeCalculateOp>();
target.addLegalOp<FuncOp>();
target.addLegalOp<func::FuncOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
@ -64,7 +64,7 @@ class DropShapeCalculationsPass
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createDropShapeCalculationsPass() {
return std::make_unique<DropShapeCalculationsPass>();
}

View File

@ -88,7 +88,7 @@ public:
return it->second;
}
Optional<LinkageInfo> getFuncLinkageInfo(NnModuleOp instance,
FuncOp methodFunc) {
func::FuncOp methodFunc) {
auto it = funcLinkageInfo.find({instance, methodFunc});
if (it == funcLinkageInfo.end())
return None;
@ -185,7 +185,7 @@ private:
for (auto method : classType.getOps<MethodOp>()) {
nameStack.push_back(method.name().str());
funcLinkageInfo[{nnModule,
symbolTable.lookup<FuncOp>(method.function())}] =
symbolTable.lookup<func::FuncOp>(method.function())}] =
LinkageInfo{llvm::join(nameStack, "."), method.isPrivate()};
nameStack.pop_back();
}
@ -251,7 +251,7 @@ private:
// Linkage info for each method in the program. Since we are going to be
// monomorphizing all the functions, we also need to key this off of the
// instance (NnModuleOp) that the func is monomorphized for.
DenseMap<std::pair<NnModuleOp, FuncOp>, LinkageInfo> funcLinkageInfo;
DenseMap<std::pair<NnModuleOp, func::FuncOp>, LinkageInfo> funcLinkageInfo;
// The corresponding GlobalSlotOp for each SlotOp in the program.
DenseMap<SlotOp, GlobalSlotOp> slotToGlobalSlot;
// A set of values that we have copied into torch.global_slot initializers,
@ -298,7 +298,7 @@ namespace {
// any notion of "type" that we have in the IR, but still fits the formal
// definition.
struct Monomorphization {
FuncOp func;
func::FuncOp func;
std::vector<ArgInstance> argInstances;
};
} // namespace
@ -327,7 +327,7 @@ template <> struct llvm::DenseMapInfo<Monomorphization> {
//
// This generalizes to a full abstract interpretation of the function, but
// currently only analyzes a subset of ops.
static LogicalResult analyzeInstances(FuncOp func,
static LogicalResult analyzeInstances(func::FuncOp func,
ArrayRef<ArgInstance> argInstances,
BlockAndValueMapping &mapping) {
for (auto &argInstance : argInstances)
@ -351,7 +351,7 @@ static LogicalResult analyzeInstances(FuncOp func,
static FailureOr<Monomorphization>
createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping,
SymbolTable &symbolTable) {
auto func = symbolTable.lookup<FuncOp>(op.getCallee());
auto func = symbolTable.lookup<func::FuncOp>(op.getCallee());
Monomorphization monomorphization;
monomorphization.func = func;
for (auto operand : llvm::enumerate(op->getOperands())) {
@ -372,7 +372,7 @@ public:
: module(module), symbolTable(module) {}
LogicalResult
initialize(DenseMap<ClassTypeOp, std::vector<NnModuleOp>> &instances) {
for (auto func : module.getOps<FuncOp>()) {
for (auto func : module.getOps<func::FuncOp>()) {
Monomorphization monomorphization;
monomorphization.func = func;
bool canTriviallyMonomorphize = true;
@ -455,7 +455,7 @@ static LogicalResult verifyNnModuleValueUses(Value value) {
// Verify that `func` conforms to the subset of allowable method bodies
// that we can convert.
static LogicalResult verifyFuncConformsToSubset(FuncOp func) {
static LogicalResult verifyFuncConformsToSubset(func::FuncOp func) {
// TODO: Investingate why WalkResult::interrupt() doesn't propagate properly.
LogicalResult ret = success();
func.walk([&](Block *block) {
@ -481,7 +481,7 @@ static LogicalResult verifyFuncConformsToSubset(FuncOp func) {
static LogicalResult
verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
MonomorphizationTracker &tracker) {
DenseMap<FuncOp, int> numMonomorphizations;
DenseMap<func::FuncOp, int> numMonomorphizations;
for (auto &monomorphization : tracker.getMonomorphizations()) {
numMonomorphizations[monomorphization.func] += 1;
}
@ -489,7 +489,7 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
for (auto classType : module.getOps<ClassTypeOp>()) {
for (auto method : classType.getOps<MethodOp>()) {
if (!method.isPrivate()) {
if (numMonomorphizations[symbolTable.lookup<FuncOp>(
if (numMonomorphizations[symbolTable.lookup<func::FuncOp>(
method.function())] > 1) {
method.emitError()
<< "public function with multiple monomorphizations";
@ -503,10 +503,9 @@ verifyPublicMonomorphizations(ModuleOp module, SymbolTable &symbolTable,
// Rewrite `func`, given that all values of `NnModuleType` have been mapped in
// `mapping` to corresponding global instances.
static LogicalResult
rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping,
SymbolTable &symbolTable,
DenseMap<Monomorphization, FuncOp> &newFuncs,
static LogicalResult rewriteMonomorphizedFuncClone(
func::FuncOp func, BlockAndValueMapping mapping, SymbolTable &symbolTable,
DenseMap<Monomorphization, func::FuncOp> &newFuncs,
ObjectGraphInfo &objectGraphInfo) {
SmallVector<Operation *> toErase;
@ -605,7 +604,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
// static analysis to discover how to monomorphize th eprogram, including
// tracking instances through control flow, through get/set attr, etc. We
// implement a very simple subset of cases.
for (auto func : module.getOps<FuncOp>()) {
for (auto func : module.getOps<func::FuncOp>()) {
if (failed(verifyFuncConformsToSubset(func)))
return failure();
}
@ -637,10 +636,10 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
// Step 4: Clone/rewrite functions to implement the necessary
// monomorphizations.
DenseMap<Monomorphization, FuncOp> newFuncs;
DenseMap<Monomorphization, func::FuncOp> newFuncs;
int uniquifier = 0;
for (auto &monomorphization : tracker.getMonomorphizations()) {
auto newFunc = cast<FuncOp>(monomorphization.func->clone());
auto newFunc = cast<func::FuncOp>(monomorphization.func->clone());
newFuncs[monomorphization] = newFunc;
Optional<LinkageInfo> linkageInfo = None;
// If it is potentially a method, check its linkage info.
@ -675,14 +674,14 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
}
// Step 5: Clean up object graph.
DenseSet<FuncOp> liveFuncs;
DenseSet<func::FuncOp> liveFuncs;
for (auto &kv : newFuncs) {
liveFuncs.insert(kv.second);
}
for (auto &op : llvm::make_early_inc_range(module.getOps())) {
if (isa<GlobalSlotOp>(&op))
continue;
if (auto func = dyn_cast<FuncOp>(op)) {
if (auto func = dyn_cast<func::FuncOp>(op)) {
if (liveFuncs.contains(func))
continue;
}

View File

@ -355,7 +355,7 @@ class MaximizeValueSemanticsPass
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createMaximizeValueSemanticsPass() {
return std::make_unique<MaximizeValueSemanticsPass>();
}

View File

@ -10,9 +10,11 @@
#ifndef TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#define TORCHMLIR_DIALECT_TORCH_TRANSFORMS_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
class ModuleOp;
namespace torch {
namespace Torch {

View File

@ -94,7 +94,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
if (options.optimize) {
// Eliminate the PrimTupleIndexOp generated from the
// adjustCallingConventions
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Inline global slots, which for most inference scenarios deletes them.
// This also exposes more information to intraprocedural transformations
// below like MaximizeValueSemantics and RefineTypes.
@ -105,14 +105,14 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
}
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<FuncOp>(createReduceOpVariantsPass());
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
if (options.optimize) {
// OPT-ONLY: Right now we rely on this to eliminate certain branches that
// guard unreachable code that backends can't handle yet, such as lists,
// RaiseException, unimplemented tensor ops, and only-used-in-training
// operations on `torch.global_slot`'s.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// OPT-ONLY: We may have deleted some `torch.global_slot.get` /
// `torch.global_slot.get` ops, which may have left more
// `torch.global_slot`'s unused.
@ -124,7 +124,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
//===--------------------------------------------------------------------===//
// Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's.
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
// Do shape refinement.
// This must be run before RefineTypes (which primarily does dtype inference),
@ -132,7 +132,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// operand.
createTorchShapeRefinementPipeline(pm, options);
// Refine types in the program, which mainly means inferring dtypes of ops.
pm.addNestedPass<FuncOp>(Torch::createRefineTypesPass());
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
// Propagate to ABI return types the shape/dtype information discovered by
// the previous pass. Doing this is ABI-compatible for our backends.
@ -142,7 +142,7 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// This can fold away some branches given the information got from
// RefineTypes before doing maximize value sematics which only works with
// basic blocks.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
if (options.optimize) {
@ -152,9 +152,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
// branches that guard unreachable code that backends can't handle yet, such
// as lists, RaiseException, unimplemented aten ops, and
// only-used-in-training operations on `torch.global_slot`'s.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
}
pm.addNestedPass<FuncOp>(Torch::createDecomposeComplexOpsPass());
pm.addNestedPass<func::FuncOp>(Torch::createDecomposeComplexOpsPass());
// TODO: VerifyTorchBackendContractPass.
}
@ -172,11 +172,11 @@ void mlir::torch::Torch::createTorchShapeRefinementPipeline(
// as hard as possible" kind of thing, so it's inherently somewhat brittle.
// The idea is to keep strengthening what we do here to support the shape
// library. We don't need to support arbitrary programs, thankfully.
pm.addNestedPass<FuncOp>(Torch::createSimplifyShapeCalculationsPass());
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
// Run CSE, then see if we can simplify further.
pm.addNestedPass<FuncOp>(createCSEPass());
pm.addNestedPass<FuncOp>(Torch::createSimplifyShapeCalculationsPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(Torch::createSimplifyShapeCalculationsPass());
// Drop shape calculations, leaving behind the shape-refined program.
pm.addNestedPass<FuncOp>(Torch::createDropShapeCalculationsPass());
pm.addNestedPass<func::FuncOp>(Torch::createDropShapeCalculationsPass());
}

View File

@ -33,10 +33,10 @@ public:
auto classType = symbolTable.lookup<ClassTypeOp>(
op.receiver().getType().cast<NnModuleType>().getClassName());
assert(classType && "malformed module -- missing ClassTypeOp");
FuncOp func;
func::FuncOp func;
for (auto method : classType.getOps<MethodOp>()) {
if (method.name() == op.name()) {
func = symbolTable.lookup<FuncOp>(method.function());
func = symbolTable.lookup<func::FuncOp>(method.function());
break;
}
}

View File

@ -207,7 +207,7 @@ public:
assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 &&
"Torch JIT operators shouldn't have regions or successors");
Operation *newOp = rewriter.createOperation(state);
Operation *newOp = rewriter.create(state);
auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
@ -276,7 +276,7 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createReduceOpVariantsPass() {
return std::make_unique<ReduceOpVariantsPass>();
}

View File

@ -25,7 +25,7 @@ class RefinePublicReturnPass
: public RefinePublicReturnBase<RefinePublicReturnPass> {
void runOnOperation() override {
auto module = getOperation();
module.walk([&](FuncOp func) {
module.walk([&](func::FuncOp func) {
if (func.getVisibility() != SymbolTable::Visibility::Public)
return;
if (func.isExternal())
@ -40,7 +40,7 @@ class RefinePublicReturnPass
});
}
void rewriteSignature(FuncOp func) {
void rewriteSignature(func::FuncOp func) {
// Find the unique return op.
func::ReturnOp returnOp;
WalkResult walkResult = func.walk([&](func::ReturnOp op) {
@ -90,7 +90,7 @@ class RefinePublicReturnPass
returnOp->setOperands(newOperands);
// Update the function type.
auto funcType = func.getType();
auto funcType = func.getFunctionType();
func.setType(FunctionType::get(funcType.getContext(), funcType.getInputs(),
ValueRange(newOperands).getTypes()));
}

View File

@ -1191,7 +1191,7 @@ static bool isSafeToRefineOperandInPlace(OpOperand *use, Type newOperandType) {
return operationIsValidWithRefinedType(use, newOperandType);
}
void optimize(FuncOp func, TypeAnalyzer &analyzer) {
void optimize(func::FuncOp func, TypeAnalyzer &analyzer) {
func.walk([&](Operation *op) {
auto convertValuesToMostRefinedType = [&](ValueRange values, OpBuilder &b) {
for (Value v : values) {
@ -1336,7 +1336,7 @@ class RefineTypesPass : public RefineTypesBase<RefineTypesPass> {
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createRefineTypesPass() {
return std::make_unique<RefineTypesPass>();
}

View File

@ -169,7 +169,7 @@ static Value adjustShapeFunctionArg(Value operand, Type desiredType,
// from the shape library.
static LogicalResult
populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands,
mlir::FuncOp shapeFunction) {
mlir::func::FuncOp shapeFunction) {
// Create a call to the shape function in the `shapeCalculation` region.
// We will import the callee from the shape library later.
OpBuilder b(op.getContext());
@ -241,7 +241,7 @@ class ReifyShapeCalculationsPass
name = name.drop_front(strlen("valsem."));
auto shapeFunctionName = ("__torch_mlir_shape_fn." + Twine(name)).str();
auto shapeFunction =
shapeLibrary->lookupSymbol<FuncOp>(shapeFunctionName);
shapeLibrary->lookupSymbol<func::FuncOp>(shapeFunctionName);
if (!shapeFunction)
return;
neededShapeFunctions.push_back(shapeFunctionName);
@ -276,7 +276,7 @@ class ReifyShapeCalculationsPass
auto symName = worklist.pop_back_val();
if (importedFunctions.count(symName))
continue;
auto func = shapeLibrary->lookupSymbol<mlir::FuncOp>(symName);
auto func = shapeLibrary->lookupSymbol<mlir::func::FuncOp>(symName);
assert(func && "broken shape library");
// Move the shape function from the library to the module this pass
// is running on. (this mutates the library, but we re-parse it each time

View File

@ -415,7 +415,7 @@ class SimplifyShapeCalculationsPass
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createSimplifyShapeCalculationsPass() {
return std::make_unique<SimplifyShapeCalculationsPass>();
}

View File

@ -45,9 +45,10 @@ struct FuncBackendTypeConversionPass
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, typeConverter);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType()) &&
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody());
});
populateCallOpTypeConversionPattern(patterns, typeConverter);
@ -155,7 +156,7 @@ struct FinalizingBackendTypeConversionPass
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
return std::make_unique<FinalizingBackendTypeConversionPass>();
}

View File

@ -10,9 +10,12 @@
#ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
#define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
class ModuleOp;
namespace torch {
namespace TorchConversion {

View File

@ -59,26 +59,27 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
// We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model)
// and those constants get somewhat obscured by TorchToStd.
pm.addNestedPass<FuncOp>(createConvertTorchToTMTensorPass());
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<FuncOp>(memref::createExpandOpsPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(memref::createExpandOpsPass());
if (options.optimize) {
// Clean up any non-canonical code introduced above..
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// Resolve `dim` ops on tensors (which currently live in the `memref`
// dialect for some reason -- we don't have memrefs at this level).
pm.addNestedPass<FuncOp>(memref::createResolveShapedTypeResultDimsPass());
pm.addNestedPass<func::FuncOp>(
memref::createResolveShapedTypeResultDimsPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
}
// Finish the type conversion from `torch` types to the types of the
// linalg-on-tensors backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<FuncOp>(
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that linalg on tensors backends
@ -93,21 +94,21 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
pm.addPass(
TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass());
pm.addNestedPass<FuncOp>(createConvertTorchToTosaPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToTosaPass());
// Perform rank broadcasting so TosaToLinalg pass works
pm.addNestedPass<FuncOp>(createTosaMakeBroadcastablePass());
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
if (options.optimize) {
// Clean up any non-canonical code introduced above..
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
}
// Finish the type conversion from `torch` types to the types of the
// TOSA backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<FuncOp>(
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
// Verify that we have lowered to the form that TOSA backends

View File

@ -9,6 +9,8 @@
#include "PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@ -60,7 +62,7 @@ class VerifyLinalgOnTensorsBackendContractPass
ConversionTarget target(*context);
// Structural operations.
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>(
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
opHasLegalTypes);
target.addDynamicallyLegalOp<GetNextSeedOp>(opHasLegalTypes);

View File

@ -9,6 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@ -39,7 +40,7 @@ class VerifyTosaBackendContractPass
ConversionTarget target(*context);
// Structural operations.
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>(
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
opHasLegalTypes);
// Basic scalar operations.
target.addLegalDialect<tosa::TosaDialect>();

View File

@ -10,6 +10,7 @@
#ifndef REFBACKEND_PASSDETAIL_H
#define REFBACKEND_PASSDETAIL_H
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"

View File

@ -15,7 +15,7 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@ -68,7 +68,7 @@ static bool isArgMemRefTypeValid(Type type) {
return false;
}
static void addEmitCInterfaceAttr(FuncOp func) {
static void addEmitCInterfaceAttr(func::FuncOp func) {
func->setAttr("llvm.emit_c_interface", UnitAttr::get(func.getContext()));
}
@ -115,7 +115,7 @@ static void replaceReturnWithCall(OpBuilder b, func::ReturnOp op,
}
static LogicalResult mungeFunction(
FuncOp func,
func::FuncOp func,
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
// Only need to call mungeFunction for functions callable from outside of the
// module.
@ -188,15 +188,15 @@ class MungeCallingConventions
auto module = getOperation();
OpBuilder b(module.getBodyRegion());
std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
for (auto func : module.getOps<FuncOp>()) {
for (auto func : module.getOps<func::FuncOp>()) {
if (failed(mungeFunction(func, invokedConsumeFuncReturnFuncs)))
return signalPassFailure();
}
// Create FuncOp for consumeFuncReturnFuncs that are used.
for (auto &p : invokedConsumeFuncReturnFuncs) {
auto consumeFuncReturnFunc =
b.create<FuncOp>(module.getLoc(), p.first,
auto consumeFuncReturnFunc = b.create<func::FuncOp>(
module.getLoc(), p.first,
FunctionType::get(module.getContext(), p.second, {}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc);
@ -309,7 +309,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
return std::make_unique<ExpandOpsForLLVM>();
}
@ -366,7 +366,7 @@ class MungeMemrefCopy : public MungeMemrefCopyBase<MungeMemrefCopy> {
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::RefBackend::createMungeMemrefCopyPass() {
return std::make_unique<MungeMemrefCopy>();
}
@ -390,7 +390,7 @@ class GeneralizeTensorPad
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::RefBackend::createGeneralizeTensorPadPass() {
return std::make_unique<GeneralizeTensorPad>();
}

View File

@ -36,8 +36,8 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
MlirAttribute symNameAttr = mlirStringAttrGet(
context, toMlirStringRef(function->qualname().qualifiedName()));
MlirOperation func = createMlirOperation(
"builtin.func", loc, mlirRegionCreate(),
toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)),
"func.func", loc, mlirRegionCreate(),
toMlirNamedAttribute("function_type", mlirTypeAttrGet(functionType)),
toMlirNamedAttribute("sym_name", symNameAttr));
std::vector<MlirAttribute> argAttrDicts;
for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) {

View File

@ -29,7 +29,7 @@ import torch
from torch.jit import ScriptFunction
from torch_mlir import ir
from torch_mlir.dialects.builtin import FuncOp
from torch_mlir.dialects.func import FuncOp
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder

View File

@ -115,15 +115,15 @@ class RefBackendInvoker:
LOWERING_PIPELINE = ",".join([
"builtin.func(refback-generalize-tensor-pad)",
"func.func(refback-generalize-tensor-pad)",
# Bufferize.
"builtin.func(scf-bufferize)",
"builtin.func(tm-tensor-bufferize)",
"builtin.func(linalg-bufferize)",
"func.func(scf-bufferize)",
"func.func(tm-tensor-bufferize)",
"func.func(linalg-bufferize)",
"func-bufferize",
"arith-bufferize",
"builtin.func(tensor-bufferize)",
"builtin.func(finalizing-bufferize)",
"func.func(tensor-bufferize)",
"func.func(finalizing-bufferize)",
# Munge to make it ExecutionEngine compatible.
# Specifically, we rewrite calling convention boundaries to be in terms
# of unranked memref, and we rewrite the return to actually be a
@ -135,17 +135,17 @@ LOWERING_PIPELINE = ",".join([
# global seed used in stateful rng.
"refback-insert-rng-globals",
# Lower to LLVM
"builtin.func(tm-tensor-to-loops)",
"builtin.func(refback-munge-memref-copy)",
"builtin.func(convert-linalg-to-loops)",
"builtin.func(lower-affine)",
"func.func(tm-tensor-to-loops)",
"func.func(refback-munge-memref-copy)",
"func.func(convert-linalg-to-loops)",
"func.func(lower-affine)",
"convert-scf-to-cf",
"builtin.func(refback-expand-ops-for-llvm)",
"builtin.func(arith-expand)",
"builtin.func(convert-math-to-llvm)",
"func.func(refback-expand-ops-for-llvm)",
"func.func(arith-expand)",
"func.func(convert-math-to-llvm)",
"convert-linalg-to-llvm",
"convert-memref-to-llvm",
"builtin.func(convert-arith-to-llvm)",
"func.func(convert-arith-to-llvm)",
"convert-func-to-llvm",
"convert-cf-to-llvm",
"reconcile-unrealized-casts",

View File

@ -38,25 +38,25 @@ class LinalgOnTensorsTosaBackend(TosaBackend):
"""
# TOSA legalization may emit tosa.const() ops. These are legalized
# by tosa-to-standard to arith.constants. This mechanical transformation
# by tosa-to-arith to arith.constants. This mechanical transformation
# must be done prior to TOSA-to-LinAlg so that the latter does not fail.
# This is an artifact of legalizations spread across a collection of simple
# ones in TOSA-to-Standard and the main conversions TOSA-to-LinAlg,
# that depend on TOSA as well as TOSA-to-Standard.
run_pipeline_with_repro_report(
imported_module,
"builtin.func(tosa-to-standard)",
"Lowering TOSA to Standard")
"func.func(tosa-to-arith)",
"Lowering TOSA to Arith")
# Named ops must be legalized prior to general tosa-to-linalg
run_pipeline_with_repro_report(
imported_module,
"builtin.func(tosa-to-linalg-named)",
"func.func(tosa-to-linalg-named)",
"Lowering TOSA to Linalg-on-Tensors for Named Ops")
run_pipeline_with_repro_report(
imported_module,
"builtin.func(tosa-to-linalg)",
"func.func(tosa-to-linalg)",
"Lowering TOSA to Linalg-on-Tensors")
return self.refbackend.compile(imported_module)

View File

@ -1,7 +1,7 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @forward
builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3

View File

@ -42,7 +42,7 @@ torch.nn_module {
torch.slot "t1", %t : !torch.tensor
torch.slot "t2", %t : !torch.tensor
} : !torch.nn.Module<"c">
builtin.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor {
func.func private @use_slot(%arg0 : !torch.nn.Module<"c">) -> !torch.tensor {
%t1 = torch.prim.GetAttr %arg0["t1"] : !torch.nn.Module<"c"> -> !torch.tensor
%t2 = torch.prim.GetAttr %arg0["t2"] : !torch.nn.Module<"c"> -> !torch.tensor
%cst = torch.constant.int 1
@ -63,7 +63,7 @@ torch.nn_module {
torch.slot "t1", %t : !torch.tensor
torch.slot "t2", %t : !torch.tensor
} : !torch.nn.Module<"c">
builtin.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) {
func.func private @set_slot(%arg0 : !torch.nn.Module<"c">, %arg1 : !torch.tensor) {
torch.prim.SetAttr %arg0["t1"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
torch.prim.SetAttr %arg0["t2"] = %arg1: !torch.nn.Module<"c">, !torch.tensor
return

View File

@ -4,8 +4,8 @@
torch.class_type @c {}
%0 = torch.nn_module {
// expected-error @+1 {{'builtin.func' op is not allowed inside 'torch.nn_module'}}
builtin.func @f()
// expected-error @+1 {{'func.func' op is not allowed inside 'torch.nn_module'}}
func.func @f()
} : !torch.nn.Module<"c">
// -----
@ -32,8 +32,8 @@ torch.class_type @c {
// -----
torch.class_type @c {
// expected-error @+1 {{'builtin.func' op is not allowed inside `torch.class_type`}}
builtin.func @f()
// expected-error @+1 {{'func.func' op is not allowed inside `torch.class_type`}}
func.func @f()
}
// -----
@ -60,7 +60,7 @@ torch.class_type @c {
torch.method "f", @f
}
builtin.func @f(%arg0: !torch.nn.Module<"c">) {
func.func @f(%arg0: !torch.nn.Module<"c">) {
return
}
@ -71,11 +71,11 @@ torch.class_type @c {
torch.method "f", @f
}
builtin.func private @f(%arg0: !torch.nn.Module<"c">)
func.func private @f(%arg0: !torch.nn.Module<"c">)
// -----
builtin.func private @f() {
func.func private @f() {
return
}
torch.class_type @c {
@ -85,7 +85,7 @@ torch.class_type @c {
// -----
builtin.func private @f(!torch.nn.Module<"other_c">) {
func.func private @f(%arg0: !torch.nn.Module<"other_c">) {
return
}
torch.class_type @c {
@ -101,21 +101,21 @@ torch.class_type @c {
// -----
// expected-error @+1 {{'torch.type_bound' must be attached to an argument of !torch.tensor/!torch.vtensor type}}
builtin.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
func.func @f(%arg0: i32 {torch.type_bound = !torch.tensor<*,f32>})
// -----
// expected-error @+1 {{'torch.type_bound' must be TypeAttr}}
builtin.func @f(%arg0: i32 {torch.type_bound = 1})
func.func @f(%arg0: i32 {torch.type_bound = 1})
// -----
// expected-error @+1 {{'torch.type_bound' must be of !torch.tensor/!torch.vtensor type}}
builtin.func @f(%arg0: i32 {torch.type_bound = i32})
func.func @f(%arg0: i32 {torch.type_bound = i32})
// -----
builtin.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
func.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
// expected-error @+1 {{operand type '!torch.optional<tensor>' and result type '!torch.tensor' are cast incompatible}}
%0 = torch.derefine %arg0 : !torch.optional<tensor> to !torch.tensor
return %0 : !torch.tensor
@ -123,7 +123,7 @@ builtin.func @derefine(%arg0: !torch.optional<tensor>) -> !torch.tensor {
// -----
builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<tensor> {
func.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !torch.optional<tensor> {
// expected-error @+1 {{operand type '!torch.tensor' and result type '!torch.optional<tensor>' are cast incompatible}}
%0 = torch.prim.unchecked_cast %arg0 : !torch.tensor -> !torch.optional<tensor>
return %0 : !torch.optional<tensor>
@ -132,11 +132,11 @@ builtin.func @torch.prim.unchecked_cast$invalid_types(%arg0: !torch.tensor) -> !
// -----
// expected-error @+1 {{invalid dtype 'tuple<>' for !torch.tensor type}}
builtin.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
func.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
// -----
builtin.func @torch.tensor() {
func.func @torch.tensor() {
// Incompatible shape.
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
@ -145,7 +145,7 @@ builtin.func @torch.tensor() {
// -----
builtin.func @torch.tensor() {
func.func @torch.tensor() {
// Incompatible dtype.
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
@ -154,7 +154,7 @@ builtin.func @torch.tensor() {
// -----
builtin.func @torch.tensor() {
func.func @torch.tensor() {
// Incompatible type.
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1
@ -163,7 +163,7 @@ builtin.func @torch.tensor() {
// -----
builtin.func @torch.prim.ListConstruct() {
func.func @torch.prim.ListConstruct() {
%int2 = torch.constant.int 2
// expected-error@+1 {{operand types should have the same type as the list contained type}}
torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<tensor>
@ -172,7 +172,7 @@ builtin.func @torch.prim.ListConstruct() {
// -----
builtin.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
func.func @torch.overwrite.tensor.contents(%arg0: !torch.vtensor<[1],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> {
%0 = torch.copy.to_tensor %arg0 : !torch.tensor<[1],f32>
// expected-error@+1 {{'torch.overwrite.tensor.contents' op failed to verify that overwritten tensor type is corresponding !torch.tensor of value tensor type}}
torch.overwrite.tensor.contents %arg1 overwrites %0 : !torch.vtensor<[?],f32>, !torch.tensor<[1],f32>

View File

@ -9,7 +9,7 @@
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
// CHECK: return
builtin.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>,
func.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1],f32>,
%t1: !torch.vtensor<[1],f64>,
%alpha: !torch.float) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
@ -25,7 +25,7 @@ builtin.func @tensor_tensor$same_category_different_width(%t0: !torch.vtensor<[1
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],f64>
// CHECK: return
builtin.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
func.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
%t1: !torch.vtensor<[1],f64>,
%alpha: !torch.float) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si32>, !torch.vtensor<[1],f64>, !torch.float -> !torch.vtensor<[1],unk>
@ -41,7 +41,7 @@ builtin.func @tensor_tensor$different_category(%t0: !torch.vtensor<[1],si32>,
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: return
builtin.func @tensor_tensor$same_category_zero_rank_wider(
func.func @tensor_tensor$same_category_zero_rank_wider(
%t0: !torch.vtensor<[1],f32>,
%t1: !torch.vtensor<[],f64>,
%alpha: !torch.int) {
@ -58,7 +58,7 @@ builtin.func @tensor_tensor$same_category_zero_rank_wider(
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: return
builtin.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>,
func.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si64>,
%t1: !torch.vtensor<[],f32>,
%alpha: !torch.int) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],si64>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[1],unk>
@ -73,7 +73,7 @@ builtin.func @tensor_tensor$zero_rank_higher_category(%t0: !torch.vtensor<[1],si
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],f32>
// CHECK: return
builtin.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>,
func.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],f32>,
%t1: !torch.vtensor<[1],f32>,
%alpha: !torch.float) {
%1 = torch.aten.add.Tensor %t0, %t1, %alpha: !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1],unk>
@ -89,7 +89,7 @@ builtin.func @tensor_tensor$alpha_wider_no_contribution(%t0: !torch.vtensor<[1],
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si64>, !torch.float, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: return
builtin.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) {
func.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>, %scalar: !torch.float, %alpha: !torch.int) {
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si64>, !torch.float, !torch.int -> !torch.vtensor<[1],unk>
return
}
@ -103,7 +103,7 @@ builtin.func @tensor_scalar$scalar_higher_category(%t0: !torch.vtensor<[1],si64>
// CHECK: %[[VAL_3:.*]] = torch.aten.add.Scalar %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] :
// CHECK-SAME: !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32>
// CHECK: return
builtin.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) {
func.func @tensor_scalar$scalar_same_category_wider(%t0: !torch.vtensor<[1],si32>, %scalar: !torch.int, %alpha: !torch.int) {
%1 = torch.aten.add.Scalar %t0, %scalar, %alpha: !torch.vtensor<[1], si32>, !torch.int, !torch.int -> !torch.vtensor<[1],unk>
return
}