Add support for multiple return values

This change is to unblock the work of some backprop ops returning more
than one tensors. We will need to think of a more scalable approach
in the future if more flexible return types combinations are needed.
pull/418/head
Yi Zhang 2021-11-08 10:56:40 -05:00
parent 6e8d39642e
commit 0fe70994e5
15 changed files with 334 additions and 130 deletions

View File

@ -11,11 +11,11 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class SoftmaxBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
@ -33,6 +33,8 @@ class SoftmaxBackwardModule(torch.nn.Module):
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
# ==============================================================================
class TanhBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -43,10 +45,11 @@ class TanhBackwardModule(torch.nn.Module):
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, grad_out, output):
return torch.ops.aten.tanh_backward(grad_out, output)
def forward(self, out_grad, output):
return torch.ops.aten.tanh_backward(out_grad, output)
@register_test_case(module_factory=lambda: TanhBackwardModule())
def TanhBackward_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3), torch.randn(3, 3))

View File

@ -543,7 +543,7 @@ class TensorToInt(torch.nn.Module):
@register_test_case(module_factory=lambda: TensorToInt())
def TensorToInt_basic(module, tu: TestUtils):
module.forward(torch.randint(10,[]))
class LogSoftmaxIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -577,3 +577,24 @@ class NumToTensorModule(torch.nn.Module):
@register_test_case(module_factory=lambda: NumToTensorModule())
def NumToTensorModule_basic(module, tu: TestUtils):
module.forward()
# This test can be removed once we have one real op returning 3 float32 tensors
class ReturnThreeTensorFloat32(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b, c):
return a, b, c
@register_test_case(module_factory=lambda: ReturnThreeTensorFloat32())
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))

View File

@ -29,4 +29,5 @@ TOSA_PASS_SET = {
"ElementwiseFloorModule_basic",
"ElementwiseLogModule_basic",
"TanhBackward_basic",
"ReturnThreeTensorFloat32_basic",
}

View File

@ -43,6 +43,7 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
AnyTorchType:$result
);
let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)";
let hasCanonicalizer = 1;
}
def Torch_PrimDeviceOp : Torch_Op<"prim.device", [

View File

@ -121,8 +121,8 @@ def AdjustCallingConventions
function arguments, which should be `!numpy.ndarray<...>`'s.
- Python-isms are rewritten to MLIR-isms
- NoneType return is rewritten to the absence of a return value.
- (Not implemented yet) Tuple return is rewritten to multiple return
values
- Tuple return is rewritten to multiple return values.
}];
}
@ -219,14 +219,14 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> {
let summary = "Decompose complicated torch operations";
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
let description = [{
Decompose torch operation that are losslessly represented as combinations of
other operations, modulo appropropriate compiler fusion. Note that this pass
is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about
systematic reductions of a large number of ops at once, guided mostly by
Decompose torch operation that are losslessly represented as combinations of
other operations, modulo appropropriate compiler fusion. Note that this pass
is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about
systematic reductions of a large number of ops at once, guided mostly by
traits.
An example of the transformations done in this pass is:
- convert aten.softmax to softmax(x, dim)
- convert aten.softmax to softmax(x, dim)
=> tmp=exp(x); tmp / sum(tmp, dim, keepdim=True)
}];
}

View File

@ -916,6 +916,29 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
});
}
//===----------------------------------------------------------------------===//
// PrimTupleIndexOp
//===----------------------------------------------------------------------===//
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
if (!tupleConstruct)
return failure();
int64_t i;
if (!matchPattern(op.i(), m_TorchConstantInt(&i)))
return failure();
if (i >= (int64_t)tupleConstruct.elements().size())
return failure();
rewriter.replaceOp(op, tupleConstruct.elements()[i]);
return success();
});
}
//===----------------------------------------------------------------------===//
// PrimTupleUnpackOp
//===----------------------------------------------------------------------===//
@ -923,9 +946,7 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
auto torchTuple = op.tup();
auto tupleConstruct =
torchTuple.getDefiningOp<Torch::PrimTupleConstructOp>();
auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
if (!tupleConstruct)
return failure();

View File

@ -66,15 +66,20 @@ public:
// TODO: add tuple type.
conversion.addInputs(type.index(), type.value());
}
rewriter.applySignatureConversion(&func.getBody(), conversion,
typeConverter);
SmallVector<Type> newResultTypes;
for (auto type : func.getType().getResults()) {
if (auto none = type.dyn_cast<Torch::NoneType>()) {
continue;
}
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
llvm::append_range(newResultTypes, tuple.getContainedTypes());
continue;
}
newResultTypes.push_back(type);
}
rewriter.applySignatureConversion(&func.getBody(), conversion,
typeConverter);
rewriter.updateRootInPlace(func, [&] {
func.setType(FunctionType::get(
getContext(), conversion.getConvertedTypes(), newResultTypes));
@ -131,6 +136,11 @@ public:
rewriter.create<ConstantNoneOp>(call.getLoc(), type));
continue;
}
if (type.isa<Torch::TupleType>()) {
newResults.push_back(rewriter.create<PrimTupleConstructOp>(
call.getLoc(), type, newCall.getResults()));
continue;
}
newResults.push_back(newCall.getResult(newOpResultIdx++));
}
rewriter.replaceOp(call, newResults);
@ -151,12 +161,22 @@ public:
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newOperands;
for (auto operand : llvm::enumerate(adaptor.getOperands())) {
if (!operand.value())
for (auto operand : adaptor.getOperands()) {
if (!operand)
continue;
if (operand.value().getType().isa<Torch::NoneType>())
if (operand.getType().isa<Torch::NoneType>())
continue;
newOperands.push_back(operand.value());
if (auto tuple = operand.getType().dyn_cast<Torch::TupleType>()) {
Location loc = op.getLoc();
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
auto i = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(en.index()));
newOperands.push_back(
rewriter.create<PrimTupleIndexOp>(loc, en.value(), operand, i));
}
continue;
}
newOperands.push_back(operand);
}
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
@ -168,9 +188,14 @@ static LogicalResult adjustCallingConventions(FuncOp func,
TypeBoundMap &typeBoundMap) {
MLIRContext *context = func.getContext();
RewritePatternSet patterns(context);
// TODO: TupleTypes
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
[](Torch::TupleType type,
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
llvm::append_range(types, type.getContainedTypes());
return success();
});
typeConverter.addConversion(
[](Torch::NoneType type,
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
@ -220,6 +245,9 @@ static LogicalResult adjustCallingConventions(FuncOp func,
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
target.addLegalOp<TensorStaticInfoCastOp>();
target.addLegalOp<ConstantNoneOp>();
target.addLegalOp<ConstantIntOp>();
target.addLegalOp<PrimTupleIndexOp>();
target.addLegalOp<PrimTupleConstructOp>();
// We don't know how to rewrite it, so mark it as illegal.
target.addIllegalOp<CallIndirectOp>();
if (failed(applyPartialConversion(func.getOperation(), target,

View File

@ -89,6 +89,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
pm.addPass(createAdjustCallingConventionsPass());
if (options.optimize) {
// Eliminate the PrimTupleIndexOp generated from the
// adjustCallingConventions
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
// Inline global slots, which for most inference scenarios deletes them.
// This also exposes more information to intraprocedural transformations
// below like MaximizeValueSemantics and RefineTypes.

View File

@ -21,7 +21,9 @@
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
#include "set"
#include "torch-mlir/RefBackend/Passes.h"
#include <numeric>
using namespace mlir;
using namespace mlir::torch;
@ -67,33 +69,47 @@ static Type getAbiTypeForMemRef(Type type) {
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
}
// Passes the return op operands `val` to `funOp`. Also, adds the op to the
// `toErase` vector.
static void replaceCallToFunction(OpBuilder b, ReturnOp op, FuncOp funcOp,
Value val,
// Helper function to get the type string for one return value like i32, f64,
// mri32 etc. The strings from multiple return values are concatenated to get
// the consumeFuncReturnFunc name.
static std::string getTypeToken(Type type) {
if (type.isSignlessInteger())
return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
else if (type.isa<mlir::FloatType>())
return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
else if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
return "mr" + getTypeToken(memRefType.getElementType());
llvm_unreachable(
"Type token should handle all types: memref, float and int type");
}
// Systematically derive the consumeFuncReturnFunc name from return value types.
static std::string getConsumeReturnFunctionNameForReturnTypes(TypeRange types) {
SmallVector<std::string> tokens = {"refbackend_consume_func_return"};
for (auto type : types)
tokens.push_back(getTypeToken(type));
return std::accumulate(tokens.begin(), tokens.end(), std::string(),
[](std::string &a, std::string &b) {
return a.empty() ? b : (a + "_" + b);
});
}
// Replace the original returnOp with a call to consumeFuncReturnFunc and add
// the op to the `toErase` vector.
static void replaceReturnWithCall(OpBuilder b, ReturnOp op, StringRef funcName,
TypeRange retTypes,
SmallVectorImpl<Value> &vals,
SmallVectorImpl<Operation *> &toErase) {
b.create<mlir::CallOp>(op.getLoc(), funcOp, val);
b.create<mlir::CallOp>(op.getLoc(), funcName, TypeRange({}), vals);
b.create<mlir::ReturnOp>(op.getLoc());
toErase.push_back(op);
}
// Checks whether the return op is munge-compatible and the respective calling
// function is defined.
static bool isReturnOpCompatible(ReturnOp op,
DenseMap<Type, FuncOp> &consumeFuncReturnFuncs,
Type returnType) {
auto it = consumeFuncReturnFuncs.find(returnType);
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
op.emitError("must have one return value of Memref type or Elemental types "
"of i64, f64, f32");
return false;
}
return true;
}
static LogicalResult mungeFunction(
FuncOp func,
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) {
FuncOp func, std::set<std::string> &supportedConsumeFuncReturnFuncs,
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
// Add `llvm.emit_c_interface`.
// This allows ExecutionEngine to resolve the symbol properly.
addEmitCInterfaceAttr(func);
@ -120,37 +136,43 @@ static LogicalResult mungeFunction(
}
SmallVector<Operation *> toErase;
bool isCompatible = false;
bool isSupported = true;
func.walk([&](ReturnOp op) {
auto returnType = op.getOperandTypes()[0];
auto types = op.getOperandTypes();
b.setInsertionPoint(op);
// Memref Types.
if (auto memrefReturnType = returnType.dyn_cast<MemRefType>()) {
auto elemType = memrefReturnType.getElementType();
auto unRankedType = UnrankedMemRefType::get(elemType, 0);
isCompatible =
isReturnOpCompatible(op, consumeFuncReturnFuncs, unRankedType);
if (!isCompatible)
return;
// Cast to unranked memref type before sending it as a function argument.
auto cast = b.create<memref::CastOp>(
op.getLoc(), op.getOperand(0),
getAbiTypeForMemRef(op.getOperandTypes()[0]));
replaceCallToFunction(b, op, consumeFuncReturnFuncs[unRankedType],
cast.getResult(), toErase);
// Elemental types.
} else if (returnType.isa<IntegerType>() || returnType.isa<FloatType>()) {
isCompatible =
isReturnOpCompatible(op, consumeFuncReturnFuncs, returnType);
if (!isCompatible)
return;
replaceCallToFunction(b, op, consumeFuncReturnFuncs[returnType],
op->getOperand(0), toErase);
std::vector<Type> retTypes;
SmallVector<Value> retVals;
for (auto en : llvm::enumerate(types)) {
Type retType = en.value();
Value retVal = op.getOperand(en.index());
if (auto memrefReturnType = retType.dyn_cast<MemRefType>()) {
auto elemType = memrefReturnType.getElementType();
retType = UnrankedMemRefType::get(elemType, 0);
// Cast to unranked memref type before sending it as a function
// argument.
retVal = b.create<memref::CastOp>(
op.getLoc(), retVal, getAbiTypeForMemRef(types[en.index()]));
}
retTypes.push_back(retType);
retVals.push_back(retVal);
}
auto supportedFuncsEnd = supportedConsumeFuncReturnFuncs.end();
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
op.emitError(
"must have one return value of memref types or scalar types "
"of i32, i64, f32, f64 or three return values of memref f32");
isSupported = false;
}
auto invokedFuncsEnd = invokedConsumeFuncReturnFuncs.end();
if (invokedConsumeFuncReturnFuncs.find(funcName) == invokedFuncsEnd)
invokedConsumeFuncReturnFuncs.insert({funcName, retTypes});
replaceReturnWithCall(b, op, funcName, retTypes, retVals, toErase);
});
if (!isCompatible)
if (!isSupported)
return failure();
func.setType(FunctionType::get(func.getContext(), newArgTypes, {}));
for (Operation *op : toErase)
@ -158,50 +180,47 @@ static LogicalResult mungeFunction(
return success();
}
static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
std::set<std::string> funcNames;
Type mri32 = UnrankedMemRefType::get(b.getI32Type(), 0);
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0);
Type i64 = b.getI64Type();
Type f32 = b.getF32Type();
Type f64 = b.getF64Type();
SmallVector<TypeRange> supportedReturnTypes = {
mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}};
llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));
});
return funcNames;
}
namespace {
class MungeCallingConventions
: public MungeCallingConventionsBase<MungeCallingConventions> {
void runOnOperation() override {
auto module = getOperation();
OpBuilder b(module.getBodyRegion());
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
DenseSet<FuncOp> consumeFuncReturnFuncsSet;
auto createConsumeFuncReturnFunc = [&](Type returnType,
std::string funcName) {
auto consumeFuncReturnFunc = b.create<FuncOp>(
module.getLoc(), funcName,
FunctionType::get(module.getContext(), returnType, {}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc);
consumeFuncReturnFuncs[returnType] = consumeFuncReturnFunc;
consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc);
};
// Memref return types.
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0),
"refbackend_consume_memref_int32_func_return");
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
"refbackend_consume_memref_int64_func_return");
createConsumeFuncReturnFunc(
UnrankedMemRefType::get(b.getF32Type(), 0),
"refbackend_consume_memref_float32_func_return");
createConsumeFuncReturnFunc(
UnrankedMemRefType::get(b.getF64Type(), 0),
"refbackend_consume_memref_float64_func_return");
// Elemental return types.
createConsumeFuncReturnFunc(b.getI64Type(),
"refbackend_consume_int64_func_return");
createConsumeFuncReturnFunc(b.getF32Type(),
"refbackend_consume_float32_func_return");
createConsumeFuncReturnFunc(b.getF64Type(),
"refbackend_consume_float64_func_return");
static std::set<std::string> supported =
getSupportedConsumeFuncReturnFuncs(b);
std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
for (auto func : module.getOps<FuncOp>()) {
if (consumeFuncReturnFuncsSet.contains(func))
continue;
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
if (failed(mungeFunction(func, supported, invokedConsumeFuncReturnFuncs)))
return signalPassFailure();
}
// Create FuncOp for consumeFuncReturnFuncs that are used.
for (auto &p : invokedConsumeFuncReturnFuncs) {
auto consumeFuncReturnFunc =
b.create<FuncOp>(module.getLoc(), p.first,
FunctionType::get(module.getContext(), p.second, {}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc);
}
}
};
} // namespace

View File

@ -399,7 +399,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
emit_op(registry[key], f, **kwargs)
emit("prim::layout : (Tensor) -> (int)")
emit("prim::TupleIndex : (Any, int) -> (Any)")
emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True)
emit("prim::device : (Tensor) -> (Device)")
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)

View File

@ -27,59 +27,73 @@ def checkArgTypeIsSupported(ty):
SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
class RefBackendInvoker:
def __init__(self, module):
self.ee = ExecutionEngine(module)
self.result = None
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_i32_return(a):
def consume_return_mri32(a):
self.result = unranked_memref_to_numpy(a, np.int32)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_i64_return(a):
def consume_return_mri64(a):
self.result = unranked_memref_to_numpy(a, np.int64)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_f32_return(a):
def consume_return_mrf32(a):
self.result = unranked_memref_to_numpy(a, np.float32)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_f64_return(a):
def consume_return_mrf64(a):
self.result = unranked_memref_to_numpy(a, np.float64)
@ctypes.CFUNCTYPE(None, ctypes.c_int)
def consume_i64_return(a):
def consume_return_i64(a):
self.result = a
@ctypes.CFUNCTYPE(None, ctypes.c_float)
def consume_f32_return(a):
def consume_return_f32(a):
self.result = a
@ctypes.CFUNCTYPE(None, ctypes.c_double)
def consume_f64_return(a):
def consume_return_f64(a):
self.result = a
self.ee.register_runtime("refbackend_consume_memref_int32_func_return",
consume_memref_i32_return)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf32_mrf32_mrf32(arg0, arg1, arg2):
self.result = unranked_memref_to_numpy(
arg0, np.float32), unranked_memref_to_numpy(
arg1,
np.float32), unranked_memref_to_numpy(arg2, np.float32)
self.ee.register_runtime("refbackend_consume_memref_int64_func_return",
consume_memref_i64_return)
self.ee.register_runtime("refbackend_consume_func_return_mri32",
consume_return_mri32)
self.ee.register_runtime("refbackend_consume_memref_float32_func_return",
consume_memref_f32_return)
self.ee.register_runtime("refbackend_consume_func_return_mri64",
consume_return_mri64)
self.ee.register_runtime("refbackend_consume_memref_float64_func_return",
consume_memref_f64_return)
self.ee.register_runtime("refbackend_consume_func_return_mrf32",
consume_return_mrf32)
self.ee.register_runtime("refbackend_consume_int64_func_return",
consume_i64_return)
self.ee.register_runtime("refbackend_consume_func_return_mrf64",
consume_return_mrf64)
self.ee.register_runtime("refbackend_consume_float32_func_return",
consume_f32_return)
self.ee.register_runtime("refbackend_consume_func_return_i64",
consume_return_i64)
self.ee.register_runtime("refbackend_consume_float64_func_return",
consume_f64_return)
self.ee.register_runtime("refbackend_consume_func_return_f32",
consume_return_f32)
self.ee.register_runtime("refbackend_consume_func_return_f64",
consume_return_f64)
self.ee.register_runtime(
"refbackend_consume_func_return_mrf32_mrf32_mrf32",
consume_return_mrf32_mrf32_mrf32)
def __getattr__(self, function_name: str):
def invoke(*args):

View File

@ -47,3 +47,53 @@ func @none_call_return() {
"test.use"(%0) : (!torch.none) -> ()
return
}
// CHECK-LABEL: func @tuple_return(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] :
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
return %1 : !torch.tuple<!torch.tensor, !torch.tensor>
}
// CHECK-LABEL: func @call_tuple_return(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) :
// CHECK-SAME: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor)
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 :
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor>
return %0 : !torch.tuple<!torch.tensor, !torch.tensor>
}

View File

@ -568,3 +568,29 @@ func @torch.tensor_static_info_cast$upcast_first(%t: !torch.tensor<[?,?],f64>) -
%downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64>
return %downcast: !torch.tensor<[?,?],f64>
}
// CHECK-LABEL: func @torch.prim.TupleIndex(
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[T1]] : !torch.tensor
func @torch.prim.TupleIndex(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
%int1 = torch.constant.int 1
%1 = torch.prim.TupleIndex %0, %int1 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
return %1 : !torch.tensor
}
// CHECK-LABEL: func @torch.prim.TupleIndex$out_of_bound(
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[INDEX3:.*]] = torch.constant.int 3
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]], %[[T2]] :
// CHECK-SAME: !torch.tensor, !torch.tensor, !torch.tensor ->
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
// CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INDEX3]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
%int3 = torch.constant.int 3
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
return %1 : !torch.tensor
}

View File

@ -23,8 +23,6 @@ func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor
return %2 : !torch.tensor
}
// -----
// Call to public function.

View File

@ -4,7 +4,7 @@
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_memref_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: call @refbackend_consume_func_return_mrf32(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: return
func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
return %arg0 : memref<?xf32>
@ -16,7 +16,7 @@ func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
// CHECK: call @refbackend_consume_memref_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> ()
// CHECK: call @refbackend_consume_func_return_mri64(%[[RESULT]]) : (memref<*xi64>) -> ()
// CHECK: return
func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
return %arg0 : memref<?xi64>
@ -28,9 +28,28 @@ func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64>
// CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64>
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (i64) -> ()
// CHECK: call @refbackend_consume_func_return_i64(%[[RESULT]]) : (i64) -> ()
// CHECK: return
func @elemental_type(%arg0: memref<i64>) -> i64 {
%0 = memref.load %arg0[] : memref<i64>
return %0 : i64
}
// -----
// CHECK-LABEL: func @multiple_return_values(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[VAL1:.*]] = memref.cast %[[ARG1]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[VAL2:.*]] = memref.cast %[[ARG2]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RET0:.*]] = memref.cast %[[VAL0]] : memref<?xf32> to memref<*xf32>
// CHECK: %[[RET1:.*]] = memref.cast %[[VAL1]] : memref<?xf32> to memref<*xf32>
// CHECK: %[[RET2:.*]] = memref.cast %[[VAL2]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_func_return_mrf32_mrf32_mrf32(%[[RET0]], %[[RET1]], %[[RET2]])
// CHECK-SAME: : (memref<*xf32>, memref<*xf32>, memref<*xf32>) -> ()
// CHECK: return
func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
return %arg0 ,%arg1, %arg2 : memref<?xf32>, memref<?xf32>, memref<?xf32>
}