Add support for multiple return values

This change is to unblock the work of some backprop ops returning more
than one tensors. We will need to think of a more scalable approach
in the future if more flexible return types combinations are needed.
pull/418/head
Yi Zhang 2021-11-08 10:56:40 -05:00
parent 6e8d39642e
commit 0fe70994e5
15 changed files with 334 additions and 130 deletions

View File

@ -11,11 +11,11 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================
class SoftmaxBackwardModule(torch.nn.Module): class SoftmaxBackwardModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@export @export
@annotate_args([ @annotate_args([
None, None,
@ -33,6 +33,8 @@ class SoftmaxBackwardModule(torch.nn.Module):
def SoftmaxBackwardModule_basic(module, tu: TestUtils): def SoftmaxBackwardModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4)) module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
# ==============================================================================
class TanhBackwardModule(torch.nn.Module): class TanhBackwardModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -43,10 +45,11 @@ class TanhBackwardModule(torch.nn.Module):
([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True),
]) ])
def forward(self, grad_out, output):
return torch.ops.aten.tanh_backward(grad_out, output)
def forward(self, out_grad, output):
return torch.ops.aten.tanh_backward(out_grad, output)
@register_test_case(module_factory=lambda: TanhBackwardModule()) @register_test_case(module_factory=lambda: TanhBackwardModule())
def TanhBackward_basic(module, tu: TestUtils): def TanhBackward_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3), torch.randn(3, 3)) module.forward(torch.randn(3, 3), torch.randn(3, 3))

View File

@ -577,3 +577,24 @@ class NumToTensorModule(torch.nn.Module):
@register_test_case(module_factory=lambda: NumToTensorModule()) @register_test_case(module_factory=lambda: NumToTensorModule())
def NumToTensorModule_basic(module, tu: TestUtils): def NumToTensorModule_basic(module, tu: TestUtils):
module.forward() module.forward()
# This test can be removed once we have one real op returning 3 float32 tensors
class ReturnThreeTensorFloat32(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b, c):
return a, b, c
@register_test_case(module_factory=lambda: ReturnThreeTensorFloat32())
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))

View File

@ -29,4 +29,5 @@ TOSA_PASS_SET = {
"ElementwiseFloorModule_basic", "ElementwiseFloorModule_basic",
"ElementwiseLogModule_basic", "ElementwiseLogModule_basic",
"TanhBackward_basic", "TanhBackward_basic",
"ReturnThreeTensorFloat32_basic",
} }

View File

@ -43,6 +43,7 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
AnyTorchType:$result AnyTorchType:$result
); );
let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)"; let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)";
let hasCanonicalizer = 1;
} }
def Torch_PrimDeviceOp : Torch_Op<"prim.device", [ def Torch_PrimDeviceOp : Torch_Op<"prim.device", [

View File

@ -121,8 +121,8 @@ def AdjustCallingConventions
function arguments, which should be `!numpy.ndarray<...>`'s. function arguments, which should be `!numpy.ndarray<...>`'s.
- Python-isms are rewritten to MLIR-isms - Python-isms are rewritten to MLIR-isms
- NoneType return is rewritten to the absence of a return value. - NoneType return is rewritten to the absence of a return value.
- (Not implemented yet) Tuple return is rewritten to multiple return - Tuple return is rewritten to multiple return values.
values
}]; }];
} }

View File

@ -916,6 +916,29 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
}); });
} }
//===----------------------------------------------------------------------===//
// PrimTupleIndexOp
//===----------------------------------------------------------------------===//
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
if (!tupleConstruct)
return failure();
int64_t i;
if (!matchPattern(op.i(), m_TorchConstantInt(&i)))
return failure();
if (i >= (int64_t)tupleConstruct.elements().size())
return failure();
rewriter.replaceOp(op, tupleConstruct.elements()[i]);
return success();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PrimTupleUnpackOp // PrimTupleUnpackOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -923,9 +946,7 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { MLIRContext *context) {
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) { patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
auto torchTuple = op.tup(); auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
auto tupleConstruct =
torchTuple.getDefiningOp<Torch::PrimTupleConstructOp>();
if (!tupleConstruct) if (!tupleConstruct)
return failure(); return failure();

View File

@ -66,15 +66,20 @@ public:
// TODO: add tuple type. // TODO: add tuple type.
conversion.addInputs(type.index(), type.value()); conversion.addInputs(type.index(), type.value());
} }
rewriter.applySignatureConversion(&func.getBody(), conversion,
typeConverter);
SmallVector<Type> newResultTypes; SmallVector<Type> newResultTypes;
for (auto type : func.getType().getResults()) { for (auto type : func.getType().getResults()) {
if (auto none = type.dyn_cast<Torch::NoneType>()) { if (auto none = type.dyn_cast<Torch::NoneType>()) {
continue; continue;
} }
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
llvm::append_range(newResultTypes, tuple.getContainedTypes());
continue;
}
newResultTypes.push_back(type); newResultTypes.push_back(type);
} }
rewriter.applySignatureConversion(&func.getBody(), conversion,
typeConverter);
rewriter.updateRootInPlace(func, [&] { rewriter.updateRootInPlace(func, [&] {
func.setType(FunctionType::get( func.setType(FunctionType::get(
getContext(), conversion.getConvertedTypes(), newResultTypes)); getContext(), conversion.getConvertedTypes(), newResultTypes));
@ -131,6 +136,11 @@ public:
rewriter.create<ConstantNoneOp>(call.getLoc(), type)); rewriter.create<ConstantNoneOp>(call.getLoc(), type));
continue; continue;
} }
if (type.isa<Torch::TupleType>()) {
newResults.push_back(rewriter.create<PrimTupleConstructOp>(
call.getLoc(), type, newCall.getResults()));
continue;
}
newResults.push_back(newCall.getResult(newOpResultIdx++)); newResults.push_back(newCall.getResult(newOpResultIdx++));
} }
rewriter.replaceOp(call, newResults); rewriter.replaceOp(call, newResults);
@ -151,12 +161,22 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newOperands; SmallVector<Value> newOperands;
for (auto operand : llvm::enumerate(adaptor.getOperands())) { for (auto operand : adaptor.getOperands()) {
if (!operand.value()) if (!operand)
continue; continue;
if (operand.value().getType().isa<Torch::NoneType>()) if (operand.getType().isa<Torch::NoneType>())
continue; continue;
newOperands.push_back(operand.value()); if (auto tuple = operand.getType().dyn_cast<Torch::TupleType>()) {
Location loc = op.getLoc();
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
auto i = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(en.index()));
newOperands.push_back(
rewriter.create<PrimTupleIndexOp>(loc, en.value(), operand, i));
}
continue;
}
newOperands.push_back(operand);
} }
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands); rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success(); return success();
@ -168,9 +188,14 @@ static LogicalResult adjustCallingConventions(FuncOp func,
TypeBoundMap &typeBoundMap) { TypeBoundMap &typeBoundMap) {
MLIRContext *context = func.getContext(); MLIRContext *context = func.getContext();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
// TODO: TupleTypes
TypeConverter typeConverter; TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
[](Torch::TupleType type,
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
llvm::append_range(types, type.getContainedTypes());
return success();
});
typeConverter.addConversion( typeConverter.addConversion(
[](Torch::NoneType type, [](Torch::NoneType type,
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> { SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
@ -220,6 +245,9 @@ static LogicalResult adjustCallingConventions(FuncOp func,
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>(); target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
target.addLegalOp<TensorStaticInfoCastOp>(); target.addLegalOp<TensorStaticInfoCastOp>();
target.addLegalOp<ConstantNoneOp>(); target.addLegalOp<ConstantNoneOp>();
target.addLegalOp<ConstantIntOp>();
target.addLegalOp<PrimTupleIndexOp>();
target.addLegalOp<PrimTupleConstructOp>();
// We don't know how to rewrite it, so mark it as illegal. // We don't know how to rewrite it, so mark it as illegal.
target.addIllegalOp<CallIndirectOp>(); target.addIllegalOp<CallIndirectOp>();
if (failed(applyPartialConversion(func.getOperation(), target, if (failed(applyPartialConversion(func.getOperation(), target,

View File

@ -89,6 +89,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
pm.addPass(createAdjustCallingConventionsPass()); pm.addPass(createAdjustCallingConventionsPass());
if (options.optimize) { if (options.optimize) {
// Eliminate the PrimTupleIndexOp generated from the
// adjustCallingConventions
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
// Inline global slots, which for most inference scenarios deletes them. // Inline global slots, which for most inference scenarios deletes them.
// This also exposes more information to intraprocedural transformations // This also exposes more information to intraprocedural transformations
// below like MaximizeValueSemantics and RefineTypes. // below like MaximizeValueSemantics and RefineTypes.

View File

@ -21,7 +21,9 @@
#include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "set"
#include "torch-mlir/RefBackend/Passes.h" #include "torch-mlir/RefBackend/Passes.h"
#include <numeric>
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -67,33 +69,47 @@ static Type getAbiTypeForMemRef(Type type) {
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0); return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
} }
// Passes the return op operands `val` to `funOp`. Also, adds the op to the // Helper function to get the type string for one return value like i32, f64,
// `toErase` vector. // mri32 etc. The strings from multiple return values are concatenated to get
static void replaceCallToFunction(OpBuilder b, ReturnOp op, FuncOp funcOp, // the consumeFuncReturnFunc name.
Value val, static std::string getTypeToken(Type type) {
if (type.isSignlessInteger())
return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
else if (type.isa<mlir::FloatType>())
return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
else if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
return "mr" + getTypeToken(memRefType.getElementType());
llvm_unreachable(
"Type token should handle all types: memref, float and int type");
}
// Systematically derive the consumeFuncReturnFunc name from return value types.
static std::string getConsumeReturnFunctionNameForReturnTypes(TypeRange types) {
SmallVector<std::string> tokens = {"refbackend_consume_func_return"};
for (auto type : types)
tokens.push_back(getTypeToken(type));
return std::accumulate(tokens.begin(), tokens.end(), std::string(),
[](std::string &a, std::string &b) {
return a.empty() ? b : (a + "_" + b);
});
}
// Replace the original returnOp with a call to consumeFuncReturnFunc and add
// the op to the `toErase` vector.
static void replaceReturnWithCall(OpBuilder b, ReturnOp op, StringRef funcName,
TypeRange retTypes,
SmallVectorImpl<Value> &vals,
SmallVectorImpl<Operation *> &toErase) { SmallVectorImpl<Operation *> &toErase) {
b.create<mlir::CallOp>(op.getLoc(), funcOp, val); b.create<mlir::CallOp>(op.getLoc(), funcName, TypeRange({}), vals);
b.create<mlir::ReturnOp>(op.getLoc()); b.create<mlir::ReturnOp>(op.getLoc());
toErase.push_back(op); toErase.push_back(op);
} }
// Checks whether the return op is munge-compatible and the respective calling
// function is defined.
static bool isReturnOpCompatible(ReturnOp op,
DenseMap<Type, FuncOp> &consumeFuncReturnFuncs,
Type returnType) {
auto it = consumeFuncReturnFuncs.find(returnType);
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
op.emitError("must have one return value of Memref type or Elemental types "
"of i64, f64, f32");
return false;
}
return true;
}
static LogicalResult mungeFunction( static LogicalResult mungeFunction(
FuncOp func, FuncOp func, std::set<std::string> &supportedConsumeFuncReturnFuncs,
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) { std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
// Add `llvm.emit_c_interface`. // Add `llvm.emit_c_interface`.
// This allows ExecutionEngine to resolve the symbol properly. // This allows ExecutionEngine to resolve the symbol properly.
addEmitCInterfaceAttr(func); addEmitCInterfaceAttr(func);
@ -120,37 +136,43 @@ static LogicalResult mungeFunction(
} }
SmallVector<Operation *> toErase; SmallVector<Operation *> toErase;
bool isCompatible = false; bool isSupported = true;
func.walk([&](ReturnOp op) { func.walk([&](ReturnOp op) {
auto returnType = op.getOperandTypes()[0]; auto types = op.getOperandTypes();
b.setInsertionPoint(op); b.setInsertionPoint(op);
// Memref Types. // Memref Types.
if (auto memrefReturnType = returnType.dyn_cast<MemRefType>()) { std::vector<Type> retTypes;
auto elemType = memrefReturnType.getElementType(); SmallVector<Value> retVals;
auto unRankedType = UnrankedMemRefType::get(elemType, 0); for (auto en : llvm::enumerate(types)) {
isCompatible = Type retType = en.value();
isReturnOpCompatible(op, consumeFuncReturnFuncs, unRankedType); Value retVal = op.getOperand(en.index());
if (!isCompatible) if (auto memrefReturnType = retType.dyn_cast<MemRefType>()) {
return; auto elemType = memrefReturnType.getElementType();
retType = UnrankedMemRefType::get(elemType, 0);
// Cast to unranked memref type before sending it as a function argument. // Cast to unranked memref type before sending it as a function
auto cast = b.create<memref::CastOp>( // argument.
op.getLoc(), op.getOperand(0), retVal = b.create<memref::CastOp>(
getAbiTypeForMemRef(op.getOperandTypes()[0])); op.getLoc(), retVal, getAbiTypeForMemRef(types[en.index()]));
replaceCallToFunction(b, op, consumeFuncReturnFuncs[unRankedType], }
cast.getResult(), toErase); retTypes.push_back(retType);
// Elemental types. retVals.push_back(retVal);
} else if (returnType.isa<IntegerType>() || returnType.isa<FloatType>()) {
isCompatible =
isReturnOpCompatible(op, consumeFuncReturnFuncs, returnType);
if (!isCompatible)
return;
replaceCallToFunction(b, op, consumeFuncReturnFuncs[returnType],
op->getOperand(0), toErase);
} }
auto supportedFuncsEnd = supportedConsumeFuncReturnFuncs.end();
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
op.emitError(
"must have one return value of memref types or scalar types "
"of i32, i64, f32, f64 or three return values of memref f32");
isSupported = false;
}
auto invokedFuncsEnd = invokedConsumeFuncReturnFuncs.end();
if (invokedConsumeFuncReturnFuncs.find(funcName) == invokedFuncsEnd)
invokedConsumeFuncReturnFuncs.insert({funcName, retTypes});
replaceReturnWithCall(b, op, funcName, retTypes, retVals, toErase);
}); });
if (!isCompatible) if (!isSupported)
return failure(); return failure();
func.setType(FunctionType::get(func.getContext(), newArgTypes, {})); func.setType(FunctionType::get(func.getContext(), newArgTypes, {}));
for (Operation *op : toErase) for (Operation *op : toErase)
@ -158,50 +180,47 @@ static LogicalResult mungeFunction(
return success(); return success();
} }
static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
std::set<std::string> funcNames;
Type mri32 = UnrankedMemRefType::get(b.getI32Type(), 0);
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0);
Type i64 = b.getI64Type();
Type f32 = b.getF32Type();
Type f64 = b.getF64Type();
SmallVector<TypeRange> supportedReturnTypes = {
mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}};
llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));
});
return funcNames;
}
namespace { namespace {
class MungeCallingConventions class MungeCallingConventions
: public MungeCallingConventionsBase<MungeCallingConventions> { : public MungeCallingConventionsBase<MungeCallingConventions> {
void runOnOperation() override { void runOnOperation() override {
auto module = getOperation(); auto module = getOperation();
OpBuilder b(module.getBodyRegion()); OpBuilder b(module.getBodyRegion());
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs; static std::set<std::string> supported =
DenseSet<FuncOp> consumeFuncReturnFuncsSet; getSupportedConsumeFuncReturnFuncs(b);
auto createConsumeFuncReturnFunc = [&](Type returnType, std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
std::string funcName) {
auto consumeFuncReturnFunc = b.create<FuncOp>(
module.getLoc(), funcName,
FunctionType::get(module.getContext(), returnType, {}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc);
consumeFuncReturnFuncs[returnType] = consumeFuncReturnFunc;
consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc);
};
// Memref return types.
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0),
"refbackend_consume_memref_int32_func_return");
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
"refbackend_consume_memref_int64_func_return");
createConsumeFuncReturnFunc(
UnrankedMemRefType::get(b.getF32Type(), 0),
"refbackend_consume_memref_float32_func_return");
createConsumeFuncReturnFunc(
UnrankedMemRefType::get(b.getF64Type(), 0),
"refbackend_consume_memref_float64_func_return");
// Elemental return types.
createConsumeFuncReturnFunc(b.getI64Type(),
"refbackend_consume_int64_func_return");
createConsumeFuncReturnFunc(b.getF32Type(),
"refbackend_consume_float32_func_return");
createConsumeFuncReturnFunc(b.getF64Type(),
"refbackend_consume_float64_func_return");
for (auto func : module.getOps<FuncOp>()) { for (auto func : module.getOps<FuncOp>()) {
if (consumeFuncReturnFuncsSet.contains(func)) if (failed(mungeFunction(func, supported, invokedConsumeFuncReturnFuncs)))
continue;
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
return signalPassFailure(); return signalPassFailure();
} }
// Create FuncOp for consumeFuncReturnFuncs that are used.
for (auto &p : invokedConsumeFuncReturnFuncs) {
auto consumeFuncReturnFunc =
b.create<FuncOp>(module.getLoc(), p.first,
FunctionType::get(module.getContext(), p.second, {}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc);
}
} }
}; };
} // namespace } // namespace

View File

@ -399,7 +399,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
emit_op(registry[key], f, **kwargs) emit_op(registry[key], f, **kwargs)
emit("prim::layout : (Tensor) -> (int)") emit("prim::layout : (Tensor) -> (int)")
emit("prim::TupleIndex : (Any, int) -> (Any)") emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True)
emit("prim::device : (Tensor) -> (Device)") emit("prim::device : (Tensor) -> (Device)")
emit("prim::dtype : (Tensor) -> (int)", has_folder=True) emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True) emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)

View File

@ -27,59 +27,73 @@ def checkArgTypeIsSupported(ty):
SUPPORTED = [np.float32, np.float64, np.int32, np.int64] SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported" assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
class RefBackendInvoker: class RefBackendInvoker:
def __init__(self, module): def __init__(self, module):
self.ee = ExecutionEngine(module) self.ee = ExecutionEngine(module)
self.result = None self.result = None
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_i32_return(a): def consume_return_mri32(a):
self.result = unranked_memref_to_numpy(a, np.int32) self.result = unranked_memref_to_numpy(a, np.int32)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_i64_return(a): def consume_return_mri64(a):
self.result = unranked_memref_to_numpy(a, np.int64) self.result = unranked_memref_to_numpy(a, np.int64)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_f32_return(a): def consume_return_mrf32(a):
self.result = unranked_memref_to_numpy(a, np.float32) self.result = unranked_memref_to_numpy(a, np.float32)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_f64_return(a): def consume_return_mrf64(a):
self.result = unranked_memref_to_numpy(a, np.float64) self.result = unranked_memref_to_numpy(a, np.float64)
@ctypes.CFUNCTYPE(None, ctypes.c_int) @ctypes.CFUNCTYPE(None, ctypes.c_int)
def consume_i64_return(a): def consume_return_i64(a):
self.result = a self.result = a
@ctypes.CFUNCTYPE(None, ctypes.c_float) @ctypes.CFUNCTYPE(None, ctypes.c_float)
def consume_f32_return(a): def consume_return_f32(a):
self.result = a self.result = a
@ctypes.CFUNCTYPE(None, ctypes.c_double) @ctypes.CFUNCTYPE(None, ctypes.c_double)
def consume_f64_return(a): def consume_return_f64(a):
self.result = a self.result = a
self.ee.register_runtime("refbackend_consume_memref_int32_func_return", @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
consume_memref_i32_return) ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf32_mrf32_mrf32(arg0, arg1, arg2):
self.result = unranked_memref_to_numpy(
arg0, np.float32), unranked_memref_to_numpy(
arg1,
np.float32), unranked_memref_to_numpy(arg2, np.float32)
self.ee.register_runtime("refbackend_consume_memref_int64_func_return", self.ee.register_runtime("refbackend_consume_func_return_mri32",
consume_memref_i64_return) consume_return_mri32)
self.ee.register_runtime("refbackend_consume_memref_float32_func_return", self.ee.register_runtime("refbackend_consume_func_return_mri64",
consume_memref_f32_return) consume_return_mri64)
self.ee.register_runtime("refbackend_consume_memref_float64_func_return", self.ee.register_runtime("refbackend_consume_func_return_mrf32",
consume_memref_f64_return) consume_return_mrf32)
self.ee.register_runtime("refbackend_consume_int64_func_return", self.ee.register_runtime("refbackend_consume_func_return_mrf64",
consume_i64_return) consume_return_mrf64)
self.ee.register_runtime("refbackend_consume_float32_func_return", self.ee.register_runtime("refbackend_consume_func_return_i64",
consume_f32_return) consume_return_i64)
self.ee.register_runtime("refbackend_consume_float64_func_return", self.ee.register_runtime("refbackend_consume_func_return_f32",
consume_f64_return) consume_return_f32)
self.ee.register_runtime("refbackend_consume_func_return_f64",
consume_return_f64)
self.ee.register_runtime(
"refbackend_consume_func_return_mrf32_mrf32_mrf32",
consume_return_mrf32_mrf32_mrf32)
def __getattr__(self, function_name: str): def __getattr__(self, function_name: str):
def invoke(*args): def invoke(*args):

View File

@ -47,3 +47,53 @@ func @none_call_return() {
"test.use"(%0) : (!torch.none) -> () "test.use"(%0) : (!torch.none) -> ()
return return
} }
// CHECK-LABEL: func @tuple_return(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] :
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
return %1 : !torch.tuple<!torch.tensor, !torch.tensor>
}
// CHECK-LABEL: func @call_tuple_return(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) :
// CHECK-SAME: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor)
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 :
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor>
return %0 : !torch.tuple<!torch.tensor, !torch.tensor>
}

View File

@ -568,3 +568,29 @@ func @torch.tensor_static_info_cast$upcast_first(%t: !torch.tensor<[?,?],f64>) -
%downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64> %downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64>
return %downcast: !torch.tensor<[?,?],f64> return %downcast: !torch.tensor<[?,?],f64>
} }
// CHECK-LABEL: func @torch.prim.TupleIndex(
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[T1]] : !torch.tensor
func @torch.prim.TupleIndex(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
%int1 = torch.constant.int 1
%1 = torch.prim.TupleIndex %0, %int1 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
return %1 : !torch.tensor
}
// CHECK-LABEL: func @torch.prim.TupleIndex$out_of_bound(
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[INDEX3:.*]] = torch.constant.int 3
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]], %[[T2]] :
// CHECK-SAME: !torch.tensor, !torch.tensor, !torch.tensor ->
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
// CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INDEX3]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
%int3 = torch.constant.int 3
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
return %1 : !torch.tensor
}

View File

@ -23,8 +23,6 @@ func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor
return %2 : !torch.tensor return %2 : !torch.tensor
} }
// ----- // -----
// Call to public function. // Call to public function.

View File

@ -4,7 +4,7 @@
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32> // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32> // CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_memref_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> () // CHECK: call @refbackend_consume_func_return_mrf32(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: return // CHECK: return
func @f(%arg0: memref<?xf32>) -> memref<?xf32> { func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
return %arg0 : memref<?xf32> return %arg0 : memref<?xf32>
@ -16,7 +16,7 @@ func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64> // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64> // CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
// CHECK: call @refbackend_consume_memref_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> () // CHECK: call @refbackend_consume_func_return_mri64(%[[RESULT]]) : (memref<*xi64>) -> ()
// CHECK: return // CHECK: return
func @i(%arg0: memref<?xi64>) -> memref<?xi64> { func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
return %arg0 : memref<?xi64> return %arg0 : memref<?xi64>
@ -28,9 +28,28 @@ func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64> // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64>
// CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64> // CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64>
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (i64) -> () // CHECK: call @refbackend_consume_func_return_i64(%[[RESULT]]) : (i64) -> ()
// CHECK: return // CHECK: return
func @elemental_type(%arg0: memref<i64>) -> i64 { func @elemental_type(%arg0: memref<i64>) -> i64 {
%0 = memref.load %arg0[] : memref<i64> %0 = memref.load %arg0[] : memref<i64>
return %0 : i64 return %0 : i64
} }
// -----
// CHECK-LABEL: func @multiple_return_values(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[VAL1:.*]] = memref.cast %[[ARG1]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[VAL2:.*]] = memref.cast %[[ARG2]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RET0:.*]] = memref.cast %[[VAL0]] : memref<?xf32> to memref<*xf32>
// CHECK: %[[RET1:.*]] = memref.cast %[[VAL1]] : memref<?xf32> to memref<*xf32>
// CHECK: %[[RET2:.*]] = memref.cast %[[VAL2]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_func_return_mrf32_mrf32_mrf32(%[[RET0]], %[[RET1]], %[[RET2]])
// CHECK-SAME: : (memref<*xf32>, memref<*xf32>, memref<*xf32>) -> ()
// CHECK: return
func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
return %arg0 ,%arg1, %arg2 : memref<?xf32>, memref<?xf32>, memref<?xf32>
}