mirror of https://github.com/llvm/torch-mlir
Add support for multiple return values
This change is to unblock the work of some backprop ops returning more than one tensors. We will need to think of a more scalable approach in the future if more flexible return types combinations are needed.pull/418/head
parent
6e8d39642e
commit
0fe70994e5
|
@ -11,11 +11,11 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxBackwardModule(torch.nn.Module):
|
class SoftmaxBackwardModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
None,
|
None,
|
||||||
|
@ -33,6 +33,8 @@ class SoftmaxBackwardModule(torch.nn.Module):
|
||||||
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
|
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
|
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
class TanhBackwardModule(torch.nn.Module):
|
class TanhBackwardModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -43,10 +45,11 @@ class TanhBackwardModule(torch.nn.Module):
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
([-1, -1], torch.float32, True),
|
([-1, -1], torch.float32, True),
|
||||||
])
|
])
|
||||||
|
def forward(self, grad_out, output):
|
||||||
|
return torch.ops.aten.tanh_backward(grad_out, output)
|
||||||
|
|
||||||
def forward(self, out_grad, output):
|
|
||||||
return torch.ops.aten.tanh_backward(out_grad, output)
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: TanhBackwardModule())
|
@register_test_case(module_factory=lambda: TanhBackwardModule())
|
||||||
def TanhBackward_basic(module, tu: TestUtils):
|
def TanhBackward_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 3), torch.randn(3, 3))
|
module.forward(torch.randn(3, 3), torch.randn(3, 3))
|
||||||
|
|
||||||
|
|
|
@ -577,3 +577,24 @@ class NumToTensorModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: NumToTensorModule())
|
@register_test_case(module_factory=lambda: NumToTensorModule())
|
||||||
def NumToTensorModule_basic(module, tu: TestUtils):
|
def NumToTensorModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
# This test can be removed once we have one real op returning 3 float32 tensors
|
||||||
|
class ReturnThreeTensorFloat32(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b, c):
|
||||||
|
return a, b, c
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReturnThreeTensorFloat32())
|
||||||
|
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))
|
||||||
|
|
||||||
|
|
|
@ -29,4 +29,5 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseFloorModule_basic",
|
"ElementwiseFloorModule_basic",
|
||||||
"ElementwiseLogModule_basic",
|
"ElementwiseLogModule_basic",
|
||||||
"TanhBackward_basic",
|
"TanhBackward_basic",
|
||||||
|
"ReturnThreeTensorFloat32_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,6 +43,7 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
|
||||||
AnyTorchType:$result
|
AnyTorchType:$result
|
||||||
);
|
);
|
||||||
let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)";
|
let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)";
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_PrimDeviceOp : Torch_Op<"prim.device", [
|
def Torch_PrimDeviceOp : Torch_Op<"prim.device", [
|
||||||
|
|
|
@ -121,8 +121,8 @@ def AdjustCallingConventions
|
||||||
function arguments, which should be `!numpy.ndarray<...>`'s.
|
function arguments, which should be `!numpy.ndarray<...>`'s.
|
||||||
- Python-isms are rewritten to MLIR-isms
|
- Python-isms are rewritten to MLIR-isms
|
||||||
- NoneType return is rewritten to the absence of a return value.
|
- NoneType return is rewritten to the absence of a return value.
|
||||||
- (Not implemented yet) Tuple return is rewritten to multiple return
|
- Tuple return is rewritten to multiple return values.
|
||||||
values
|
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -916,6 +916,29 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PrimTupleIndexOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
MLIRContext *context) {
|
||||||
|
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
|
||||||
|
auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
||||||
|
if (!tupleConstruct)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t i;
|
||||||
|
if (!matchPattern(op.i(), m_TorchConstantInt(&i)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (i >= (int64_t)tupleConstruct.elements().size())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, tupleConstruct.elements()[i]);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// PrimTupleUnpackOp
|
// PrimTupleUnpackOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -923,9 +946,7 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
||||||
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
|
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
|
||||||
auto torchTuple = op.tup();
|
auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
||||||
auto tupleConstruct =
|
|
||||||
torchTuple.getDefiningOp<Torch::PrimTupleConstructOp>();
|
|
||||||
if (!tupleConstruct)
|
if (!tupleConstruct)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
|
|
@ -66,15 +66,20 @@ public:
|
||||||
// TODO: add tuple type.
|
// TODO: add tuple type.
|
||||||
conversion.addInputs(type.index(), type.value());
|
conversion.addInputs(type.index(), type.value());
|
||||||
}
|
}
|
||||||
|
rewriter.applySignatureConversion(&func.getBody(), conversion,
|
||||||
|
typeConverter);
|
||||||
|
|
||||||
SmallVector<Type> newResultTypes;
|
SmallVector<Type> newResultTypes;
|
||||||
for (auto type : func.getType().getResults()) {
|
for (auto type : func.getType().getResults()) {
|
||||||
if (auto none = type.dyn_cast<Torch::NoneType>()) {
|
if (auto none = type.dyn_cast<Torch::NoneType>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
|
||||||
|
llvm::append_range(newResultTypes, tuple.getContainedTypes());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
newResultTypes.push_back(type);
|
newResultTypes.push_back(type);
|
||||||
}
|
}
|
||||||
rewriter.applySignatureConversion(&func.getBody(), conversion,
|
|
||||||
typeConverter);
|
|
||||||
rewriter.updateRootInPlace(func, [&] {
|
rewriter.updateRootInPlace(func, [&] {
|
||||||
func.setType(FunctionType::get(
|
func.setType(FunctionType::get(
|
||||||
getContext(), conversion.getConvertedTypes(), newResultTypes));
|
getContext(), conversion.getConvertedTypes(), newResultTypes));
|
||||||
|
@ -131,6 +136,11 @@ public:
|
||||||
rewriter.create<ConstantNoneOp>(call.getLoc(), type));
|
rewriter.create<ConstantNoneOp>(call.getLoc(), type));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (type.isa<Torch::TupleType>()) {
|
||||||
|
newResults.push_back(rewriter.create<PrimTupleConstructOp>(
|
||||||
|
call.getLoc(), type, newCall.getResults()));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
newResults.push_back(newCall.getResult(newOpResultIdx++));
|
newResults.push_back(newCall.getResult(newOpResultIdx++));
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(call, newResults);
|
rewriter.replaceOp(call, newResults);
|
||||||
|
@ -151,12 +161,22 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
SmallVector<Value> newOperands;
|
SmallVector<Value> newOperands;
|
||||||
for (auto operand : llvm::enumerate(adaptor.getOperands())) {
|
for (auto operand : adaptor.getOperands()) {
|
||||||
if (!operand.value())
|
if (!operand)
|
||||||
continue;
|
continue;
|
||||||
if (operand.value().getType().isa<Torch::NoneType>())
|
if (operand.getType().isa<Torch::NoneType>())
|
||||||
continue;
|
continue;
|
||||||
newOperands.push_back(operand.value());
|
if (auto tuple = operand.getType().dyn_cast<Torch::TupleType>()) {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
|
||||||
|
auto i = rewriter.create<ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(en.index()));
|
||||||
|
newOperands.push_back(
|
||||||
|
rewriter.create<PrimTupleIndexOp>(loc, en.value(), operand, i));
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
newOperands.push_back(operand);
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
|
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
|
||||||
return success();
|
return success();
|
||||||
|
@ -168,9 +188,14 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
||||||
TypeBoundMap &typeBoundMap) {
|
TypeBoundMap &typeBoundMap) {
|
||||||
MLIRContext *context = func.getContext();
|
MLIRContext *context = func.getContext();
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
// TODO: TupleTypes
|
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
typeConverter.addConversion(
|
||||||
|
[](Torch::TupleType type,
|
||||||
|
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
|
||||||
|
llvm::append_range(types, type.getContainedTypes());
|
||||||
|
return success();
|
||||||
|
});
|
||||||
typeConverter.addConversion(
|
typeConverter.addConversion(
|
||||||
[](Torch::NoneType type,
|
[](Torch::NoneType type,
|
||||||
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
|
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
|
||||||
|
@ -220,6 +245,9 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
||||||
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
|
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
|
||||||
target.addLegalOp<TensorStaticInfoCastOp>();
|
target.addLegalOp<TensorStaticInfoCastOp>();
|
||||||
target.addLegalOp<ConstantNoneOp>();
|
target.addLegalOp<ConstantNoneOp>();
|
||||||
|
target.addLegalOp<ConstantIntOp>();
|
||||||
|
target.addLegalOp<PrimTupleIndexOp>();
|
||||||
|
target.addLegalOp<PrimTupleConstructOp>();
|
||||||
// We don't know how to rewrite it, so mark it as illegal.
|
// We don't know how to rewrite it, so mark it as illegal.
|
||||||
target.addIllegalOp<CallIndirectOp>();
|
target.addIllegalOp<CallIndirectOp>();
|
||||||
if (failed(applyPartialConversion(func.getOperation(), target,
|
if (failed(applyPartialConversion(func.getOperation(), target,
|
||||||
|
|
|
@ -89,6 +89,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
||||||
pm.addPass(createAdjustCallingConventionsPass());
|
pm.addPass(createAdjustCallingConventionsPass());
|
||||||
|
|
||||||
if (options.optimize) {
|
if (options.optimize) {
|
||||||
|
// Eliminate the PrimTupleIndexOp generated from the
|
||||||
|
// adjustCallingConventions
|
||||||
|
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||||
// Inline global slots, which for most inference scenarios deletes them.
|
// Inline global slots, which for most inference scenarios deletes them.
|
||||||
// This also exposes more information to intraprocedural transformations
|
// This also exposes more information to intraprocedural transformations
|
||||||
// below like MaximizeValueSemantics and RefineTypes.
|
// below like MaximizeValueSemantics and RefineTypes.
|
||||||
|
|
|
@ -21,7 +21,9 @@
|
||||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "set"
|
||||||
#include "torch-mlir/RefBackend/Passes.h"
|
#include "torch-mlir/RefBackend/Passes.h"
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -67,33 +69,47 @@ static Type getAbiTypeForMemRef(Type type) {
|
||||||
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
|
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Passes the return op operands `val` to `funOp`. Also, adds the op to the
|
// Helper function to get the type string for one return value like i32, f64,
|
||||||
// `toErase` vector.
|
// mri32 etc. The strings from multiple return values are concatenated to get
|
||||||
static void replaceCallToFunction(OpBuilder b, ReturnOp op, FuncOp funcOp,
|
// the consumeFuncReturnFunc name.
|
||||||
Value val,
|
static std::string getTypeToken(Type type) {
|
||||||
|
if (type.isSignlessInteger())
|
||||||
|
return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
|
||||||
|
else if (type.isa<mlir::FloatType>())
|
||||||
|
return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
|
||||||
|
else if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
|
||||||
|
return "mr" + getTypeToken(memRefType.getElementType());
|
||||||
|
|
||||||
|
llvm_unreachable(
|
||||||
|
"Type token should handle all types: memref, float and int type");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Systematically derive the consumeFuncReturnFunc name from return value types.
|
||||||
|
static std::string getConsumeReturnFunctionNameForReturnTypes(TypeRange types) {
|
||||||
|
SmallVector<std::string> tokens = {"refbackend_consume_func_return"};
|
||||||
|
for (auto type : types)
|
||||||
|
tokens.push_back(getTypeToken(type));
|
||||||
|
|
||||||
|
return std::accumulate(tokens.begin(), tokens.end(), std::string(),
|
||||||
|
[](std::string &a, std::string &b) {
|
||||||
|
return a.empty() ? b : (a + "_" + b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace the original returnOp with a call to consumeFuncReturnFunc and add
|
||||||
|
// the op to the `toErase` vector.
|
||||||
|
static void replaceReturnWithCall(OpBuilder b, ReturnOp op, StringRef funcName,
|
||||||
|
TypeRange retTypes,
|
||||||
|
SmallVectorImpl<Value> &vals,
|
||||||
SmallVectorImpl<Operation *> &toErase) {
|
SmallVectorImpl<Operation *> &toErase) {
|
||||||
b.create<mlir::CallOp>(op.getLoc(), funcOp, val);
|
b.create<mlir::CallOp>(op.getLoc(), funcName, TypeRange({}), vals);
|
||||||
b.create<mlir::ReturnOp>(op.getLoc());
|
b.create<mlir::ReturnOp>(op.getLoc());
|
||||||
toErase.push_back(op);
|
toErase.push_back(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks whether the return op is munge-compatible and the respective calling
|
|
||||||
// function is defined.
|
|
||||||
static bool isReturnOpCompatible(ReturnOp op,
|
|
||||||
DenseMap<Type, FuncOp> &consumeFuncReturnFuncs,
|
|
||||||
Type returnType) {
|
|
||||||
auto it = consumeFuncReturnFuncs.find(returnType);
|
|
||||||
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
|
|
||||||
op.emitError("must have one return value of Memref type or Elemental types "
|
|
||||||
"of i64, f64, f32");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult mungeFunction(
|
static LogicalResult mungeFunction(
|
||||||
FuncOp func,
|
FuncOp func, std::set<std::string> &supportedConsumeFuncReturnFuncs,
|
||||||
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) {
|
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
|
||||||
// Add `llvm.emit_c_interface`.
|
// Add `llvm.emit_c_interface`.
|
||||||
// This allows ExecutionEngine to resolve the symbol properly.
|
// This allows ExecutionEngine to resolve the symbol properly.
|
||||||
addEmitCInterfaceAttr(func);
|
addEmitCInterfaceAttr(func);
|
||||||
|
@ -120,37 +136,43 @@ static LogicalResult mungeFunction(
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Operation *> toErase;
|
SmallVector<Operation *> toErase;
|
||||||
bool isCompatible = false;
|
bool isSupported = true;
|
||||||
func.walk([&](ReturnOp op) {
|
func.walk([&](ReturnOp op) {
|
||||||
auto returnType = op.getOperandTypes()[0];
|
auto types = op.getOperandTypes();
|
||||||
|
|
||||||
b.setInsertionPoint(op);
|
b.setInsertionPoint(op);
|
||||||
// Memref Types.
|
// Memref Types.
|
||||||
if (auto memrefReturnType = returnType.dyn_cast<MemRefType>()) {
|
std::vector<Type> retTypes;
|
||||||
auto elemType = memrefReturnType.getElementType();
|
SmallVector<Value> retVals;
|
||||||
auto unRankedType = UnrankedMemRefType::get(elemType, 0);
|
for (auto en : llvm::enumerate(types)) {
|
||||||
isCompatible =
|
Type retType = en.value();
|
||||||
isReturnOpCompatible(op, consumeFuncReturnFuncs, unRankedType);
|
Value retVal = op.getOperand(en.index());
|
||||||
if (!isCompatible)
|
if (auto memrefReturnType = retType.dyn_cast<MemRefType>()) {
|
||||||
return;
|
auto elemType = memrefReturnType.getElementType();
|
||||||
|
retType = UnrankedMemRefType::get(elemType, 0);
|
||||||
// Cast to unranked memref type before sending it as a function argument.
|
// Cast to unranked memref type before sending it as a function
|
||||||
auto cast = b.create<memref::CastOp>(
|
// argument.
|
||||||
op.getLoc(), op.getOperand(0),
|
retVal = b.create<memref::CastOp>(
|
||||||
getAbiTypeForMemRef(op.getOperandTypes()[0]));
|
op.getLoc(), retVal, getAbiTypeForMemRef(types[en.index()]));
|
||||||
replaceCallToFunction(b, op, consumeFuncReturnFuncs[unRankedType],
|
}
|
||||||
cast.getResult(), toErase);
|
retTypes.push_back(retType);
|
||||||
// Elemental types.
|
retVals.push_back(retVal);
|
||||||
} else if (returnType.isa<IntegerType>() || returnType.isa<FloatType>()) {
|
|
||||||
isCompatible =
|
|
||||||
isReturnOpCompatible(op, consumeFuncReturnFuncs, returnType);
|
|
||||||
if (!isCompatible)
|
|
||||||
return;
|
|
||||||
replaceCallToFunction(b, op, consumeFuncReturnFuncs[returnType],
|
|
||||||
op->getOperand(0), toErase);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto supportedFuncsEnd = supportedConsumeFuncReturnFuncs.end();
|
||||||
|
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
|
||||||
|
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
|
||||||
|
op.emitError(
|
||||||
|
"must have one return value of memref types or scalar types "
|
||||||
|
"of i32, i64, f32, f64 or three return values of memref f32");
|
||||||
|
isSupported = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto invokedFuncsEnd = invokedConsumeFuncReturnFuncs.end();
|
||||||
|
if (invokedConsumeFuncReturnFuncs.find(funcName) == invokedFuncsEnd)
|
||||||
|
invokedConsumeFuncReturnFuncs.insert({funcName, retTypes});
|
||||||
|
replaceReturnWithCall(b, op, funcName, retTypes, retVals, toErase);
|
||||||
});
|
});
|
||||||
if (!isCompatible)
|
if (!isSupported)
|
||||||
return failure();
|
return failure();
|
||||||
func.setType(FunctionType::get(func.getContext(), newArgTypes, {}));
|
func.setType(FunctionType::get(func.getContext(), newArgTypes, {}));
|
||||||
for (Operation *op : toErase)
|
for (Operation *op : toErase)
|
||||||
|
@ -158,50 +180,47 @@ static LogicalResult mungeFunction(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
|
||||||
|
std::set<std::string> funcNames;
|
||||||
|
Type mri32 = UnrankedMemRefType::get(b.getI32Type(), 0);
|
||||||
|
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
|
||||||
|
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
|
||||||
|
Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0);
|
||||||
|
Type i64 = b.getI64Type();
|
||||||
|
Type f32 = b.getF32Type();
|
||||||
|
Type f64 = b.getF64Type();
|
||||||
|
|
||||||
|
SmallVector<TypeRange> supportedReturnTypes = {
|
||||||
|
mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}};
|
||||||
|
|
||||||
|
llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
|
||||||
|
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));
|
||||||
|
});
|
||||||
|
return funcNames;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class MungeCallingConventions
|
class MungeCallingConventions
|
||||||
: public MungeCallingConventionsBase<MungeCallingConventions> {
|
: public MungeCallingConventionsBase<MungeCallingConventions> {
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
OpBuilder b(module.getBodyRegion());
|
OpBuilder b(module.getBodyRegion());
|
||||||
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
|
static std::set<std::string> supported =
|
||||||
DenseSet<FuncOp> consumeFuncReturnFuncsSet;
|
getSupportedConsumeFuncReturnFuncs(b);
|
||||||
auto createConsumeFuncReturnFunc = [&](Type returnType,
|
std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
|
||||||
std::string funcName) {
|
|
||||||
auto consumeFuncReturnFunc = b.create<FuncOp>(
|
|
||||||
module.getLoc(), funcName,
|
|
||||||
FunctionType::get(module.getContext(), returnType, {}),
|
|
||||||
b.getStringAttr("private"));
|
|
||||||
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
|
||||||
consumeFuncReturnFuncs[returnType] = consumeFuncReturnFunc;
|
|
||||||
consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Memref return types.
|
|
||||||
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0),
|
|
||||||
"refbackend_consume_memref_int32_func_return");
|
|
||||||
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
|
|
||||||
"refbackend_consume_memref_int64_func_return");
|
|
||||||
createConsumeFuncReturnFunc(
|
|
||||||
UnrankedMemRefType::get(b.getF32Type(), 0),
|
|
||||||
"refbackend_consume_memref_float32_func_return");
|
|
||||||
createConsumeFuncReturnFunc(
|
|
||||||
UnrankedMemRefType::get(b.getF64Type(), 0),
|
|
||||||
"refbackend_consume_memref_float64_func_return");
|
|
||||||
|
|
||||||
// Elemental return types.
|
|
||||||
createConsumeFuncReturnFunc(b.getI64Type(),
|
|
||||||
"refbackend_consume_int64_func_return");
|
|
||||||
createConsumeFuncReturnFunc(b.getF32Type(),
|
|
||||||
"refbackend_consume_float32_func_return");
|
|
||||||
createConsumeFuncReturnFunc(b.getF64Type(),
|
|
||||||
"refbackend_consume_float64_func_return");
|
|
||||||
for (auto func : module.getOps<FuncOp>()) {
|
for (auto func : module.getOps<FuncOp>()) {
|
||||||
if (consumeFuncReturnFuncsSet.contains(func))
|
if (failed(mungeFunction(func, supported, invokedConsumeFuncReturnFuncs)))
|
||||||
continue;
|
|
||||||
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
|
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create FuncOp for consumeFuncReturnFuncs that are used.
|
||||||
|
for (auto &p : invokedConsumeFuncReturnFuncs) {
|
||||||
|
auto consumeFuncReturnFunc =
|
||||||
|
b.create<FuncOp>(module.getLoc(), p.first,
|
||||||
|
FunctionType::get(module.getContext(), p.second, {}),
|
||||||
|
b.getStringAttr("private"));
|
||||||
|
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -399,7 +399,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit_op(registry[key], f, **kwargs)
|
emit_op(registry[key], f, **kwargs)
|
||||||
|
|
||||||
emit("prim::layout : (Tensor) -> (int)")
|
emit("prim::layout : (Tensor) -> (int)")
|
||||||
emit("prim::TupleIndex : (Any, int) -> (Any)")
|
emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True)
|
||||||
emit("prim::device : (Tensor) -> (Device)")
|
emit("prim::device : (Tensor) -> (Device)")
|
||||||
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
|
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
|
||||||
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
|
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
|
||||||
|
|
|
@ -27,59 +27,73 @@ def checkArgTypeIsSupported(ty):
|
||||||
SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
|
SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
|
||||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
|
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
|
||||||
|
|
||||||
|
|
||||||
class RefBackendInvoker:
|
class RefBackendInvoker:
|
||||||
def __init__(self, module):
|
def __init__(self, module):
|
||||||
self.ee = ExecutionEngine(module)
|
self.ee = ExecutionEngine(module)
|
||||||
self.result = None
|
self.result = None
|
||||||
|
|
||||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||||
def consume_memref_i32_return(a):
|
def consume_return_mri32(a):
|
||||||
self.result = unranked_memref_to_numpy(a, np.int32)
|
self.result = unranked_memref_to_numpy(a, np.int32)
|
||||||
|
|
||||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||||
def consume_memref_i64_return(a):
|
def consume_return_mri64(a):
|
||||||
self.result = unranked_memref_to_numpy(a, np.int64)
|
self.result = unranked_memref_to_numpy(a, np.int64)
|
||||||
|
|
||||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||||
def consume_memref_f32_return(a):
|
def consume_return_mrf32(a):
|
||||||
self.result = unranked_memref_to_numpy(a, np.float32)
|
self.result = unranked_memref_to_numpy(a, np.float32)
|
||||||
|
|
||||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||||
def consume_memref_f64_return(a):
|
def consume_return_mrf64(a):
|
||||||
self.result = unranked_memref_to_numpy(a, np.float64)
|
self.result = unranked_memref_to_numpy(a, np.float64)
|
||||||
|
|
||||||
@ctypes.CFUNCTYPE(None, ctypes.c_int)
|
@ctypes.CFUNCTYPE(None, ctypes.c_int)
|
||||||
def consume_i64_return(a):
|
def consume_return_i64(a):
|
||||||
self.result = a
|
self.result = a
|
||||||
|
|
||||||
@ctypes.CFUNCTYPE(None, ctypes.c_float)
|
@ctypes.CFUNCTYPE(None, ctypes.c_float)
|
||||||
def consume_f32_return(a):
|
def consume_return_f32(a):
|
||||||
self.result = a
|
self.result = a
|
||||||
|
|
||||||
@ctypes.CFUNCTYPE(None, ctypes.c_double)
|
@ctypes.CFUNCTYPE(None, ctypes.c_double)
|
||||||
def consume_f64_return(a):
|
def consume_return_f64(a):
|
||||||
self.result = a
|
self.result = a
|
||||||
|
|
||||||
self.ee.register_runtime("refbackend_consume_memref_int32_func_return",
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
|
||||||
consume_memref_i32_return)
|
ctypes.POINTER(UnrankedMemRefDescriptor),
|
||||||
|
ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||||
|
def consume_return_mrf32_mrf32_mrf32(arg0, arg1, arg2):
|
||||||
|
self.result = unranked_memref_to_numpy(
|
||||||
|
arg0, np.float32), unranked_memref_to_numpy(
|
||||||
|
arg1,
|
||||||
|
np.float32), unranked_memref_to_numpy(arg2, np.float32)
|
||||||
|
|
||||||
self.ee.register_runtime("refbackend_consume_memref_int64_func_return",
|
self.ee.register_runtime("refbackend_consume_func_return_mri32",
|
||||||
consume_memref_i64_return)
|
consume_return_mri32)
|
||||||
|
|
||||||
self.ee.register_runtime("refbackend_consume_memref_float32_func_return",
|
self.ee.register_runtime("refbackend_consume_func_return_mri64",
|
||||||
consume_memref_f32_return)
|
consume_return_mri64)
|
||||||
|
|
||||||
self.ee.register_runtime("refbackend_consume_memref_float64_func_return",
|
self.ee.register_runtime("refbackend_consume_func_return_mrf32",
|
||||||
consume_memref_f64_return)
|
consume_return_mrf32)
|
||||||
|
|
||||||
self.ee.register_runtime("refbackend_consume_int64_func_return",
|
self.ee.register_runtime("refbackend_consume_func_return_mrf64",
|
||||||
consume_i64_return)
|
consume_return_mrf64)
|
||||||
|
|
||||||
self.ee.register_runtime("refbackend_consume_float32_func_return",
|
self.ee.register_runtime("refbackend_consume_func_return_i64",
|
||||||
consume_f32_return)
|
consume_return_i64)
|
||||||
|
|
||||||
self.ee.register_runtime("refbackend_consume_float64_func_return",
|
self.ee.register_runtime("refbackend_consume_func_return_f32",
|
||||||
consume_f64_return)
|
consume_return_f32)
|
||||||
|
|
||||||
|
self.ee.register_runtime("refbackend_consume_func_return_f64",
|
||||||
|
consume_return_f64)
|
||||||
|
|
||||||
|
self.ee.register_runtime(
|
||||||
|
"refbackend_consume_func_return_mrf32_mrf32_mrf32",
|
||||||
|
consume_return_mrf32_mrf32_mrf32)
|
||||||
|
|
||||||
def __getattr__(self, function_name: str):
|
def __getattr__(self, function_name: str):
|
||||||
def invoke(*args):
|
def invoke(*args):
|
||||||
|
|
|
@ -47,3 +47,53 @@ func @none_call_return() {
|
||||||
"test.use"(%0) : (!torch.none) -> ()
|
"test.use"(%0) : (!torch.none) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tuple_return(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
|
||||||
|
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||||
|
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
|
||||||
|
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||||
|
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
|
||||||
|
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] :
|
||||||
|
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||||
|
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
|
||||||
|
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||||
|
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
|
||||||
|
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||||
|
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
|
||||||
|
func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
|
||||||
|
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
|
||||||
|
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||||
|
return %1 : !torch.tuple<!torch.tensor, !torch.tensor>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @call_tuple_return(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
|
||||||
|
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||||
|
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
|
||||||
|
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||||
|
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
|
||||||
|
// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
|
||||||
|
// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
|
||||||
|
// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
|
||||||
|
// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
|
||||||
|
// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) :
|
||||||
|
// CHECK-SAME: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor)
|
||||||
|
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 :
|
||||||
|
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||||
|
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
|
||||||
|
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||||
|
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
|
||||||
|
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||||
|
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
|
||||||
|
func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
|
||||||
|
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
|
||||||
|
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||||
|
return %0 : !torch.tuple<!torch.tensor, !torch.tensor>
|
||||||
|
}
|
||||||
|
|
|
@ -568,3 +568,29 @@ func @torch.tensor_static_info_cast$upcast_first(%t: !torch.tensor<[?,?],f64>) -
|
||||||
%downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64>
|
%downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64>
|
||||||
return %downcast: !torch.tensor<[?,?],f64>
|
return %downcast: !torch.tensor<[?,?],f64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.prim.TupleIndex(
|
||||||
|
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
|
||||||
|
// CHECK: return %[[T1]] : !torch.tensor
|
||||||
|
func @torch.prim.TupleIndex(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
|
||||||
|
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%1 = torch.prim.TupleIndex %0, %int1 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||||
|
return %1 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.prim.TupleIndex$out_of_bound(
|
||||||
|
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
|
||||||
|
// CHECK: %[[INDEX3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]], %[[T2]] :
|
||||||
|
// CHECK-SAME: !torch.tensor, !torch.tensor, !torch.tensor ->
|
||||||
|
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
|
||||||
|
// CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INDEX3]] :
|
||||||
|
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||||
|
// CHECK: return %[[RET]] : !torch.tensor
|
||||||
|
func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
|
||||||
|
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||||
|
return %1 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
|
@ -23,8 +23,6 @@ func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor
|
||||||
return %2 : !torch.tensor
|
return %2 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Call to public function.
|
// Call to public function.
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
||||||
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
|
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
|
||||||
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
|
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
|
||||||
// CHECK: call @refbackend_consume_memref_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
|
// CHECK: call @refbackend_consume_func_return_mrf32(%[[RESULT]]) : (memref<*xf32>) -> ()
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
|
func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
return %arg0 : memref<?xf32>
|
return %arg0 : memref<?xf32>
|
||||||
|
@ -16,7 +16,7 @@ func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
|
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
|
||||||
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
|
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
|
||||||
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
|
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
|
||||||
// CHECK: call @refbackend_consume_memref_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> ()
|
// CHECK: call @refbackend_consume_func_return_mri64(%[[RESULT]]) : (memref<*xi64>) -> ()
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
|
func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
|
||||||
return %arg0 : memref<?xi64>
|
return %arg0 : memref<?xi64>
|
||||||
|
@ -28,9 +28,28 @@ func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
|
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
|
||||||
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64>
|
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64>
|
||||||
// CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64>
|
// CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64>
|
||||||
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (i64) -> ()
|
// CHECK: call @refbackend_consume_func_return_i64(%[[RESULT]]) : (i64) -> ()
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
func @elemental_type(%arg0: memref<i64>) -> i64 {
|
func @elemental_type(%arg0: memref<i64>) -> i64 {
|
||||||
%0 = memref.load %arg0[] : memref<i64>
|
%0 = memref.load %arg0[] : memref<i64>
|
||||||
return %0 : i64
|
return %0 : i64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @multiple_return_values(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xf32>,
|
||||||
|
// CHECK-SAME: %[[ARG2:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
||||||
|
// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
|
||||||
|
// CHECK: %[[VAL1:.*]] = memref.cast %[[ARG1]] : memref<*xf32> to memref<?xf32>
|
||||||
|
// CHECK: %[[VAL2:.*]] = memref.cast %[[ARG2]] : memref<*xf32> to memref<?xf32>
|
||||||
|
// CHECK: %[[RET0:.*]] = memref.cast %[[VAL0]] : memref<?xf32> to memref<*xf32>
|
||||||
|
// CHECK: %[[RET1:.*]] = memref.cast %[[VAL1]] : memref<?xf32> to memref<*xf32>
|
||||||
|
// CHECK: %[[RET2:.*]] = memref.cast %[[VAL2]] : memref<?xf32> to memref<*xf32>
|
||||||
|
// CHECK: call @refbackend_consume_func_return_mrf32_mrf32_mrf32(%[[RET0]], %[[RET1]], %[[RET2]])
|
||||||
|
// CHECK-SAME: : (memref<*xf32>, memref<*xf32>, memref<*xf32>) -> ()
|
||||||
|
// CHECK: return
|
||||||
|
|
||||||
|
func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
|
||||||
|
return %arg0 ,%arg1, %arg2 : memref<?xf32>, memref<?xf32>, memref<?xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue