diff --git a/e2e_testing/torchscript/backprop.py b/e2e_testing/torchscript/backprop.py index 2ed72f22f..6488ade08 100644 --- a/e2e_testing/torchscript/backprop.py +++ b/e2e_testing/torchscript/backprop.py @@ -11,11 +11,11 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== + class SoftmaxBackwardModule(torch.nn.Module): def __init__(self): super().__init__() - @export @annotate_args([ None, @@ -33,6 +33,8 @@ class SoftmaxBackwardModule(torch.nn.Module): def SoftmaxBackwardModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4)) + +# ============================================================================== class TanhBackwardModule(torch.nn.Module): def __init__(self): super().__init__() @@ -43,10 +45,11 @@ class TanhBackwardModule(torch.nn.Module): ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True), ]) + def forward(self, grad_out, output): + return torch.ops.aten.tanh_backward(grad_out, output) - def forward(self, out_grad, output): - return torch.ops.aten.tanh_backward(out_grad, output) @register_test_case(module_factory=lambda: TanhBackwardModule()) def TanhBackward_basic(module, tu: TestUtils): module.forward(torch.randn(3, 3), torch.randn(3, 3)) + diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 5fabb4b23..193633086 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -543,7 +543,7 @@ class TensorToInt(torch.nn.Module): @register_test_case(module_factory=lambda: TensorToInt()) def TensorToInt_basic(module, tu: TestUtils): module.forward(torch.randint(10,[])) - + class LogSoftmaxIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -577,3 +577,24 @@ class NumToTensorModule(torch.nn.Module): @register_test_case(module_factory=lambda: NumToTensorModule()) def NumToTensorModule_basic(module, tu: TestUtils): module.forward() + + +# This test can be removed once we have one real op returning 3 float32 tensors +class ReturnThreeTensorFloat32(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b, c): + return a, b, c + +@register_test_case(module_factory=lambda: ReturnThreeTensorFloat32()) +def ReturnThreeTensorFloat32_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3)) + diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 82ebb78a7..2bf9a1eaf 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -29,4 +29,5 @@ TOSA_PASS_SET = { "ElementwiseFloorModule_basic", "ElementwiseLogModule_basic", "TanhBackward_basic", + "ReturnThreeTensorFloat32_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td index 3720083f1..d7dc8d80a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td @@ -43,6 +43,7 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [ AnyTorchType:$result ); let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)"; + let hasCanonicalizer = 1; } def Torch_PrimDeviceOp : Torch_Op<"prim.device", [ diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 9d973b1d3..6656f5969 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -121,8 +121,8 @@ def AdjustCallingConventions function arguments, which should be `!numpy.ndarray<...>`'s. - Python-isms are rewritten to MLIR-isms - NoneType return is rewritten to the absence of a return value. - - (Not implemented yet) Tuple return is rewritten to multiple return - values + - Tuple return is rewritten to multiple return values. + }]; } @@ -219,14 +219,14 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> { let summary = "Decompose complicated torch operations"; let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()"; let description = [{ - Decompose torch operation that are losslessly represented as combinations of - other operations, modulo appropropriate compiler fusion. Note that this pass - is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about - systematic reductions of a large number of ops at once, guided mostly by + Decompose torch operation that are losslessly represented as combinations of + other operations, modulo appropropriate compiler fusion. Note that this pass + is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about + systematic reductions of a large number of ops at once, guided mostly by traits. An example of the transformations done in this pass is: - - convert aten.softmax to softmax(x, dim) + - convert aten.softmax to softmax(x, dim) => tmp=exp(x); tmp / sum(tmp, dim, keepdim=True) }]; } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 3d04de4b2..fdfc46625 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -916,6 +916,29 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// PrimTupleIndexOp +//===----------------------------------------------------------------------===// + +void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) { + auto tupleConstruct = op.tup().getDefiningOp(); + if (!tupleConstruct) + return failure(); + + int64_t i; + if (!matchPattern(op.i(), m_TorchConstantInt(&i))) + return failure(); + + if (i >= (int64_t)tupleConstruct.elements().size()) + return failure(); + + rewriter.replaceOp(op, tupleConstruct.elements()[i]); + return success(); + }); +} + //===----------------------------------------------------------------------===// // PrimTupleUnpackOp //===----------------------------------------------------------------------===// @@ -923,9 +946,7 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) { - auto torchTuple = op.tup(); - auto tupleConstruct = - torchTuple.getDefiningOp(); + auto tupleConstruct = op.tup().getDefiningOp(); if (!tupleConstruct) return failure(); diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 6a2f79ddc..bb11cd5a4 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -66,15 +66,20 @@ public: // TODO: add tuple type. conversion.addInputs(type.index(), type.value()); } + rewriter.applySignatureConversion(&func.getBody(), conversion, + typeConverter); + SmallVector newResultTypes; for (auto type : func.getType().getResults()) { if (auto none = type.dyn_cast()) { continue; } + if (auto tuple = type.dyn_cast()) { + llvm::append_range(newResultTypes, tuple.getContainedTypes()); + continue; + } newResultTypes.push_back(type); } - rewriter.applySignatureConversion(&func.getBody(), conversion, - typeConverter); rewriter.updateRootInPlace(func, [&] { func.setType(FunctionType::get( getContext(), conversion.getConvertedTypes(), newResultTypes)); @@ -131,6 +136,11 @@ public: rewriter.create(call.getLoc(), type)); continue; } + if (type.isa()) { + newResults.push_back(rewriter.create( + call.getLoc(), type, newCall.getResults())); + continue; + } newResults.push_back(newCall.getResult(newOpResultIdx++)); } rewriter.replaceOp(call, newResults); @@ -151,12 +161,22 @@ public: ConversionPatternRewriter &rewriter) const override { SmallVector newOperands; - for (auto operand : llvm::enumerate(adaptor.getOperands())) { - if (!operand.value()) + for (auto operand : adaptor.getOperands()) { + if (!operand) continue; - if (operand.value().getType().isa()) + if (operand.getType().isa()) continue; - newOperands.push_back(operand.value()); + if (auto tuple = operand.getType().dyn_cast()) { + Location loc = op.getLoc(); + for (auto en : llvm::enumerate(tuple.getContainedTypes())) { + auto i = rewriter.create( + loc, rewriter.getI64IntegerAttr(en.index())); + newOperands.push_back( + rewriter.create(loc, en.value(), operand, i)); + } + continue; + } + newOperands.push_back(operand); } rewriter.replaceOpWithNewOp(op, newOperands); return success(); @@ -168,9 +188,14 @@ static LogicalResult adjustCallingConventions(FuncOp func, TypeBoundMap &typeBoundMap) { MLIRContext *context = func.getContext(); RewritePatternSet patterns(context); - // TODO: TupleTypes TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](Torch::TupleType type, + SmallVectorImpl &types) -> Optional { + llvm::append_range(types, type.getContainedTypes()); + return success(); + }); typeConverter.addConversion( [](Torch::NoneType type, SmallVectorImpl &types) -> Optional { @@ -220,6 +245,9 @@ static LogicalResult adjustCallingConventions(FuncOp func, target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); // We don't know how to rewrite it, so mark it as illegal. target.addIllegalOp(); if (failed(applyPartialConversion(func.getOperation(), target, diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 164355df4..00b951b36 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -89,6 +89,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( pm.addPass(createAdjustCallingConventionsPass()); if (options.optimize) { + // Eliminate the PrimTupleIndexOp generated from the + // adjustCallingConventions + pm.addNestedPass(createCanonicalizerPass()); // Inline global slots, which for most inference scenarios deletes them. // This also exposes more information to intraprocedural transformations // below like MaximizeValueSemantics and RefineTypes. diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index fada23f51..734116e60 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -21,7 +21,9 @@ #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Transforms/DialectConversion.h" +#include "set" #include "torch-mlir/RefBackend/Passes.h" +#include using namespace mlir; using namespace mlir::torch; @@ -67,33 +69,47 @@ static Type getAbiTypeForMemRef(Type type) { return UnrankedMemRefType::get(type.cast().getElementType(), 0); } -// Passes the return op operands `val` to `funOp`. Also, adds the op to the -// `toErase` vector. -static void replaceCallToFunction(OpBuilder b, ReturnOp op, FuncOp funcOp, - Value val, +// Helper function to get the type string for one return value like i32, f64, +// mri32 etc. The strings from multiple return values are concatenated to get +// the consumeFuncReturnFunc name. +static std::string getTypeToken(Type type) { + if (type.isSignlessInteger()) + return ("i" + Twine(type.getIntOrFloatBitWidth())).str(); + else if (type.isa()) + return ("f" + Twine(type.getIntOrFloatBitWidth())).str(); + else if (auto memRefType = type.dyn_cast()) + return "mr" + getTypeToken(memRefType.getElementType()); + + llvm_unreachable( + "Type token should handle all types: memref, float and int type"); +} + +// Systematically derive the consumeFuncReturnFunc name from return value types. +static std::string getConsumeReturnFunctionNameForReturnTypes(TypeRange types) { + SmallVector tokens = {"refbackend_consume_func_return"}; + for (auto type : types) + tokens.push_back(getTypeToken(type)); + + return std::accumulate(tokens.begin(), tokens.end(), std::string(), + [](std::string &a, std::string &b) { + return a.empty() ? b : (a + "_" + b); + }); +} + +// Replace the original returnOp with a call to consumeFuncReturnFunc and add +// the op to the `toErase` vector. +static void replaceReturnWithCall(OpBuilder b, ReturnOp op, StringRef funcName, + TypeRange retTypes, + SmallVectorImpl &vals, SmallVectorImpl &toErase) { - b.create(op.getLoc(), funcOp, val); + b.create(op.getLoc(), funcName, TypeRange({}), vals); b.create(op.getLoc()); toErase.push_back(op); } -// Checks whether the return op is munge-compatible and the respective calling -// function is defined. -static bool isReturnOpCompatible(ReturnOp op, - DenseMap &consumeFuncReturnFuncs, - Type returnType) { - auto it = consumeFuncReturnFuncs.find(returnType); - if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) { - op.emitError("must have one return value of Memref type or Elemental types " - "of i64, f64, f32"); - return false; - } - return true; -} - static LogicalResult mungeFunction( - FuncOp func, - DenseMap consumeFuncReturnFuncs) { + FuncOp func, std::set &supportedConsumeFuncReturnFuncs, + std::map> &invokedConsumeFuncReturnFuncs) { // Add `llvm.emit_c_interface`. // This allows ExecutionEngine to resolve the symbol properly. addEmitCInterfaceAttr(func); @@ -120,37 +136,43 @@ static LogicalResult mungeFunction( } SmallVector toErase; - bool isCompatible = false; + bool isSupported = true; func.walk([&](ReturnOp op) { - auto returnType = op.getOperandTypes()[0]; - + auto types = op.getOperandTypes(); b.setInsertionPoint(op); // Memref Types. - if (auto memrefReturnType = returnType.dyn_cast()) { - auto elemType = memrefReturnType.getElementType(); - auto unRankedType = UnrankedMemRefType::get(elemType, 0); - isCompatible = - isReturnOpCompatible(op, consumeFuncReturnFuncs, unRankedType); - if (!isCompatible) - return; - - // Cast to unranked memref type before sending it as a function argument. - auto cast = b.create( - op.getLoc(), op.getOperand(0), - getAbiTypeForMemRef(op.getOperandTypes()[0])); - replaceCallToFunction(b, op, consumeFuncReturnFuncs[unRankedType], - cast.getResult(), toErase); - // Elemental types. - } else if (returnType.isa() || returnType.isa()) { - isCompatible = - isReturnOpCompatible(op, consumeFuncReturnFuncs, returnType); - if (!isCompatible) - return; - replaceCallToFunction(b, op, consumeFuncReturnFuncs[returnType], - op->getOperand(0), toErase); + std::vector retTypes; + SmallVector retVals; + for (auto en : llvm::enumerate(types)) { + Type retType = en.value(); + Value retVal = op.getOperand(en.index()); + if (auto memrefReturnType = retType.dyn_cast()) { + auto elemType = memrefReturnType.getElementType(); + retType = UnrankedMemRefType::get(elemType, 0); + // Cast to unranked memref type before sending it as a function + // argument. + retVal = b.create( + op.getLoc(), retVal, getAbiTypeForMemRef(types[en.index()])); + } + retTypes.push_back(retType); + retVals.push_back(retVal); } + + auto supportedFuncsEnd = supportedConsumeFuncReturnFuncs.end(); + std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes); + if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) { + op.emitError( + "must have one return value of memref types or scalar types " + "of i32, i64, f32, f64 or three return values of memref f32"); + isSupported = false; + } + + auto invokedFuncsEnd = invokedConsumeFuncReturnFuncs.end(); + if (invokedConsumeFuncReturnFuncs.find(funcName) == invokedFuncsEnd) + invokedConsumeFuncReturnFuncs.insert({funcName, retTypes}); + replaceReturnWithCall(b, op, funcName, retTypes, retVals, toErase); }); - if (!isCompatible) + if (!isSupported) return failure(); func.setType(FunctionType::get(func.getContext(), newArgTypes, {})); for (Operation *op : toErase) @@ -158,50 +180,47 @@ static LogicalResult mungeFunction( return success(); } +static std::set getSupportedConsumeFuncReturnFuncs(OpBuilder &b) { + std::set funcNames; + Type mri32 = UnrankedMemRefType::get(b.getI32Type(), 0); + Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0); + Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0); + Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0); + Type i64 = b.getI64Type(); + Type f32 = b.getF32Type(); + Type f64 = b.getF64Type(); + + SmallVector supportedReturnTypes = { + mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}}; + + llvm::for_each(supportedReturnTypes, [&](TypeRange &types) { + funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types)); + }); + return funcNames; +} + namespace { class MungeCallingConventions : public MungeCallingConventionsBase { void runOnOperation() override { auto module = getOperation(); OpBuilder b(module.getBodyRegion()); - DenseMap consumeFuncReturnFuncs; - DenseSet consumeFuncReturnFuncsSet; - auto createConsumeFuncReturnFunc = [&](Type returnType, - std::string funcName) { - auto consumeFuncReturnFunc = b.create( - module.getLoc(), funcName, - FunctionType::get(module.getContext(), returnType, {}), - b.getStringAttr("private")); - addEmitCInterfaceAttr(consumeFuncReturnFunc); - consumeFuncReturnFuncs[returnType] = consumeFuncReturnFunc; - consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc); - }; - - // Memref return types. - createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0), - "refbackend_consume_memref_int32_func_return"); - createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0), - "refbackend_consume_memref_int64_func_return"); - createConsumeFuncReturnFunc( - UnrankedMemRefType::get(b.getF32Type(), 0), - "refbackend_consume_memref_float32_func_return"); - createConsumeFuncReturnFunc( - UnrankedMemRefType::get(b.getF64Type(), 0), - "refbackend_consume_memref_float64_func_return"); - - // Elemental return types. - createConsumeFuncReturnFunc(b.getI64Type(), - "refbackend_consume_int64_func_return"); - createConsumeFuncReturnFunc(b.getF32Type(), - "refbackend_consume_float32_func_return"); - createConsumeFuncReturnFunc(b.getF64Type(), - "refbackend_consume_float64_func_return"); + static std::set supported = + getSupportedConsumeFuncReturnFuncs(b); + std::map> invokedConsumeFuncReturnFuncs; for (auto func : module.getOps()) { - if (consumeFuncReturnFuncsSet.contains(func)) - continue; - if (failed(mungeFunction(func, consumeFuncReturnFuncs))) + if (failed(mungeFunction(func, supported, invokedConsumeFuncReturnFuncs))) return signalPassFailure(); } + + // Create FuncOp for consumeFuncReturnFuncs that are used. + for (auto &p : invokedConsumeFuncReturnFuncs) { + auto consumeFuncReturnFunc = + b.create(module.getLoc(), p.first, + FunctionType::get(module.getContext(), p.second, {}), + b.getStringAttr("private")); + addEmitCInterfaceAttr(consumeFuncReturnFunc); + } } }; } // namespace diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 7ece933f0..f232e4c5f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -399,7 +399,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry): emit_op(registry[key], f, **kwargs) emit("prim::layout : (Tensor) -> (int)") - emit("prim::TupleIndex : (Any, int) -> (Any)") + emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True) emit("prim::device : (Tensor) -> (Device)") emit("prim::dtype : (Tensor) -> (int)", has_folder=True) emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True) diff --git a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 0409afc4f..c35fe55f4 100644 --- a/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -27,59 +27,73 @@ def checkArgTypeIsSupported(ty): SUPPORTED = [np.float32, np.float64, np.int32, np.int64] assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported" + class RefBackendInvoker: def __init__(self, module): self.ee = ExecutionEngine(module) self.result = None @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) - def consume_memref_i32_return(a): + def consume_return_mri32(a): self.result = unranked_memref_to_numpy(a, np.int32) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) - def consume_memref_i64_return(a): + def consume_return_mri64(a): self.result = unranked_memref_to_numpy(a, np.int64) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) - def consume_memref_f32_return(a): + def consume_return_mrf32(a): self.result = unranked_memref_to_numpy(a, np.float32) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) - def consume_memref_f64_return(a): + def consume_return_mrf64(a): self.result = unranked_memref_to_numpy(a, np.float64) @ctypes.CFUNCTYPE(None, ctypes.c_int) - def consume_i64_return(a): + def consume_return_i64(a): self.result = a @ctypes.CFUNCTYPE(None, ctypes.c_float) - def consume_f32_return(a): + def consume_return_f32(a): self.result = a @ctypes.CFUNCTYPE(None, ctypes.c_double) - def consume_f64_return(a): + def consume_return_f64(a): self.result = a - self.ee.register_runtime("refbackend_consume_memref_int32_func_return", - consume_memref_i32_return) + @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor), + ctypes.POINTER(UnrankedMemRefDescriptor), + ctypes.POINTER(UnrankedMemRefDescriptor)) + def consume_return_mrf32_mrf32_mrf32(arg0, arg1, arg2): + self.result = unranked_memref_to_numpy( + arg0, np.float32), unranked_memref_to_numpy( + arg1, + np.float32), unranked_memref_to_numpy(arg2, np.float32) - self.ee.register_runtime("refbackend_consume_memref_int64_func_return", - consume_memref_i64_return) + self.ee.register_runtime("refbackend_consume_func_return_mri32", + consume_return_mri32) - self.ee.register_runtime("refbackend_consume_memref_float32_func_return", - consume_memref_f32_return) + self.ee.register_runtime("refbackend_consume_func_return_mri64", + consume_return_mri64) - self.ee.register_runtime("refbackend_consume_memref_float64_func_return", - consume_memref_f64_return) + self.ee.register_runtime("refbackend_consume_func_return_mrf32", + consume_return_mrf32) - self.ee.register_runtime("refbackend_consume_int64_func_return", - consume_i64_return) + self.ee.register_runtime("refbackend_consume_func_return_mrf64", + consume_return_mrf64) - self.ee.register_runtime("refbackend_consume_float32_func_return", - consume_f32_return) + self.ee.register_runtime("refbackend_consume_func_return_i64", + consume_return_i64) - self.ee.register_runtime("refbackend_consume_float64_func_return", - consume_f64_return) + self.ee.register_runtime("refbackend_consume_func_return_f32", + consume_return_f32) + + self.ee.register_runtime("refbackend_consume_func_return_f64", + consume_return_f64) + + self.ee.register_runtime( + "refbackend_consume_func_return_mrf32_mrf32_mrf32", + consume_return_mrf32_mrf32_mrf32) def __getattr__(self, function_name: str): def invoke(*args): diff --git a/test/Dialect/Torch/adjust-calling-conventions.mlir b/test/Dialect/Torch/adjust-calling-conventions.mlir index 9a8152e8b..ae38b5a34 100644 --- a/test/Dialect/Torch/adjust-calling-conventions.mlir +++ b/test/Dialect/Torch/adjust-calling-conventions.mlir @@ -47,3 +47,53 @@ func @none_call_return() { "test.use"(%0) : (!torch.none) -> () return } + +// CHECK-LABEL: func @tuple_return( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { +// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor +// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor +// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : +// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : +// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : +// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor +// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor +func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, + %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { + %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple + return %1 : !torch.tuple +} + +// CHECK-LABEL: func @call_tuple_return( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { +// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor +// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor +// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor +// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> +// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> +// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> +// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> +// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : +// CHECK-SAME: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) +// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : +// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : +// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : +// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor +// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor +func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, + %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { + %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple + return %0 : !torch.tuple +} diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 31d61f223..57e97bea1 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -568,3 +568,29 @@ func @torch.tensor_static_info_cast$upcast_first(%t: !torch.tensor<[?,?],f64>) - %downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64> return %downcast: !torch.tensor<[?,?],f64> } + +// CHECK-LABEL: func @torch.prim.TupleIndex( +// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor { +// CHECK: return %[[T1]] : !torch.tensor +func @torch.prim.TupleIndex(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor { + %0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple + %int1 = torch.constant.int 1 + %1 = torch.prim.TupleIndex %0, %int1 : !torch.tuple, !torch.int -> !torch.tensor + return %1 : !torch.tensor +} + +// CHECK-LABEL: func @torch.prim.TupleIndex$out_of_bound( +// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor { +// CHECK: %[[INDEX3:.*]] = torch.constant.int 3 +// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]], %[[T2]] : +// CHECK-SAME: !torch.tensor, !torch.tensor, !torch.tensor -> +// CHECK-SAME: !torch.tuple +// CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INDEX3]] : +// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor +// CHECK: return %[[RET]] : !torch.tensor +func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor { + %0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple + %int3 = torch.constant.int 3 + %1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple, !torch.int -> !torch.tensor + return %1 : !torch.tensor +} diff --git a/test/Dialect/Torch/refine-public-return.mlir b/test/Dialect/Torch/refine-public-return.mlir index b4f5ca5f9..698b6f291 100644 --- a/test/Dialect/Torch/refine-public-return.mlir +++ b/test/Dialect/Torch/refine-public-return.mlir @@ -23,8 +23,6 @@ func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor return %2 : !torch.tensor } - - // ----- // Call to public function. diff --git a/test/RefBackend/munge-calling-conventions.mlir b/test/RefBackend/munge-calling-conventions.mlir index 30dc8cb72..d37094d27 100644 --- a/test/RefBackend/munge-calling-conventions.mlir +++ b/test/RefBackend/munge-calling-conventions.mlir @@ -4,7 +4,7 @@ // CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} { // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref // CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref to memref<*xf32> -// CHECK: call @refbackend_consume_memref_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> () +// CHECK: call @refbackend_consume_func_return_mrf32(%[[RESULT]]) : (memref<*xf32>) -> () // CHECK: return func @f(%arg0: memref) -> memref { return %arg0 : memref @@ -16,7 +16,7 @@ func @f(%arg0: memref) -> memref { // CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} { // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref // CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref to memref<*xi64> -// CHECK: call @refbackend_consume_memref_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> () +// CHECK: call @refbackend_consume_func_return_mri64(%[[RESULT]]) : (memref<*xi64>) -> () // CHECK: return func @i(%arg0: memref) -> memref { return %arg0 : memref @@ -28,9 +28,28 @@ func @i(%arg0: memref) -> memref { // CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} { // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref // CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref -// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (i64) -> () +// CHECK: call @refbackend_consume_func_return_i64(%[[RESULT]]) : (i64) -> () // CHECK: return func @elemental_type(%arg0: memref) -> i64 { %0 = memref.load %arg0[] : memref return %0 : i64 } + +// ----- + +// CHECK-LABEL: func @multiple_return_values( +// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xf32>, +// CHECK-SAME: %[[ARG2:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} { +// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref +// CHECK: %[[VAL1:.*]] = memref.cast %[[ARG1]] : memref<*xf32> to memref +// CHECK: %[[VAL2:.*]] = memref.cast %[[ARG2]] : memref<*xf32> to memref +// CHECK: %[[RET0:.*]] = memref.cast %[[VAL0]] : memref to memref<*xf32> +// CHECK: %[[RET1:.*]] = memref.cast %[[VAL1]] : memref to memref<*xf32> +// CHECK: %[[RET2:.*]] = memref.cast %[[VAL2]] : memref to memref<*xf32> +// CHECK: call @refbackend_consume_func_return_mrf32_mrf32_mrf32(%[[RET0]], %[[RET1]], %[[RET2]]) +// CHECK-SAME: : (memref<*xf32>, memref<*xf32>, memref<*xf32>) -> () +// CHECK: return + +func @multiple_return_values(%arg0: memref, %arg1: memref, %arg2: memref) -> (memref, memref, memref) { + return %arg0 ,%arg1, %arg2 : memref, memref, memref +}