Bump LLVM at 8361c5da30588d3d4a48eae648f53be1feb5cfad

pull/777/head snapshot-20220318.332
Vigilans 2022-03-16 18:44:23 +08:00 committed by Yi Zhang
parent 218b4875d5
commit 63fb1e5aad
46 changed files with 177 additions and 150 deletions

View File

@ -23,7 +23,7 @@ from itertools import chain
from torch_mlir import ir
import torch_mlir.dialects.torch as torch_d
from torch_mlir.dialects import builtin, std
from torch_mlir.dialects import builtin, std, func
import torch.fx
from torch.fx.experimental.fx_acc import acc_ops
@ -288,7 +288,7 @@ class _ForwardFunctionBuilder(_Builder):
result = self._insert_function_call(node)
self.env[node] = result
elif node.op == 'output':
std.ReturnOp([self.env[node_arg] for node_arg in node.args],
func.ReturnOp([self.env[node_arg] for node_arg in node.args],
loc=self.loc, ip=self.func_ip)
elif node.op == 'placeholder':
continue

View File

@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTMTensorDialect
MLIRSideEffectInterfaces
MLIRSupport
MLIRSCF
MLIRStandard
MLIRFunc
MLIRTensor
MLIRViewLikeInterface
)

View File

@ -10,8 +10,8 @@
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "llvm/Support/Debug.h"

View File

@ -10,11 +10,11 @@
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"

View File

@ -10,9 +10,10 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
@ -132,8 +133,8 @@ struct TMTensorBufferizePass
bufferization::BufferizeTypeConverter typeConverter;
// Mark all Standard operations legal.
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
StandardOpsDialect, tensor::TensorDialect>();
target.addLegalDialect<arith::ArithmeticDialect, func::FuncDialect,
memref::MemRefDialect, tensor::TensorDialect>();
// Mark all TMTensor operations illegal as long as they work on tensors.
auto isLegalOperation = [&](Operation *op) {

View File

@ -16,7 +16,7 @@ add_mlir_library(TorchMLIRTMTensorPasses
MLIRMemRef
MLIRPass
MLIRSCF
MLIRStandard
MLIRFunc
MLIRSupport
MLIRTensor
MLIRTransforms

View File

@ -7,11 +7,11 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@ -93,7 +93,7 @@ struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern {
namespace {
struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, StandardOpsDialect,
registry.insert<linalg::LinalgDialect, func::FuncDialect,
mlir::arith::ArithmeticDialect, math::MathDialect,
memref::MemRefDialect, scf::SCFDialect>();
}

View File

@ -64,18 +64,18 @@ func @scatter_mixed_tensor_memref(
// -----
func @scatter_mixed_tensor_memref(
func @scatter_output_type_mismatch(
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> memref<?x?xf32> {
// expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'memref<?x?xf32>'}}
%original : tensor<?x?xf32>) -> tensor<4x?xf32> {
// expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'tensor<4x?xf32>'}}
%0 = tm_tensor.scatter unique_indices(true)
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
tm_tensor.yield %1 : f32
} -> memref<?x?xf32>
return %0 : memref<?x?xf32>
} -> tensor<4x?xf32>
return %0 : tensor<4x?xf32>
}
// -----

View File

@ -6,7 +6,7 @@ set(LIBS
MLIROptLib
MLIRSCF
MLIRSCFTransforms
MLIRStandard
MLIRFunc
MLIRTensor
MLIRTransforms
TorchMLIRTMTensorDialect

View File

@ -8,15 +8,15 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
@ -40,7 +40,7 @@ int main(int argc, char **argv) {
mlir::torch::TMTensor::TMTensorDialect,
// Upstream dialects
mlir::arith::ArithmeticDialect, mlir::linalg::LinalgDialect,
mlir::memref::MemRefDialect, mlir::StandardOpsDialect,
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
mlir::scf::SCFDialect, mlir::tensor::TensorDialect>();
return mlir::asMainReturnCode(

@ -1 +1 @@
Subproject commit bfc6fbfb65f6d490e6a1ba4eb6c734d2b4494dd1
Subproject commit 8361c5da30588d3d4a48eae648f53be1feb5cfad

View File

@ -36,7 +36,7 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
(ATen) and ops where such mismatches are undefined behavior (linalg).
To model the termination of the program for implementing error guards,
we use the `std.assert` op.
we use the `cf.assert` op.
This is a design decision that is at variance from other passes in the
ecosystem, which use the
`shape` dialect's witness system (`shape.cstr_*` family of ops feeding into

View File

@ -474,7 +474,6 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
$maxTripCount `,` $initialCondition `,` `init` `(` $iterArgsInit `)` $region
attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true if this loop is "for-like". Otherwise it is "while-like"
/// and this function returns false.
@ -528,7 +527,6 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [
let regions = (region SizedRegion<1>:$thenRegion, SizedRegion<1>:$elseRegion);
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
@ -1047,13 +1045,13 @@ let results = (outs
// "torch.prim.If"(%cond) ({
// "torch.prim.If.yield"() : () -> ()
// }, {
// "torch.prim.RaiseException"(%msg) : (!torch.str) -> ()
// "torch.prim.If.yield"() : () -> ()
// "torch.prim.RaiseException"(%msg) : (!torch.str) -> ()
// "torch.prim.If.yield"() : () -> ()
// }) : (!torch.bool) -> ()
//
//
// This new operation `torch.runtime.assert` is added to simplify the IR control
// flow by avoiding unnecessary branches. It also makes insertion of the runtime
// assert in the source code easier.
// flow by avoiding unnecessary branches. It also makes insertion of the runtime
// assert in the source code easier.
def Torch_RuntimeAssertOp: Torch_Op<"runtime.assert", [
AllowsTypeRefinement,
HasValueSemantics,
@ -1104,7 +1102,6 @@ def Torch_ShapeCalculateOp : Torch_Op<"shape.calculate", [
let assemblyFormat = [{
$body `shapes` $shapeCalculation attr-dict `:` type($results)
}];
let hasVerifier = 1;
}
def Torch_ShapeCalculateYieldOp : Torch_Op<"shape.calculate.yield", [

View File

@ -18,7 +18,7 @@ include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td"
class TorchConversion_Op<string mnemonic, list<Trait> traits = []>
: Op<TorchConversion_Dialect, mnemonic, traits> {
: Op<TorchConversion_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
}
//===----------------------------------------------------------------------===//
@ -206,7 +206,11 @@ def TorchConversion_GeneratorToI64Op : TorchConversion_Op<"generator_to_i64", [
}];
}
def TorchConversion_GetNextSeedOp: TorchConversion_Op<"get_next_seed", [
class TorchConversionWithSideEffect_Op<string mnemonic, list<Trait> traits = []>
: Op<TorchConversion_Dialect, mnemonic, traits> {
}
def TorchConversion_GetNextSeedOp: TorchConversionWithSideEffect_Op<"get_next_seed", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
]> {
let summary = "Get the next global seed";

View File

@ -27,6 +27,8 @@ std::unique_ptr<OperationPass<FuncOp>> createExpandOpsForLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
std::unique_ptr<OperationPass<FuncOp>> createMungeMemrefCopyPass();
std::unique_ptr<OperationPass<FuncOp>> createGeneralizeTensorPadPass();
} // namespace RefBackend
} // namespace torch
} // namespace mlir

View File

@ -35,4 +35,9 @@ def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "FuncOp"> {
let dependentDialects = ["memref::MemRefDialect"];
}
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "FuncOp"> {
let summary = "Convert tensor.pad to linalg ops";
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
}
#endif // TORCHMLIR_REFBACKEND_PASSES

View File

@ -13,9 +13,9 @@
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -41,7 +41,7 @@ public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
registry.insert<math::MathDialect>();
registry.insert<StandardOpsDialect>();
registry.insert<func::FuncDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<cf::ControlFlowDialect>();
@ -51,7 +51,7 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
cf::ControlFlowDialect, math::MathDialect,
tensor::TensorDialect, arith::ArithmeticDialect>();
target.addLegalOp<TorchConversion::GetNextSeedOp>();

View File

@ -14,7 +14,7 @@ add_mlir_conversion_library(TorchMLIRTorchToSCF
MLIRIR
MLIRPass
MLIRSCF
MLIRStandard
MLIRFunc
TorchMLIRTorchDialect
TorchMLIRTorchConversionDialect
)

View File

@ -13,7 +13,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStd
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRStandard
MLIRFunc
TorchMLIRTorchDialect
)

View File

@ -11,7 +11,7 @@
#include "../PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Transforms/DialectConversion.h"
@ -158,7 +158,7 @@ namespace {
class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect>();
registry.insert<func::FuncDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<cf::ControlFlowDialect>();
@ -168,7 +168,7 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect, StandardOpsDialect,
target.addLegalDialect<Torch::TorchDialect, func::FuncDialect,
arith::ArithmeticDialect, tensor::TensorDialect,
cf::ControlFlowDialect>();

View File

@ -10,8 +10,8 @@
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/MLIRContext.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
@ -312,7 +312,7 @@ class ConvertTorchToTMTensor
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
registry.insert<StandardOpsDialect>();
registry.insert<func::FuncDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
registry.insert<TMTensorDialect>();
@ -322,7 +322,7 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
tensor::TensorDialect, arith::ArithmeticDialect,
Torch::TorchDialect, TMTensorDialect>();

View File

@ -232,10 +232,6 @@ LogicalResult ClassTypeOp::verify() {
// PrimLoopOp
//===----------------------------------------------------------------------===//
LogicalResult PrimLoopOp::verify() {
return RegionBranchOpInterface::verifyTypes(*this);
}
OperandRange PrimLoopOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0);
return iterArgsInit();
@ -275,10 +271,6 @@ PrimLoopConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
// PrimIfOp
//===----------------------------------------------------------------------===//
LogicalResult PrimIfOp::verify() {
return RegionBranchOpInterface::verifyTypes(*this);
}
ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions.
result.regions.reserve(2);
@ -1551,10 +1543,6 @@ OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
// ShapeCalculateOp
//===----------------------------------------------------------------------===//
LogicalResult ShapeCalculateOp::verify() {
return RegionBranchOpInterface::verifyTypes(*this);
}
void ShapeCalculateOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {

View File

@ -9,7 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@ -93,14 +93,15 @@ public:
} // namespace
namespace {
class AdjustCallingConventionForCall : public OpConversionPattern<CallOp> {
class AdjustCallingConventionForCall
: public OpConversionPattern<func::CallOp> {
public:
AdjustCallingConventionForCall(TypeConverter &converter, MLIRContext *context,
TypeBoundMap &typeBoundMap)
: OpConversionPattern<CallOp>(converter, context),
: OpConversionPattern<func::CallOp>(converter, context),
typeBoundMap(typeBoundMap) {}
LogicalResult
matchAndRewrite(CallOp call, OpAdaptor adaptor,
matchAndRewrite(func::CallOp call, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> convertedResults;
if (failed(typeConverter->convertTypes(call.getResultTypes(),
@ -126,8 +127,8 @@ public:
newOperands.push_back(operand.value());
}
CallOp newCall = rewriter.create<CallOp>(call.getLoc(), call.getCallee(),
convertedResults, newOperands);
func::CallOp newCall = rewriter.create<func::CallOp>(
call.getLoc(), call.getCallee(), convertedResults, newOperands);
int newOpResultIdx = 0;
SmallVector<Value> newResults;
for (auto type : call.getResultTypes()) {
@ -153,11 +154,12 @@ private:
} // namespace
namespace {
class AdjustCallingConventionForReturn : public OpConversionPattern<ReturnOp> {
class AdjustCallingConventionForReturn
: public OpConversionPattern<func::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newOperands;
@ -178,7 +180,7 @@ public:
}
newOperands.push_back(operand);
}
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, newOperands);
return success();
}
};
@ -233,13 +235,15 @@ static LogicalResult adjustCallingConventions(FuncOp func,
//
// Bug for doing this better https://bugs.llvm.org/show_bug.cgi?id=49812
DenseSet<Operation *> opsInOriginalProgram;
func.walk([&](CallOp op) { opsInOriginalProgram.insert(op.getOperation()); });
func.walk(
[&](ReturnOp op) { opsInOriginalProgram.insert(op.getOperation()); });
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
[&](func::CallOp op) { opsInOriginalProgram.insert(op.getOperation()); });
func.walk([&](func::ReturnOp op) {
opsInOriginalProgram.insert(op.getOperation());
});
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return !opsInOriginalProgram.contains(op.getOperation());
});
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return !opsInOriginalProgram.contains(op.getOperation());
});
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
@ -249,7 +253,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
target.addLegalOp<PrimTupleIndexOp>();
target.addLegalOp<PrimTupleConstructOp>();
// We don't know how to rewrite it, so mark it as illegal.
target.addIllegalOp<CallIndirectOp>();
target.addIllegalOp<func::CallIndirectOp>();
if (failed(applyPartialConversion(func.getOperation(), target,
std::move(patterns))))
return failure();

View File

@ -9,10 +9,10 @@
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Parser.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/InliningUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"

View File

@ -9,7 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@ -349,7 +349,7 @@ static LogicalResult analyzeInstances(FuncOp func,
}
static FailureOr<Monomorphization>
createMonomorphizationForCall(CallOp op, BlockAndValueMapping &mapping,
createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping,
SymbolTable &symbolTable) {
auto func = symbolTable.lookup<FuncOp>(op.getCallee());
Monomorphization monomorphization;
@ -413,7 +413,7 @@ private:
BlockAndValueMapping mapping;
if (failed(analyzeInstances(func, m.argInstances, mapping)))
return failure();
auto walkResult = func.walk([&](CallOp op) {
auto walkResult = func.walk([&](func::CallOp op) {
FailureOr<Monomorphization> maybeMonomorphization =
createMonomorphizationForCall(op, mapping, symbolTable);
if (failed(maybeMonomorphization))
@ -439,7 +439,7 @@ static LogicalResult verifyNnModuleValueUses(Value value) {
if (!value.getType().isa<NnModuleType>())
return success();
for (Operation *op : value.getUsers()) {
if (isa<CallOp, PrimGetAttrOp>(op))
if (isa<func::CallOp, PrimGetAttrOp>(op))
continue;
// Only allow `value` as the receiver.
if (isa<PrimSetAttrOp>(op) && cast<PrimSetAttrOp>(op).value() != value)
@ -539,7 +539,7 @@ rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping,
toErase.push_back(op);
return WalkResult::advance();
};
auto handleCall = [&](CallOp op) {
auto handleCall = [&](func::CallOp op) {
FailureOr<Monomorphization> maybeMonomorphization =
createMonomorphizationForCall(op, mapping, symbolTable);
if (failed(maybeMonomorphization))
@ -550,7 +550,7 @@ rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping,
return !v.getType().isa<NnModuleType>();
}));
assert(newFuncs.find(monomorphization) != newFuncs.end());
auto newOp = OpBuilder(op).create<CallOp>(
auto newOp = OpBuilder(op).create<func::CallOp>(
op.getLoc(), newFuncs[monomorphization], newArguments);
op.replaceAllUsesWith(newOp);
toErase.push_back(op);
@ -561,7 +561,7 @@ rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping,
return handlePrimSetAttr(primSetAttr);
if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op))
return handlePrimGetAttr(primGetAttr);
if (auto call = dyn_cast<CallOp>(op))
if (auto call = dyn_cast<func::CallOp>(op))
return handleCall(call);
return WalkResult::advance();
});

View File

@ -9,7 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
@ -51,7 +51,7 @@ public:
SmallVector<Operation *> copyLikeOps;
SmallVector<Operation *> viewLikeOps;
SmallVector<OverwriteTensorContentsOp> overwriteTensorContentsOps;
Optional<mlir::ReturnOp> returnOp;
Optional<mlir::func::ReturnOp> returnOp;
};
// Check that graph rewriting is possible by doing an abstract
@ -103,7 +103,7 @@ public:
availableAliases.clear();
availableAliases.insert(assertNonValueTensor(overwrite.overwritten()));
result.overwriteTensorContentsOps.push_back(overwrite);
} else if (auto returnOp = dyn_cast<mlir::ReturnOp>(user)) {
} else if (auto returnOp = dyn_cast<mlir::func::ReturnOp>(user)) {
result.returnOp = returnOp;
} else {
return rewriter.notifyMatchFailure(
@ -234,7 +234,7 @@ public:
// Nothing to do if there is just a ReturnOp -- we know that we won't be
// rewriting anything, since we must preserve the ReturnOp's original type.
if (llvm::hasSingleElement(nonValueTensorsUsedByOp) &&
isa<mlir::ReturnOp>(nonValueTensorsUsedByOp.begin()->first)) {
isa<mlir::func::ReturnOp>(nonValueTensorsUsedByOp.begin()->first)) {
return failure();
}
@ -278,7 +278,7 @@ public:
// trivially feeds into CopyToValueTensorOp's.
SmallVector<Operation *> viewLikeOps;
SmallVector<CopyToValueTensorOp> copyToValueTensorOps;
SmallVector<mlir::ReturnOp> returnOps;
SmallVector<mlir::func::ReturnOp> returnOps;
auto workList = llvm::to_vector<6>(copy.getResult().getUsers());
// We currently only support view-like ops with one tensor input and one
// tensor output, meaning that the tensor use-def chains form a tree.
@ -288,7 +288,7 @@ public:
Operation *op = workList.pop_back_val();
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
copyToValueTensorOps.push_back(copyToValueTensor);
} else if (auto returnOp = dyn_cast<mlir::ReturnOp>(op)) {
} else if (auto returnOp = dyn_cast<mlir::func::ReturnOp>(op)) {
returnOps.push_back(returnOp);
} else if (isViewLikeOp(op)) {
viewLikeOps.push_back(op);
@ -309,7 +309,7 @@ public:
for (CopyToValueTensorOp op : copyToValueTensorOps)
rewriter.replaceOp(op, op.getOperand());
// Keep track of the original types of any view-like ops, so that we can
// correctly copy them back to their mlir::ReturnOp's expected types.
// correctly copy them back to their mlir::func::ReturnOp's expected types.
DenseMap<Value, Type> originalTypes;
for (Operation *op : viewLikeOps) {
rewriter.updateRootInPlace(op, [&]() {
@ -321,7 +321,7 @@ public:
});
}
// For ReturnOp's, we need to update the operands to their original types.
for (mlir::ReturnOp op : returnOps) {
for (mlir::func::ReturnOp op : returnOps) {
for (int i = 0, e = op->getNumOperands(); i < e; i++) {
OpOperand &operand = op->getOpOperand(i);
auto it = originalTypes.find(operand.get());

View File

@ -9,7 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@ -41,7 +41,7 @@ public:
}
}
assert(func);
rewriter.replaceOpWithNewOp<CallOp>(op, func, op->getOperands());
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, op->getOperands());
return success();
}
@ -51,10 +51,10 @@ private:
} // namespace
namespace {
class EraseUnusedConstantOp : public OpRewritePattern<ConstantOp> {
class EraseUnusedConstantOp : public OpRewritePattern<func::ConstantOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConstantOp op,
LogicalResult matchAndRewrite(func::ConstantOp op,
PatternRewriter &rewriter) const override {
if (op.use_empty()) {
rewriter.eraseOp(op);
@ -76,7 +76,7 @@ class PrepareForGlobalizeObjectGraphPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<ConvertPrimCallMethodToCall>(context, symbolTable);
CallIndirectOp::getCanonicalizationPatterns(patterns, context);
func::CallIndirectOp::getCanonicalizationPatterns(patterns, context);
patterns.add<EraseUnusedConstantOp>(context);
// Use applyPatternsAndFoldGreedily because the CallIndirectOp folding
@ -94,9 +94,9 @@ class PrepareForGlobalizeObjectGraphPass
// to the form we want.
ConversionTarget target(*context);
target.addIllegalOp<PrimCallMethodOp>();
target.addDynamicallyLegalOp<ConstantOp>(
[](ConstantOp op) { return !op.getType().isa<FunctionType>(); });
target.addIllegalOp<CallIndirectOp>();
target.addDynamicallyLegalOp<func::ConstantOp>(
[](func::ConstantOp op) { return !op.getType().isa<FunctionType>(); });
target.addIllegalOp<func::CallIndirectOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
RewritePatternSet dummyPatterns(context);

View File

@ -9,7 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -42,8 +42,8 @@ class RefinePublicReturnPass
void rewriteSignature(FuncOp func) {
// Find the unique return op.
ReturnOp returnOp;
WalkResult walkResult = func.walk([&](ReturnOp op) {
func::ReturnOp returnOp;
WalkResult walkResult = func.walk([&](func::ReturnOp op) {
if (returnOp)
return WalkResult::interrupt();
returnOp = op;

View File

@ -9,10 +9,10 @@
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Parser.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/InliningUtils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
@ -192,7 +192,8 @@ populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands,
}
// Create the call to the shape function!
auto call = b.create<mlir::CallOp>(loc, shapeFunction, shapeFunctionArgs);
auto call =
b.create<mlir::func::CallOp>(loc, shapeFunction, shapeFunctionArgs);
// Python models multiple results with a tuple, so we need to unpack it
// if the op has multiple results.
@ -223,7 +224,7 @@ class ReifyShapeCalculationsPass
// TODO: Find a way to not have to parse this every time.
// The shape library is O(#ops we know about), and this pass should be
// O(#ops in the program) ideally.
auto shapeLibrary = parseSourceString(getShapeLibrary(), context);
auto shapeLibrary = parseSourceString<ModuleOp>(getShapeLibrary(), context);
// Walk all the operations, and if we have a shape function, wrap the op
// in a `torch.shape.calculate` op.
@ -286,7 +287,8 @@ class ReifyShapeCalculationsPass
func.setVisibility(SymbolTable::Visibility::Private);
// Continue the DFS.
importedFunctions.insert(symName);
func.walk([&](CallOp op) { worklist.push_back(op.getCallee().str()); });
func.walk(
[&](func::CallOp op) { worklist.push_back(op.getCallee().str()); });
}
}
};

View File

@ -9,10 +9,10 @@
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Parser.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/InliningUtils.h"

View File

@ -9,8 +9,8 @@
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@ -204,8 +204,8 @@ struct FuncBackendTypeConversionPass
typeConverter.isLegal(&op.getBody());
});
populateCallOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<CallOp>(
[&](CallOp op) { return typeConverter.isLegal(op); });
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return typeConverter.isLegal(op); });
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);

View File

@ -18,7 +18,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRStandardOpsTransforms
MLIRFuncTransforms
TorchMLIRTorchConversionDialect
TorchMLIRTorchDialect
TorchMLIRTorchPasses

View File

@ -8,18 +8,18 @@
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;

View File

@ -10,9 +10,9 @@
#include "PassDetail.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
@ -60,12 +60,13 @@ class VerifyLinalgOnTensorsBackendContractPass
ConversionTarget target(*context);
// Structural operations.
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>(
opHasLegalTypes);
target.addDynamicallyLegalOp<GetNextSeedOp>(opHasLegalTypes);
// Basic scalar operations.
target.addDynamicallyLegalDialect<StandardOpsDialect>(isLegalScalarOp);
target.addDynamicallyLegalDialect<func::FuncDialect>(isLegalScalarOp);
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
isLegalScalarOp);

View File

@ -9,7 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinOps.h"
@ -39,7 +39,8 @@ class VerifyTosaBackendContractPass
ConversionTarget target(*context);
// Structural operations.
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>(
opHasLegalTypes);
// Basic scalar operations.
target.addLegalDialect<tosa::TosaDialect>();
target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes);

View File

@ -16,11 +16,12 @@
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Approximation.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
@ -104,12 +105,12 @@ static std::string getConsumeReturnFunctionNameForReturnTypes(TypeRange types) {
// Replace the original returnOp with a call to consumeFuncReturnFunc and add
// the op to the `toErase` vector.
static void replaceReturnWithCall(OpBuilder b, ReturnOp op, StringRef funcName,
TypeRange retTypes,
static void replaceReturnWithCall(OpBuilder b, func::ReturnOp op,
StringRef funcName, TypeRange retTypes,
SmallVectorImpl<Value> &vals,
SmallVectorImpl<Operation *> &toErase) {
b.create<mlir::CallOp>(op.getLoc(), funcName, TypeRange({}), vals);
b.create<mlir::ReturnOp>(op.getLoc());
b.create<mlir::func::CallOp>(op.getLoc(), funcName, TypeRange({}), vals);
b.create<mlir::func::ReturnOp>(op.getLoc());
toErase.push_back(op);
}
@ -147,7 +148,7 @@ static LogicalResult mungeFunction(
SmallVector<Operation *> toErase;
bool isSupported = true;
func.walk([&](ReturnOp op) {
func.walk([&](func::ReturnOp op) {
auto types = op.getOperandTypes();
b.setInsertionPoint(op);
// Memref Types.
@ -341,7 +342,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
populateExpandTanhPattern(patterns);
patterns.add<math::ErfPolynomialApproximation>(patterns.getContext());
ConversionTarget target(*context);
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<math::MathDialect>();
target.addLegalDialect<arith::ArithmeticDialect>();
target.addIllegalOp<math::TanhOp>();
@ -398,10 +399,6 @@ class MemrefCopyOpToLinalg : public OpRewritePattern<memref::CopyOp> {
};
class MungeMemrefCopy : public MungeMemrefCopyBase<MungeMemrefCopy> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
@ -418,3 +415,27 @@ std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::RefBackend::createMungeMemrefCopyPass() {
return std::make_unique<MungeMemrefCopy>();
}
namespace {
class GeneralizeTensorPad
: public GeneralizeTensorPadBase<GeneralizeTensorPad> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
patterns.insert<linalg::GeneralizePadOpPattern>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::torch::RefBackend::createGeneralizeTensorPadPass() {
return std::make_unique<GeneralizeTensorPad>();
}

View File

@ -60,7 +60,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) {
createMlirOperationAtEnd(
appendToBlock, "std.return", loc,
appendToBlock, "func.return", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
};
MlirBlock block = importBlock(

View File

@ -178,7 +178,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
torch::jit::Function *function = functionType->function();
const std::string &symName = function->qualname().qualifiedName();
op = createMlirOperation(
"std.constant", loc,
"func.constant", loc,
getFunctionTypeFromSchema(context, function->getSchema()),
toMlirNamedAttribute(
"value",
@ -280,7 +280,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
return getMlirTypeFromTorchType(loc, v->type());
});
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "std.call_indirect", loc,
appendToBlock, "func.call_indirect", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValue(node->input(0)),
derefineValues(lookupMappedValues(node->inputs().slice(1)),

View File

@ -164,11 +164,11 @@ class RefBackendInvoker:
LOWERING_PIPELINE = ",".join([
"builtin.func(refback-generalize-tensor-pad)",
# Bufferize.
"builtin.func(scf-bufferize)",
"builtin.func(tm-tensor-bufferize)",
"builtin.func(linalg-bufferize)",
"builtin.func(refback-munge-memref-copy)",
"func-bufferize",
"arith-bufferize",
"builtin.func(tensor-bufferize)",
@ -185,6 +185,7 @@ LOWERING_PIPELINE = ",".join([
"refback-insert-rng-globals",
# Lower to LLVM
"builtin.func(tm-tensor-to-loops)",
"builtin.func(refback-munge-memref-copy)",
"builtin.func(convert-linalg-to-loops)",
"builtin.func(lower-affine)",
"convert-scf-to-cf",
@ -194,7 +195,7 @@ LOWERING_PIPELINE = ",".join([
"convert-linalg-to-llvm",
"convert-memref-to-llvm",
"builtin.func(convert-arith-to-llvm)",
"convert-std-to-llvm",
"convert-func-to-llvm",
"convert-cf-to-llvm",
"reconcile-unrealized-casts",
])

View File

@ -17,7 +17,7 @@
// CHECK: assert %[[EQ]], "mismatching contracting dimension for torch.aten.mm"
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[LHS_DIM_0]], %[[RHS_DIM_1]]] : tensor<?x?xf32>
// CHECK: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[CF0]], %[[INIT_TENSOR]]) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: %[[ZEROFILL:.*]] = linalg.fill ins(%[[CF0]] : f32) outs(%[[INIT_TENSOR]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ZEROFILL]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
// CHECK: %[[RESULT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x2xf32> -> !torch.vtensor<[?,2],f32>
@ -91,7 +91,7 @@ func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.
func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int {
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[?,?],si64> -> !torch.int
return %0 : !torch.int
}
}
// -----
@ -104,7 +104,7 @@ func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) ->
func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch.float {
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[],f64> -> !torch.float
return %0 : !torch.float
}
}
// -----
@ -129,14 +129,14 @@ func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch
func @torch.aten.Float.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.float {
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[?,?],f64> -> !torch.float
return %0 : !torch.float
}
}
// -----
// CHECK-LABEL: func @torch.aten.Bool.Tensor$zero_rank
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],i1>) -> !torch.bool {
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],i1> -> tensor<i1>
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i1>
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i1>
// CHECK: %[[RES:.*]] = torch_c.from_i1 %[[EXT]]
// CHECK: return %[[RES]] : !torch.bool
func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.bool {
@ -167,14 +167,14 @@ func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.b
func @torch.aten.Bool.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.bool {
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[?,?],i1> -> !torch.bool
return %0 : !torch.bool
}
}
// -----
// CHECK: func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> {
// CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]]
// CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor<i64>
// CHECK: %[[FILLVEC:.*]] = linalg.fill(%[[INI64]], %[[NEWVEC]]) : i64, tensor<i64> -> tensor<i64>
// CHECK: %[[FILLVEC:.*]] = linalg.fill ins(%[[INI64]] : i64) outs(%[[NEWVEC]] : tensor<i64>) -> tensor<i64>
// CHECK: %[[OUTVEC:.*]] = torch_c.from_builtin_tensor %[[FILLVEC]] : tensor<i64> -> !torch.vtensor<[],si64>
// CHECK: return %[[OUTVEC]] : !torch.vtensor<[],si64>
func @torch.prim.NumToTensor.Scalar$basic(%arg0: !torch.int) -> !torch.vtensor<[],si64> {

View File

@ -13,7 +13,7 @@ builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?
%false = torch.constant.bool false
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
// CHECK: %[[NEUTRAL:.*]] = arith.constant -1.401300e-45 : f32
// CHECK: %[[OUT:.*]] = linalg.fill(%[[NEUTRAL]], %{{.*}}) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32>
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[C1]], %[[C2]]] : tensor<?x?xf32>

View File

@ -25,7 +25,7 @@ torch.class_type @c {
}
%c0 = torch.constant.int 0
%0 = torch.nn_module {
// expected-error @+1 {{'torch.slot' op is expected to match type and name of 'torch.attr "g" : !torch.int'}}
// expected-error @+1 {{'torch.slot' op is expected to match type and name of '"torch.attr"() {name = "g", type = !torch.int} : () -> ()}}
torch.slot "f", %c0 : !torch.int
} : !torch.nn.Module<"c">
@ -138,7 +138,7 @@ builtin.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
builtin.func @torch.tensor() {
// Incompatible shape.
// expected-error@+1 {{incompatible}}
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
return
}
@ -147,7 +147,7 @@ builtin.func @torch.tensor() {
builtin.func @torch.tensor() {
// Incompatible dtype.
// expected-error@+1 {{incompatible}}
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
return
}
@ -156,7 +156,7 @@ builtin.func @torch.tensor() {
builtin.func @torch.tensor() {
// Incompatible type.
// expected-error@+1 {{incompatible}}
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1
return
}

View File

@ -19,7 +19,7 @@ func private @test_call_method(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float
// CHECK-LABEL: func private @test_call_indirect(
// CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">,
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
// Ensure no std.constant.
// Ensure no func.constant.
// CHECK-NEXT: %[[VAL_2:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
// CHECK-NEXT: return %[[VAL_2]] : !torch.float
func private @test_call_indirect(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {

View File

@ -12,7 +12,7 @@ func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%4 = arith.cmpi eq, %1, %2 : index
cf.assert %4, "mismatching contracting dimension for aten.mm"
%5 = linalg.init_tensor [%0, %3] : tensor<?x?xf32>
%6 = linalg.fill(%cst, %5) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32>
%7 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%6 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %7 : tensor<?x?xf32>
}
@ -47,7 +47,7 @@ module {
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
module {
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
// expected-error@+1 {{failed to legalize operation 'std.return'}}
// expected-error@+1 {{failed to legalize operation 'func.return'}}
return %arg0 : !torch.tensor
}
}

View File

@ -36,7 +36,7 @@ module {
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
module {
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
// expected-error@+1 {{failed to legalize operation 'std.return'}}
// expected-error@+1 {{failed to legalize operation 'func.return'}}
return %arg0 : !torch.tensor
}
}

View File

@ -9,7 +9,7 @@
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "torch-mlir/InitAll.h"
using namespace mlir;