mirror of https://github.com/llvm/torch-mlir
Bump LLVM at 8361c5da30588d3d4a48eae648f53be1feb5cfad
parent
218b4875d5
commit
63fb1e5aad
|
@ -23,7 +23,7 @@ from itertools import chain
|
|||
|
||||
from torch_mlir import ir
|
||||
import torch_mlir.dialects.torch as torch_d
|
||||
from torch_mlir.dialects import builtin, std
|
||||
from torch_mlir.dialects import builtin, std, func
|
||||
|
||||
import torch.fx
|
||||
from torch.fx.experimental.fx_acc import acc_ops
|
||||
|
@ -288,7 +288,7 @@ class _ForwardFunctionBuilder(_Builder):
|
|||
result = self._insert_function_call(node)
|
||||
self.env[node] = result
|
||||
elif node.op == 'output':
|
||||
std.ReturnOp([self.env[node_arg] for node_arg in node.args],
|
||||
func.ReturnOp([self.env[node_arg] for node_arg in node.args],
|
||||
loc=self.loc, ip=self.func_ip)
|
||||
elif node.op == 'placeholder':
|
||||
continue
|
||||
|
|
|
@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTMTensorDialect
|
|||
MLIRSideEffectInterfaces
|
||||
MLIRSupport
|
||||
MLIRSCF
|
||||
MLIRStandard
|
||||
MLIRFunc
|
||||
MLIRTensor
|
||||
MLIRViewLikeInterface
|
||||
)
|
||||
|
|
|
@ -10,8 +10,8 @@
|
|||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
|
|
@ -10,11 +10,11 @@
|
|||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
|
|
@ -10,9 +10,10 @@
|
|||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
|
@ -132,8 +133,8 @@ struct TMTensorBufferizePass
|
|||
bufferization::BufferizeTypeConverter typeConverter;
|
||||
|
||||
// Mark all Standard operations legal.
|
||||
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect, tensor::TensorDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect, func::FuncDialect,
|
||||
memref::MemRefDialect, tensor::TensorDialect>();
|
||||
|
||||
// Mark all TMTensor operations illegal as long as they work on tensors.
|
||||
auto isLegalOperation = [&](Operation *op) {
|
||||
|
|
|
@ -16,7 +16,7 @@ add_mlir_library(TorchMLIRTMTensorPasses
|
|||
MLIRMemRef
|
||||
MLIRPass
|
||||
MLIRSCF
|
||||
MLIRStandard
|
||||
MLIRFunc
|
||||
MLIRSupport
|
||||
MLIRTensor
|
||||
MLIRTransforms
|
||||
|
|
|
@ -7,11 +7,11 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
@ -93,7 +93,7 @@ struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern {
|
|||
namespace {
|
||||
struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect, StandardOpsDialect,
|
||||
registry.insert<linalg::LinalgDialect, func::FuncDialect,
|
||||
mlir::arith::ArithmeticDialect, math::MathDialect,
|
||||
memref::MemRefDialect, scf::SCFDialect>();
|
||||
}
|
||||
|
|
|
@ -64,18 +64,18 @@ func @scatter_mixed_tensor_memref(
|
|||
|
||||
// -----
|
||||
|
||||
func @scatter_mixed_tensor_memref(
|
||||
func @scatter_output_type_mismatch(
|
||||
%update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
|
||||
%original : tensor<?x?xf32>) -> memref<?x?xf32> {
|
||||
// expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'memref<?x?xf32>'}}
|
||||
%original : tensor<?x?xf32>) -> tensor<4x?xf32> {
|
||||
// expected-error @+1 {{expected type of `outs` operand #0 'tensor<?x?xf32>' to be same as result type 'tensor<4x?xf32>'}}
|
||||
%0 = tm_tensor.scatter unique_indices(true)
|
||||
ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
|
||||
outs(%original : tensor<?x?xf32>) {
|
||||
^bb0(%arg1: f32, %arg2: f32):
|
||||
%1 = arith.addf %arg1, %arg2 : f32
|
||||
tm_tensor.yield %1 : f32
|
||||
} -> memref<?x?xf32>
|
||||
return %0 : memref<?x?xf32>
|
||||
} -> tensor<4x?xf32>
|
||||
return %0 : tensor<4x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -6,7 +6,7 @@ set(LIBS
|
|||
MLIROptLib
|
||||
MLIRSCF
|
||||
MLIRSCFTransforms
|
||||
MLIRStandard
|
||||
MLIRFunc
|
||||
MLIRTensor
|
||||
MLIRTransforms
|
||||
TorchMLIRTMTensorDialect
|
||||
|
|
|
@ -8,15 +8,15 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/ScalarLoopOpInterface.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
|
@ -40,7 +40,7 @@ int main(int argc, char **argv) {
|
|||
mlir::torch::TMTensor::TMTensorDialect,
|
||||
// Upstream dialects
|
||||
mlir::arith::ArithmeticDialect, mlir::linalg::LinalgDialect,
|
||||
mlir::memref::MemRefDialect, mlir::StandardOpsDialect,
|
||||
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
|
||||
mlir::scf::SCFDialect, mlir::tensor::TensorDialect>();
|
||||
|
||||
return mlir::asMainReturnCode(
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit bfc6fbfb65f6d490e6a1ba4eb6c734d2b4494dd1
|
||||
Subproject commit 8361c5da30588d3d4a48eae648f53be1feb5cfad
|
|
@ -36,7 +36,7 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
|||
(ATen) and ops where such mismatches are undefined behavior (linalg).
|
||||
|
||||
To model the termination of the program for implementing error guards,
|
||||
we use the `std.assert` op.
|
||||
we use the `cf.assert` op.
|
||||
This is a design decision that is at variance from other passes in the
|
||||
ecosystem, which use the
|
||||
`shape` dialect's witness system (`shape.cstr_*` family of ops feeding into
|
||||
|
|
|
@ -474,7 +474,6 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
|
|||
$maxTripCount `,` $initialCondition `,` `init` `(` $iterArgsInit `)` $region
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns true if this loop is "for-like". Otherwise it is "while-like"
|
||||
/// and this function returns false.
|
||||
|
@ -528,7 +527,6 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [
|
|||
let regions = (region SizedRegion<1>:$thenRegion, SizedRegion<1>:$elseRegion);
|
||||
// Indicate that the operation has a custom parser and printer method.
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -1047,13 +1045,13 @@ let results = (outs
|
|||
// "torch.prim.If"(%cond) ({
|
||||
// "torch.prim.If.yield"() : () -> ()
|
||||
// }, {
|
||||
// "torch.prim.RaiseException"(%msg) : (!torch.str) -> ()
|
||||
// "torch.prim.If.yield"() : () -> ()
|
||||
// "torch.prim.RaiseException"(%msg) : (!torch.str) -> ()
|
||||
// "torch.prim.If.yield"() : () -> ()
|
||||
// }) : (!torch.bool) -> ()
|
||||
//
|
||||
//
|
||||
// This new operation `torch.runtime.assert` is added to simplify the IR control
|
||||
// flow by avoiding unnecessary branches. It also makes insertion of the runtime
|
||||
// assert in the source code easier.
|
||||
// flow by avoiding unnecessary branches. It also makes insertion of the runtime
|
||||
// assert in the source code easier.
|
||||
def Torch_RuntimeAssertOp: Torch_Op<"runtime.assert", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -1104,7 +1102,6 @@ def Torch_ShapeCalculateOp : Torch_Op<"shape.calculate", [
|
|||
let assemblyFormat = [{
|
||||
$body `shapes` $shapeCalculation attr-dict `:` type($results)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_ShapeCalculateYieldOp : Torch_Op<"shape.calculate.yield", [
|
||||
|
|
|
@ -18,7 +18,7 @@ include "torch-mlir/Dialect/Torch/IR/TorchTypes.td"
|
|||
include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionBase.td"
|
||||
|
||||
class TorchConversion_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<TorchConversion_Dialect, mnemonic, traits> {
|
||||
: Op<TorchConversion_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -206,7 +206,11 @@ def TorchConversion_GeneratorToI64Op : TorchConversion_Op<"generator_to_i64", [
|
|||
}];
|
||||
}
|
||||
|
||||
def TorchConversion_GetNextSeedOp: TorchConversion_Op<"get_next_seed", [
|
||||
class TorchConversionWithSideEffect_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<TorchConversion_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
def TorchConversion_GetNextSeedOp: TorchConversionWithSideEffect_Op<"get_next_seed", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
]> {
|
||||
let summary = "Get the next global seed";
|
||||
|
|
|
@ -27,6 +27,8 @@ std::unique_ptr<OperationPass<FuncOp>> createExpandOpsForLLVMPass();
|
|||
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMungeMemrefCopyPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createGeneralizeTensorPadPass();
|
||||
} // namespace RefBackend
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -35,4 +35,9 @@ def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "FuncOp"> {
|
|||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "FuncOp"> {
|
||||
let summary = "Convert tensor.pad to linalg ops";
|
||||
let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()";
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_REFBACKEND_PASSES
|
||||
|
|
|
@ -13,9 +13,9 @@
|
|||
#include "PopulatePatterns.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -41,7 +41,7 @@ public:
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect>();
|
||||
registry.insert<math::MathDialect>();
|
||||
registry.insert<StandardOpsDialect>();
|
||||
registry.insert<func::FuncDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<cf::ControlFlowDialect>();
|
||||
|
@ -51,7 +51,7 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
||||
cf::ControlFlowDialect, math::MathDialect,
|
||||
tensor::TensorDialect, arith::ArithmeticDialect>();
|
||||
target.addLegalOp<TorchConversion::GetNextSeedOp>();
|
||||
|
|
|
@ -14,7 +14,7 @@ add_mlir_conversion_library(TorchMLIRTorchToSCF
|
|||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRSCF
|
||||
MLIRStandard
|
||||
MLIRFunc
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRTorchConversionDialect
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStd
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRStandard
|
||||
MLIRFunc
|
||||
TorchMLIRTorchDialect
|
||||
)
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -158,7 +158,7 @@ namespace {
|
|||
class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<StandardOpsDialect>();
|
||||
registry.insert<func::FuncDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<cf::ControlFlowDialect>();
|
||||
|
@ -168,7 +168,7 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect, StandardOpsDialect,
|
||||
target.addLegalDialect<Torch::TorchDialect, func::FuncDialect,
|
||||
arith::ArithmeticDialect, tensor::TensorDialect,
|
||||
cf::ControlFlowDialect>();
|
||||
|
||||
|
|
|
@ -10,8 +10,8 @@
|
|||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
||||
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
||||
|
@ -312,7 +312,7 @@ class ConvertTorchToTMTensor
|
|||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect>();
|
||||
registry.insert<StandardOpsDialect>();
|
||||
registry.insert<func::FuncDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
registry.insert<TMTensorDialect>();
|
||||
|
@ -322,7 +322,7 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
||||
tensor::TensorDialect, arith::ArithmeticDialect,
|
||||
Torch::TorchDialect, TMTensorDialect>();
|
||||
|
||||
|
|
|
@ -232,10 +232,6 @@ LogicalResult ClassTypeOp::verify() {
|
|||
// PrimLoopOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult PrimLoopOp::verify() {
|
||||
return RegionBranchOpInterface::verifyTypes(*this);
|
||||
}
|
||||
|
||||
OperandRange PrimLoopOp::getSuccessorEntryOperands(unsigned index) {
|
||||
assert(index == 0);
|
||||
return iterArgsInit();
|
||||
|
@ -275,10 +271,6 @@ PrimLoopConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
|
|||
// PrimIfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult PrimIfOp::verify() {
|
||||
return RegionBranchOpInterface::verifyTypes(*this);
|
||||
}
|
||||
|
||||
ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// Create the regions.
|
||||
result.regions.reserve(2);
|
||||
|
@ -1551,10 +1543,6 @@ OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// ShapeCalculateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ShapeCalculateOp::verify() {
|
||||
return RegionBranchOpInterface::verifyTypes(*this);
|
||||
}
|
||||
|
||||
void ShapeCalculateOp::getSuccessorRegions(
|
||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -93,14 +93,15 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class AdjustCallingConventionForCall : public OpConversionPattern<CallOp> {
|
||||
class AdjustCallingConventionForCall
|
||||
: public OpConversionPattern<func::CallOp> {
|
||||
public:
|
||||
AdjustCallingConventionForCall(TypeConverter &converter, MLIRContext *context,
|
||||
TypeBoundMap &typeBoundMap)
|
||||
: OpConversionPattern<CallOp>(converter, context),
|
||||
: OpConversionPattern<func::CallOp>(converter, context),
|
||||
typeBoundMap(typeBoundMap) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(CallOp call, OpAdaptor adaptor,
|
||||
matchAndRewrite(func::CallOp call, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
SmallVector<Type> convertedResults;
|
||||
if (failed(typeConverter->convertTypes(call.getResultTypes(),
|
||||
|
@ -126,8 +127,8 @@ public:
|
|||
newOperands.push_back(operand.value());
|
||||
}
|
||||
|
||||
CallOp newCall = rewriter.create<CallOp>(call.getLoc(), call.getCallee(),
|
||||
convertedResults, newOperands);
|
||||
func::CallOp newCall = rewriter.create<func::CallOp>(
|
||||
call.getLoc(), call.getCallee(), convertedResults, newOperands);
|
||||
int newOpResultIdx = 0;
|
||||
SmallVector<Value> newResults;
|
||||
for (auto type : call.getResultTypes()) {
|
||||
|
@ -153,11 +154,12 @@ private:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class AdjustCallingConventionForReturn : public OpConversionPattern<ReturnOp> {
|
||||
class AdjustCallingConventionForReturn
|
||||
: public OpConversionPattern<func::ReturnOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
SmallVector<Value> newOperands;
|
||||
|
@ -178,7 +180,7 @@ public:
|
|||
}
|
||||
newOperands.push_back(operand);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
|
||||
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, newOperands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -233,13 +235,15 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
|||
//
|
||||
// Bug for doing this better https://bugs.llvm.org/show_bug.cgi?id=49812
|
||||
DenseSet<Operation *> opsInOriginalProgram;
|
||||
func.walk([&](CallOp op) { opsInOriginalProgram.insert(op.getOperation()); });
|
||||
func.walk(
|
||||
[&](ReturnOp op) { opsInOriginalProgram.insert(op.getOperation()); });
|
||||
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
|
||||
[&](func::CallOp op) { opsInOriginalProgram.insert(op.getOperation()); });
|
||||
func.walk([&](func::ReturnOp op) {
|
||||
opsInOriginalProgram.insert(op.getOperation());
|
||||
});
|
||||
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
|
||||
return !opsInOriginalProgram.contains(op.getOperation());
|
||||
});
|
||||
target.addDynamicallyLegalOp<ReturnOp>([&](ReturnOp op) {
|
||||
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
|
||||
return !opsInOriginalProgram.contains(op.getOperation());
|
||||
});
|
||||
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
|
||||
|
@ -249,7 +253,7 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
|||
target.addLegalOp<PrimTupleIndexOp>();
|
||||
target.addLegalOp<PrimTupleConstructOp>();
|
||||
// We don't know how to rewrite it, so mark it as illegal.
|
||||
target.addIllegalOp<CallIndirectOp>();
|
||||
target.addIllegalOp<func::CallIndirectOp>();
|
||||
if (failed(applyPartialConversion(func.getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return failure();
|
||||
|
|
|
@ -9,10 +9,10 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Parser/Parser.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -349,7 +349,7 @@ static LogicalResult analyzeInstances(FuncOp func,
|
|||
}
|
||||
|
||||
static FailureOr<Monomorphization>
|
||||
createMonomorphizationForCall(CallOp op, BlockAndValueMapping &mapping,
|
||||
createMonomorphizationForCall(func::CallOp op, BlockAndValueMapping &mapping,
|
||||
SymbolTable &symbolTable) {
|
||||
auto func = symbolTable.lookup<FuncOp>(op.getCallee());
|
||||
Monomorphization monomorphization;
|
||||
|
@ -413,7 +413,7 @@ private:
|
|||
BlockAndValueMapping mapping;
|
||||
if (failed(analyzeInstances(func, m.argInstances, mapping)))
|
||||
return failure();
|
||||
auto walkResult = func.walk([&](CallOp op) {
|
||||
auto walkResult = func.walk([&](func::CallOp op) {
|
||||
FailureOr<Monomorphization> maybeMonomorphization =
|
||||
createMonomorphizationForCall(op, mapping, symbolTable);
|
||||
if (failed(maybeMonomorphization))
|
||||
|
@ -439,7 +439,7 @@ static LogicalResult verifyNnModuleValueUses(Value value) {
|
|||
if (!value.getType().isa<NnModuleType>())
|
||||
return success();
|
||||
for (Operation *op : value.getUsers()) {
|
||||
if (isa<CallOp, PrimGetAttrOp>(op))
|
||||
if (isa<func::CallOp, PrimGetAttrOp>(op))
|
||||
continue;
|
||||
// Only allow `value` as the receiver.
|
||||
if (isa<PrimSetAttrOp>(op) && cast<PrimSetAttrOp>(op).value() != value)
|
||||
|
@ -539,7 +539,7 @@ rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping,
|
|||
toErase.push_back(op);
|
||||
return WalkResult::advance();
|
||||
};
|
||||
auto handleCall = [&](CallOp op) {
|
||||
auto handleCall = [&](func::CallOp op) {
|
||||
FailureOr<Monomorphization> maybeMonomorphization =
|
||||
createMonomorphizationForCall(op, mapping, symbolTable);
|
||||
if (failed(maybeMonomorphization))
|
||||
|
@ -550,7 +550,7 @@ rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping,
|
|||
return !v.getType().isa<NnModuleType>();
|
||||
}));
|
||||
assert(newFuncs.find(monomorphization) != newFuncs.end());
|
||||
auto newOp = OpBuilder(op).create<CallOp>(
|
||||
auto newOp = OpBuilder(op).create<func::CallOp>(
|
||||
op.getLoc(), newFuncs[monomorphization], newArguments);
|
||||
op.replaceAllUsesWith(newOp);
|
||||
toErase.push_back(op);
|
||||
|
@ -561,7 +561,7 @@ rewriteMonomorphizedFuncClone(FuncOp func, BlockAndValueMapping mapping,
|
|||
return handlePrimSetAttr(primSetAttr);
|
||||
if (auto primGetAttr = dyn_cast<PrimGetAttrOp>(op))
|
||||
return handlePrimGetAttr(primGetAttr);
|
||||
if (auto call = dyn_cast<CallOp>(op))
|
||||
if (auto call = dyn_cast<func::CallOp>(op))
|
||||
return handleCall(call);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -51,7 +51,7 @@ public:
|
|||
SmallVector<Operation *> copyLikeOps;
|
||||
SmallVector<Operation *> viewLikeOps;
|
||||
SmallVector<OverwriteTensorContentsOp> overwriteTensorContentsOps;
|
||||
Optional<mlir::ReturnOp> returnOp;
|
||||
Optional<mlir::func::ReturnOp> returnOp;
|
||||
};
|
||||
|
||||
// Check that graph rewriting is possible by doing an abstract
|
||||
|
@ -103,7 +103,7 @@ public:
|
|||
availableAliases.clear();
|
||||
availableAliases.insert(assertNonValueTensor(overwrite.overwritten()));
|
||||
result.overwriteTensorContentsOps.push_back(overwrite);
|
||||
} else if (auto returnOp = dyn_cast<mlir::ReturnOp>(user)) {
|
||||
} else if (auto returnOp = dyn_cast<mlir::func::ReturnOp>(user)) {
|
||||
result.returnOp = returnOp;
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -234,7 +234,7 @@ public:
|
|||
// Nothing to do if there is just a ReturnOp -- we know that we won't be
|
||||
// rewriting anything, since we must preserve the ReturnOp's original type.
|
||||
if (llvm::hasSingleElement(nonValueTensorsUsedByOp) &&
|
||||
isa<mlir::ReturnOp>(nonValueTensorsUsedByOp.begin()->first)) {
|
||||
isa<mlir::func::ReturnOp>(nonValueTensorsUsedByOp.begin()->first)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -278,7 +278,7 @@ public:
|
|||
// trivially feeds into CopyToValueTensorOp's.
|
||||
SmallVector<Operation *> viewLikeOps;
|
||||
SmallVector<CopyToValueTensorOp> copyToValueTensorOps;
|
||||
SmallVector<mlir::ReturnOp> returnOps;
|
||||
SmallVector<mlir::func::ReturnOp> returnOps;
|
||||
auto workList = llvm::to_vector<6>(copy.getResult().getUsers());
|
||||
// We currently only support view-like ops with one tensor input and one
|
||||
// tensor output, meaning that the tensor use-def chains form a tree.
|
||||
|
@ -288,7 +288,7 @@ public:
|
|||
Operation *op = workList.pop_back_val();
|
||||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||
} else if (auto returnOp = dyn_cast<mlir::ReturnOp>(op)) {
|
||||
} else if (auto returnOp = dyn_cast<mlir::func::ReturnOp>(op)) {
|
||||
returnOps.push_back(returnOp);
|
||||
} else if (isViewLikeOp(op)) {
|
||||
viewLikeOps.push_back(op);
|
||||
|
@ -309,7 +309,7 @@ public:
|
|||
for (CopyToValueTensorOp op : copyToValueTensorOps)
|
||||
rewriter.replaceOp(op, op.getOperand());
|
||||
// Keep track of the original types of any view-like ops, so that we can
|
||||
// correctly copy them back to their mlir::ReturnOp's expected types.
|
||||
// correctly copy them back to their mlir::func::ReturnOp's expected types.
|
||||
DenseMap<Value, Type> originalTypes;
|
||||
for (Operation *op : viewLikeOps) {
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
|
@ -321,7 +321,7 @@ public:
|
|||
});
|
||||
}
|
||||
// For ReturnOp's, we need to update the operands to their original types.
|
||||
for (mlir::ReturnOp op : returnOps) {
|
||||
for (mlir::func::ReturnOp op : returnOps) {
|
||||
for (int i = 0, e = op->getNumOperands(); i < e; i++) {
|
||||
OpOperand &operand = op->getOpOperand(i);
|
||||
auto it = originalTypes.find(operand.get());
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -41,7 +41,7 @@ public:
|
|||
}
|
||||
}
|
||||
assert(func);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, func, op->getOperands());
|
||||
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, op->getOperands());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -51,10 +51,10 @@ private:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class EraseUnusedConstantOp : public OpRewritePattern<ConstantOp> {
|
||||
class EraseUnusedConstantOp : public OpRewritePattern<func::ConstantOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ConstantOp op,
|
||||
LogicalResult matchAndRewrite(func::ConstantOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op.use_empty()) {
|
||||
rewriter.eraseOp(op);
|
||||
|
@ -76,7 +76,7 @@ class PrepareForGlobalizeObjectGraphPass
|
|||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<ConvertPrimCallMethodToCall>(context, symbolTable);
|
||||
CallIndirectOp::getCanonicalizationPatterns(patterns, context);
|
||||
func::CallIndirectOp::getCanonicalizationPatterns(patterns, context);
|
||||
patterns.add<EraseUnusedConstantOp>(context);
|
||||
|
||||
// Use applyPatternsAndFoldGreedily because the CallIndirectOp folding
|
||||
|
@ -94,9 +94,9 @@ class PrepareForGlobalizeObjectGraphPass
|
|||
// to the form we want.
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<PrimCallMethodOp>();
|
||||
target.addDynamicallyLegalOp<ConstantOp>(
|
||||
[](ConstantOp op) { return !op.getType().isa<FunctionType>(); });
|
||||
target.addIllegalOp<CallIndirectOp>();
|
||||
target.addDynamicallyLegalOp<func::ConstantOp>(
|
||||
[](func::ConstantOp op) { return !op.getType().isa<FunctionType>(); });
|
||||
target.addIllegalOp<func::CallIndirectOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
|
||||
RewritePatternSet dummyPatterns(context);
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -42,8 +42,8 @@ class RefinePublicReturnPass
|
|||
|
||||
void rewriteSignature(FuncOp func) {
|
||||
// Find the unique return op.
|
||||
ReturnOp returnOp;
|
||||
WalkResult walkResult = func.walk([&](ReturnOp op) {
|
||||
func::ReturnOp returnOp;
|
||||
WalkResult walkResult = func.walk([&](func::ReturnOp op) {
|
||||
if (returnOp)
|
||||
return WalkResult::interrupt();
|
||||
returnOp = op;
|
||||
|
|
|
@ -9,10 +9,10 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Parser/Parser.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
|
@ -192,7 +192,8 @@ populateShapeCalculationRegion(ShapeCalculateOp op, ValueRange originalOperands,
|
|||
}
|
||||
|
||||
// Create the call to the shape function!
|
||||
auto call = b.create<mlir::CallOp>(loc, shapeFunction, shapeFunctionArgs);
|
||||
auto call =
|
||||
b.create<mlir::func::CallOp>(loc, shapeFunction, shapeFunctionArgs);
|
||||
|
||||
// Python models multiple results with a tuple, so we need to unpack it
|
||||
// if the op has multiple results.
|
||||
|
@ -223,7 +224,7 @@ class ReifyShapeCalculationsPass
|
|||
// TODO: Find a way to not have to parse this every time.
|
||||
// The shape library is O(#ops we know about), and this pass should be
|
||||
// O(#ops in the program) ideally.
|
||||
auto shapeLibrary = parseSourceString(getShapeLibrary(), context);
|
||||
auto shapeLibrary = parseSourceString<ModuleOp>(getShapeLibrary(), context);
|
||||
|
||||
// Walk all the operations, and if we have a shape function, wrap the op
|
||||
// in a `torch.shape.calculate` op.
|
||||
|
@ -286,7 +287,8 @@ class ReifyShapeCalculationsPass
|
|||
func.setVisibility(SymbolTable::Visibility::Private);
|
||||
// Continue the DFS.
|
||||
importedFunctions.insert(symName);
|
||||
func.walk([&](CallOp op) { worklist.push_back(op.getCallee().str()); });
|
||||
func.walk(
|
||||
[&](func::CallOp op) { worklist.push_back(op.getCallee().str()); });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -9,10 +9,10 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Parser/Parser.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
|
|
@ -9,8 +9,8 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -204,8 +204,8 @@ struct FuncBackendTypeConversionPass
|
|||
typeConverter.isLegal(&op.getBody());
|
||||
});
|
||||
populateCallOpTypeConversionPattern(patterns, typeConverter);
|
||||
target.addDynamicallyLegalOp<CallOp>(
|
||||
[&](CallOp op) { return typeConverter.isLegal(op); });
|
||||
target.addDynamicallyLegalOp<func::CallOp>(
|
||||
[&](func::CallOp op) { return typeConverter.isLegal(op); });
|
||||
|
||||
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
|
||||
populateReturnOpTypeConversionPattern(patterns, typeConverter);
|
||||
|
|
|
@ -18,7 +18,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRStandardOpsTransforms
|
||||
MLIRFuncTransforms
|
||||
TorchMLIRTorchConversionDialect
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRTorchPasses
|
||||
|
|
|
@ -8,18 +8,18 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToStd/TorchToStd.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
|
|
@ -10,9 +10,9 @@
|
|||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -60,12 +60,13 @@ class VerifyLinalgOnTensorsBackendContractPass
|
|||
ConversionTarget target(*context);
|
||||
|
||||
// Structural operations.
|
||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>(
|
||||
opHasLegalTypes);
|
||||
|
||||
target.addDynamicallyLegalOp<GetNextSeedOp>(opHasLegalTypes);
|
||||
|
||||
// Basic scalar operations.
|
||||
target.addDynamicallyLegalDialect<StandardOpsDialect>(isLegalScalarOp);
|
||||
target.addDynamicallyLegalDialect<func::FuncDialect>(isLegalScalarOp);
|
||||
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
|
||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||
isLegalScalarOp);
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -39,7 +39,8 @@ class VerifyTosaBackendContractPass
|
|||
ConversionTarget target(*context);
|
||||
|
||||
// Structural operations.
|
||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, func::ReturnOp>(
|
||||
opHasLegalTypes);
|
||||
// Basic scalar operations.
|
||||
target.addLegalDialect<tosa::TosaDialect>();
|
||||
target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes);
|
||||
|
|
|
@ -16,11 +16,12 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Approximation.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
|
@ -104,12 +105,12 @@ static std::string getConsumeReturnFunctionNameForReturnTypes(TypeRange types) {
|
|||
|
||||
// Replace the original returnOp with a call to consumeFuncReturnFunc and add
|
||||
// the op to the `toErase` vector.
|
||||
static void replaceReturnWithCall(OpBuilder b, ReturnOp op, StringRef funcName,
|
||||
TypeRange retTypes,
|
||||
static void replaceReturnWithCall(OpBuilder b, func::ReturnOp op,
|
||||
StringRef funcName, TypeRange retTypes,
|
||||
SmallVectorImpl<Value> &vals,
|
||||
SmallVectorImpl<Operation *> &toErase) {
|
||||
b.create<mlir::CallOp>(op.getLoc(), funcName, TypeRange({}), vals);
|
||||
b.create<mlir::ReturnOp>(op.getLoc());
|
||||
b.create<mlir::func::CallOp>(op.getLoc(), funcName, TypeRange({}), vals);
|
||||
b.create<mlir::func::ReturnOp>(op.getLoc());
|
||||
toErase.push_back(op);
|
||||
}
|
||||
|
||||
|
@ -147,7 +148,7 @@ static LogicalResult mungeFunction(
|
|||
|
||||
SmallVector<Operation *> toErase;
|
||||
bool isSupported = true;
|
||||
func.walk([&](ReturnOp op) {
|
||||
func.walk([&](func::ReturnOp op) {
|
||||
auto types = op.getOperandTypes();
|
||||
b.setInsertionPoint(op);
|
||||
// Memref Types.
|
||||
|
@ -341,7 +342,7 @@ class ExpandOpsForLLVM : public ExpandOpsForLLVMBase<ExpandOpsForLLVM> {
|
|||
populateExpandTanhPattern(patterns);
|
||||
patterns.add<math::ErfPolynomialApproximation>(patterns.getContext());
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<func::FuncDialect>();
|
||||
target.addLegalDialect<math::MathDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
target.addIllegalOp<math::TanhOp>();
|
||||
|
@ -398,10 +399,6 @@ class MemrefCopyOpToLinalg : public OpRewritePattern<memref::CopyOp> {
|
|||
};
|
||||
|
||||
class MungeMemrefCopy : public MungeMemrefCopyBase<MungeMemrefCopy> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(&getContext());
|
||||
|
@ -418,3 +415,27 @@ std::unique_ptr<OperationPass<FuncOp>>
|
|||
mlir::torch::RefBackend::createMungeMemrefCopyPass() {
|
||||
return std::make_unique<MungeMemrefCopy>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class GeneralizeTensorPad
|
||||
: public GeneralizeTensorPadBase<GeneralizeTensorPad> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<linalg::LinalgDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(&getContext());
|
||||
patterns.insert<linalg::GeneralizePadOpPattern>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::torch::RefBackend::createGeneralizeTensorPadPass() {
|
||||
return std::make_unique<GeneralizeTensorPad>();
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
|||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||
MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(
|
||||
appendToBlock, "std.return", loc,
|
||||
appendToBlock, "func.return", loc,
|
||||
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||
};
|
||||
MlirBlock block = importBlock(
|
||||
|
|
|
@ -178,7 +178,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
torch::jit::Function *function = functionType->function();
|
||||
const std::string &symName = function->qualname().qualifiedName();
|
||||
op = createMlirOperation(
|
||||
"std.constant", loc,
|
||||
"func.constant", loc,
|
||||
getFunctionTypeFromSchema(context, function->getSchema()),
|
||||
toMlirNamedAttribute(
|
||||
"value",
|
||||
|
@ -280,7 +280,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
return getMlirTypeFromTorchType(loc, v->type());
|
||||
});
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "std.call_indirect", loc,
|
||||
appendToBlock, "func.call_indirect", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValue(node->input(0)),
|
||||
derefineValues(lookupMappedValues(node->inputs().slice(1)),
|
||||
|
|
|
@ -164,11 +164,11 @@ class RefBackendInvoker:
|
|||
|
||||
|
||||
LOWERING_PIPELINE = ",".join([
|
||||
"builtin.func(refback-generalize-tensor-pad)",
|
||||
# Bufferize.
|
||||
"builtin.func(scf-bufferize)",
|
||||
"builtin.func(tm-tensor-bufferize)",
|
||||
"builtin.func(linalg-bufferize)",
|
||||
"builtin.func(refback-munge-memref-copy)",
|
||||
"func-bufferize",
|
||||
"arith-bufferize",
|
||||
"builtin.func(tensor-bufferize)",
|
||||
|
@ -185,6 +185,7 @@ LOWERING_PIPELINE = ",".join([
|
|||
"refback-insert-rng-globals",
|
||||
# Lower to LLVM
|
||||
"builtin.func(tm-tensor-to-loops)",
|
||||
"builtin.func(refback-munge-memref-copy)",
|
||||
"builtin.func(convert-linalg-to-loops)",
|
||||
"builtin.func(lower-affine)",
|
||||
"convert-scf-to-cf",
|
||||
|
@ -194,7 +195,7 @@ LOWERING_PIPELINE = ",".join([
|
|||
"convert-linalg-to-llvm",
|
||||
"convert-memref-to-llvm",
|
||||
"builtin.func(convert-arith-to-llvm)",
|
||||
"convert-std-to-llvm",
|
||||
"convert-func-to-llvm",
|
||||
"convert-cf-to-llvm",
|
||||
"reconcile-unrealized-casts",
|
||||
])
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
// CHECK: assert %[[EQ]], "mismatching contracting dimension for torch.aten.mm"
|
||||
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[LHS_DIM_0]], %[[RHS_DIM_1]]] : tensor<?x?xf32>
|
||||
// CHECK: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[CF0]], %[[INIT_TENSOR]]) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[ZEROFILL:.*]] = linalg.fill ins(%[[CF0]] : f32) outs(%[[INIT_TENSOR]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ZEROFILL]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
|
||||
// CHECK: %[[RESULT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<?x2xf32> -> !torch.vtensor<[?,2],f32>
|
||||
|
@ -91,7 +91,7 @@ func @torch.aten.Int.Tensor$zero_rank(%arg0: !torch.vtensor<[],si64>) -> !torch.
|
|||
func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int {
|
||||
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[?,?],si64> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -104,7 +104,7 @@ func @torch.aten.Int.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],si64>) ->
|
|||
func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch.float {
|
||||
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[],f64> -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -129,14 +129,14 @@ func @torch.aten.Float.Tensor$zero_rank(%arg0: !torch.vtensor<[],f64>) -> !torch
|
|||
func @torch.aten.Float.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.float {
|
||||
%0 = torch.aten.Float.Tensor %arg0 : !torch.vtensor<[?,?],f64> -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.Bool.Tensor$zero_rank
|
||||
// CHECK-SAME: (%[[ARG:.*]]: !torch.vtensor<[],i1>) -> !torch.bool {
|
||||
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],i1> -> tensor<i1>
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i1>
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i1>
|
||||
// CHECK: %[[RES:.*]] = torch_c.from_i1 %[[EXT]]
|
||||
// CHECK: return %[[RES]] : !torch.bool
|
||||
func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.bool {
|
||||
|
@ -167,14 +167,14 @@ func @torch.aten.Bool.Tensor$zero_rank(%arg0: !torch.vtensor<[],i1>) -> !torch.b
|
|||
func @torch.aten.Bool.Tensor$non_zero_rank(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.bool {
|
||||
%0 = torch.aten.Bool.Tensor %arg0 : !torch.vtensor<[?,?],i1> -> !torch.bool
|
||||
return %0 : !torch.bool
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]]
|
||||
// CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor<i64>
|
||||
// CHECK: %[[FILLVEC:.*]] = linalg.fill(%[[INI64]], %[[NEWVEC]]) : i64, tensor<i64> -> tensor<i64>
|
||||
// CHECK: %[[FILLVEC:.*]] = linalg.fill ins(%[[INI64]] : i64) outs(%[[NEWVEC]] : tensor<i64>) -> tensor<i64>
|
||||
// CHECK: %[[OUTVEC:.*]] = torch_c.from_builtin_tensor %[[FILLVEC]] : tensor<i64> -> !torch.vtensor<[],si64>
|
||||
// CHECK: return %[[OUTVEC]] : !torch.vtensor<[],si64>
|
||||
func @torch.prim.NumToTensor.Scalar$basic(%arg0: !torch.int) -> !torch.vtensor<[],si64> {
|
||||
|
|
|
@ -13,7 +13,7 @@ builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?
|
|||
%false = torch.constant.bool false
|
||||
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
|
||||
// CHECK: %[[NEUTRAL:.*]] = arith.constant -1.401300e-45 : f32
|
||||
// CHECK: %[[OUT:.*]] = linalg.fill(%[[NEUTRAL]], %{{.*}}) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [%[[C1]], %[[C2]]] : tensor<?x?xf32>
|
||||
|
|
|
@ -25,7 +25,7 @@ torch.class_type @c {
|
|||
}
|
||||
%c0 = torch.constant.int 0
|
||||
%0 = torch.nn_module {
|
||||
// expected-error @+1 {{'torch.slot' op is expected to match type and name of 'torch.attr "g" : !torch.int'}}
|
||||
// expected-error @+1 {{'torch.slot' op is expected to match type and name of '"torch.attr"() {name = "g", type = !torch.int} : () -> ()}}
|
||||
torch.slot "f", %c0 : !torch.int
|
||||
} : !torch.nn.Module<"c">
|
||||
|
||||
|
@ -138,7 +138,7 @@ builtin.func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
|
|||
|
||||
builtin.func @torch.tensor() {
|
||||
// Incompatible shape.
|
||||
// expected-error@+1 {{incompatible}}
|
||||
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
||||
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
|
||||
return
|
||||
}
|
||||
|
@ -147,7 +147,7 @@ builtin.func @torch.tensor() {
|
|||
|
||||
builtin.func @torch.tensor() {
|
||||
// Incompatible dtype.
|
||||
// expected-error@+1 {{incompatible}}
|
||||
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
||||
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
|
||||
return
|
||||
}
|
||||
|
@ -156,7 +156,7 @@ builtin.func @torch.tensor() {
|
|||
|
||||
builtin.func @torch.tensor() {
|
||||
// Incompatible type.
|
||||
// expected-error@+1 {{incompatible}}
|
||||
// expected-error@+1 {{must be Multi-dimensional array modeling Torch's Tensor type, but got}}
|
||||
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1
|
||||
return
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ func private @test_call_method(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float
|
|||
// CHECK-LABEL: func private @test_call_indirect(
|
||||
// CHECK-SAME: %[[RECEIVER:.*]]: !torch.nn.Module<"c">,
|
||||
// CHECK-SAME: %[[F:.*]]: !torch.float) -> !torch.float {
|
||||
// Ensure no std.constant.
|
||||
// Ensure no func.constant.
|
||||
// CHECK-NEXT: %[[VAL_2:.*]] = call @test_call_method(%[[RECEIVER]], %[[F]]) : (!torch.nn.Module<"c">, !torch.float) -> !torch.float
|
||||
// CHECK-NEXT: return %[[VAL_2]] : !torch.float
|
||||
func private @test_call_indirect(%arg0: !torch.nn.Module<"c">, %arg1: !torch.float) -> !torch.float {
|
||||
|
|
|
@ -12,7 +12,7 @@ func @mm(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
|||
%4 = arith.cmpi eq, %1, %2 : index
|
||||
cf.assert %4, "mismatching contracting dimension for aten.mm"
|
||||
%5 = linalg.init_tensor [%0, %3] : tensor<?x?xf32>
|
||||
%6 = linalg.fill(%cst, %5) : f32, tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%7 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%6 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %7 : tensor<?x?xf32>
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ module {
|
|||
// expected-error@+1 {{Module does not conform to the linalg-on-tensors backend contract.}}
|
||||
module {
|
||||
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
|
||||
// expected-error@+1 {{failed to legalize operation 'std.return'}}
|
||||
// expected-error@+1 {{failed to legalize operation 'func.return'}}
|
||||
return %arg0 : !torch.tensor
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ module {
|
|||
// expected-error@+1 {{Module does not conform to the TOSA backend contract.}}
|
||||
module {
|
||||
func @disallowed(%arg0: !torch.tensor) -> !torch.tensor {
|
||||
// expected-error@+1 {{failed to legalize operation 'std.return'}}
|
||||
// expected-error@+1 {{failed to legalize operation 'func.return'}}
|
||||
return %arg0 : !torch.tensor
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Support/MlirOptMain.h"
|
||||
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
|
||||
#include "torch-mlir/InitAll.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
|
Loading…
Reference in New Issue