Add support for multiple return values

This change is to unblock the work of some backprop ops returning more
than one tensors. We will need to think of a more scalable approach
in the future if more flexible return types combinations are needed.
pull/418/head
Yi Zhang 2021-11-08 10:56:40 -05:00
parent 6e8d39642e
commit 0fe70994e5
15 changed files with 334 additions and 130 deletions

View File

@ -11,11 +11,11 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ============================================================================== # ==============================================================================
class SoftmaxBackwardModule(torch.nn.Module): class SoftmaxBackwardModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@export @export
@annotate_args([ @annotate_args([
None, None,
@ -33,6 +33,8 @@ class SoftmaxBackwardModule(torch.nn.Module):
def SoftmaxBackwardModule_basic(module, tu: TestUtils): def SoftmaxBackwardModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4)) module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
# ==============================================================================
class TanhBackwardModule(torch.nn.Module): class TanhBackwardModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -43,10 +45,11 @@ class TanhBackwardModule(torch.nn.Module):
([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True),
]) ])
def forward(self, grad_out, output):
return torch.ops.aten.tanh_backward(grad_out, output)
def forward(self, out_grad, output):
return torch.ops.aten.tanh_backward(out_grad, output)
@register_test_case(module_factory=lambda: TanhBackwardModule()) @register_test_case(module_factory=lambda: TanhBackwardModule())
def TanhBackward_basic(module, tu: TestUtils): def TanhBackward_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3), torch.randn(3, 3)) module.forward(torch.randn(3, 3), torch.randn(3, 3))

View File

@ -543,7 +543,7 @@ class TensorToInt(torch.nn.Module):
@register_test_case(module_factory=lambda: TensorToInt()) @register_test_case(module_factory=lambda: TensorToInt())
def TensorToInt_basic(module, tu: TestUtils): def TensorToInt_basic(module, tu: TestUtils):
module.forward(torch.randint(10,[])) module.forward(torch.randint(10,[]))
class LogSoftmaxIntModule(torch.nn.Module): class LogSoftmaxIntModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -577,3 +577,24 @@ class NumToTensorModule(torch.nn.Module):
@register_test_case(module_factory=lambda: NumToTensorModule()) @register_test_case(module_factory=lambda: NumToTensorModule())
def NumToTensorModule_basic(module, tu: TestUtils): def NumToTensorModule_basic(module, tu: TestUtils):
module.forward() module.forward()
# This test can be removed once we have one real op returning 3 float32 tensors
class ReturnThreeTensorFloat32(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b, c):
return a, b, c
@register_test_case(module_factory=lambda: ReturnThreeTensorFloat32())
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))

View File

@ -29,4 +29,5 @@ TOSA_PASS_SET = {
"ElementwiseFloorModule_basic", "ElementwiseFloorModule_basic",
"ElementwiseLogModule_basic", "ElementwiseLogModule_basic",
"TanhBackward_basic", "TanhBackward_basic",
"ReturnThreeTensorFloat32_basic",
} }

View File

@ -43,6 +43,7 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
AnyTorchType:$result AnyTorchType:$result
); );
let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)"; let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)";
let hasCanonicalizer = 1;
} }
def Torch_PrimDeviceOp : Torch_Op<"prim.device", [ def Torch_PrimDeviceOp : Torch_Op<"prim.device", [

View File

@ -121,8 +121,8 @@ def AdjustCallingConventions
function arguments, which should be `!numpy.ndarray<...>`'s. function arguments, which should be `!numpy.ndarray<...>`'s.
- Python-isms are rewritten to MLIR-isms - Python-isms are rewritten to MLIR-isms
- NoneType return is rewritten to the absence of a return value. - NoneType return is rewritten to the absence of a return value.
- (Not implemented yet) Tuple return is rewritten to multiple return - Tuple return is rewritten to multiple return values.
values
}]; }];
} }
@ -219,14 +219,14 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> {
let summary = "Decompose complicated torch operations"; let summary = "Decompose complicated torch operations";
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()"; let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
let description = [{ let description = [{
Decompose torch operation that are losslessly represented as combinations of Decompose torch operation that are losslessly represented as combinations of
other operations, modulo appropropriate compiler fusion. Note that this pass other operations, modulo appropropriate compiler fusion. Note that this pass
is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about
systematic reductions of a large number of ops at once, guided mostly by systematic reductions of a large number of ops at once, guided mostly by
traits. traits.
An example of the transformations done in this pass is: An example of the transformations done in this pass is:
- convert aten.softmax to softmax(x, dim) - convert aten.softmax to softmax(x, dim)
=> tmp=exp(x); tmp / sum(tmp, dim, keepdim=True) => tmp=exp(x); tmp / sum(tmp, dim, keepdim=True)
}]; }];
} }

View File

@ -916,6 +916,29 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
}); });
} }
//===----------------------------------------------------------------------===//
// PrimTupleIndexOp
//===----------------------------------------------------------------------===//
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
if (!tupleConstruct)
return failure();
int64_t i;
if (!matchPattern(op.i(), m_TorchConstantInt(&i)))
return failure();
if (i >= (int64_t)tupleConstruct.elements().size())
return failure();
rewriter.replaceOp(op, tupleConstruct.elements()[i]);
return success();
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PrimTupleUnpackOp // PrimTupleUnpackOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -923,9 +946,7 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { MLIRContext *context) {
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) { patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
auto torchTuple = op.tup(); auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
auto tupleConstruct =
torchTuple.getDefiningOp<Torch::PrimTupleConstructOp>();
if (!tupleConstruct) if (!tupleConstruct)
return failure(); return failure();

View File

@ -66,15 +66,20 @@ public:
// TODO: add tuple type. // TODO: add tuple type.
conversion.addInputs(type.index(), type.value()); conversion.addInputs(type.index(), type.value());
} }
rewriter.applySignatureConversion(&func.getBody(), conversion,
typeConverter);
SmallVector<Type> newResultTypes; SmallVector<Type> newResultTypes;
for (auto type : func.getType().getResults()) { for (auto type : func.getType().getResults()) {
if (auto none = type.dyn_cast<Torch::NoneType>()) { if (auto none = type.dyn_cast<Torch::NoneType>()) {
continue; continue;
} }
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
llvm::append_range(newResultTypes, tuple.getContainedTypes());
continue;
}
newResultTypes.push_back(type); newResultTypes.push_back(type);
} }
rewriter.applySignatureConversion(&func.getBody(), conversion,
typeConverter);
rewriter.updateRootInPlace(func, [&] { rewriter.updateRootInPlace(func, [&] {
func.setType(FunctionType::get( func.setType(FunctionType::get(
getContext(), conversion.getConvertedTypes(), newResultTypes)); getContext(), conversion.getConvertedTypes(), newResultTypes));
@ -131,6 +136,11 @@ public:
rewriter.create<ConstantNoneOp>(call.getLoc(), type)); rewriter.create<ConstantNoneOp>(call.getLoc(), type));
continue; continue;
} }
if (type.isa<Torch::TupleType>()) {
newResults.push_back(rewriter.create<PrimTupleConstructOp>(
call.getLoc(), type, newCall.getResults()));
continue;
}
newResults.push_back(newCall.getResult(newOpResultIdx++)); newResults.push_back(newCall.getResult(newOpResultIdx++));
} }
rewriter.replaceOp(call, newResults); rewriter.replaceOp(call, newResults);
@ -151,12 +161,22 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newOperands; SmallVector<Value> newOperands;
for (auto operand : llvm::enumerate(adaptor.getOperands())) { for (auto operand : adaptor.getOperands()) {
if (!operand.value()) if (!operand)
continue; continue;
if (operand.value().getType().isa<Torch::NoneType>()) if (operand.getType().isa<Torch::NoneType>())
continue; continue;
newOperands.push_back(operand.value()); if (auto tuple = operand.getType().dyn_cast<Torch::TupleType>()) {
Location loc = op.getLoc();
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
auto i = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(en.index()));
newOperands.push_back(
rewriter.create<PrimTupleIndexOp>(loc, en.value(), operand, i));
}
continue;
}
newOperands.push_back(operand);
} }
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands); rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success(); return success();
@ -168,9 +188,14 @@ static LogicalResult adjustCallingConventions(FuncOp func,
TypeBoundMap &typeBoundMap) { TypeBoundMap &typeBoundMap) {
MLIRContext *context = func.getContext(); MLIRContext *context = func.getContext();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
// TODO: TupleTypes
TypeConverter typeConverter; TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
[](Torch::TupleType type,
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
llvm::append_range(types, type.getContainedTypes());
return success();
});
typeConverter.addConversion( typeConverter.addConversion(
[](Torch::NoneType type, [](Torch::NoneType type,
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> { SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
@ -220,6 +245,9 @@ static LogicalResult adjustCallingConventions(FuncOp func,
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>(); target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
target.addLegalOp<TensorStaticInfoCastOp>(); target.addLegalOp<TensorStaticInfoCastOp>();
target.addLegalOp<ConstantNoneOp>(); target.addLegalOp<ConstantNoneOp>();
target.addLegalOp<ConstantIntOp>();
target.addLegalOp<PrimTupleIndexOp>();
target.addLegalOp<PrimTupleConstructOp>();
// We don't know how to rewrite it, so mark it as illegal. // We don't know how to rewrite it, so mark it as illegal.
target.addIllegalOp<CallIndirectOp>(); target.addIllegalOp<CallIndirectOp>();
if (failed(applyPartialConversion(func.getOperation(), target, if (failed(applyPartialConversion(func.getOperation(), target,

View File

@ -89,6 +89,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
pm.addPass(createAdjustCallingConventionsPass()); pm.addPass(createAdjustCallingConventionsPass());
if (options.optimize) { if (options.optimize) {
// Eliminate the PrimTupleIndexOp generated from the
// adjustCallingConventions
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
// Inline global slots, which for most inference scenarios deletes them. // Inline global slots, which for most inference scenarios deletes them.
// This also exposes more information to intraprocedural transformations // This also exposes more information to intraprocedural transformations
// below like MaximizeValueSemantics and RefineTypes. // below like MaximizeValueSemantics and RefineTypes.

View File

@ -21,7 +21,9 @@
#include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "set"
#include "torch-mlir/RefBackend/Passes.h" #include "torch-mlir/RefBackend/Passes.h"
#include <numeric>
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -67,33 +69,47 @@ static Type getAbiTypeForMemRef(Type type) {
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0); return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
} }
// Passes the return op operands `val` to `funOp`. Also, adds the op to the // Helper function to get the type string for one return value like i32, f64,
// `toErase` vector. // mri32 etc. The strings from multiple return values are concatenated to get
static void replaceCallToFunction(OpBuilder b, ReturnOp op, FuncOp funcOp, // the consumeFuncReturnFunc name.
Value val, static std::string getTypeToken(Type type) {
if (type.isSignlessInteger())
return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
else if (type.isa<mlir::FloatType>())
return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
else if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
return "mr" + getTypeToken(memRefType.getElementType());
llvm_unreachable(
"Type token should handle all types: memref, float and int type");
}
// Systematically derive the consumeFuncReturnFunc name from return value types.
static std::string getConsumeReturnFunctionNameForReturnTypes(TypeRange types) {
SmallVector<std::string> tokens = {"refbackend_consume_func_return"};
for (auto type : types)
tokens.push_back(getTypeToken(type));
return std::accumulate(tokens.begin(), tokens.end(), std::string(),
[](std::string &a, std::string &b) {
return a.empty() ? b : (a + "_" + b);
});
}
// Replace the original returnOp with a call to consumeFuncReturnFunc and add
// the op to the `toErase` vector.
static void replaceReturnWithCall(OpBuilder b, ReturnOp op, StringRef funcName,
TypeRange retTypes,
SmallVectorImpl<Value> &vals,
SmallVectorImpl<Operation *> &toErase) { SmallVectorImpl<Operation *> &toErase) {
b.create<mlir::CallOp>(op.getLoc(), funcOp, val); b.create<mlir::CallOp>(op.getLoc(), funcName, TypeRange({}), vals);
b.create<mlir::ReturnOp>(op.getLoc()); b.create<mlir::ReturnOp>(op.getLoc());
toErase.push_back(op); toErase.push_back(op);
} }
// Checks whether the return op is munge-compatible and the respective calling
// function is defined.
static bool isReturnOpCompatible(ReturnOp op,
DenseMap<Type, FuncOp> &consumeFuncReturnFuncs,
Type returnType) {
auto it = consumeFuncReturnFuncs.find(returnType);
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
op.emitError("must have one return value of Memref type or Elemental types "
"of i64, f64, f32");
return false;
}
return true;
}
static LogicalResult mungeFunction( static LogicalResult mungeFunction(
FuncOp func, FuncOp func, std::set<std::string> &supportedConsumeFuncReturnFuncs,
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) { std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
// Add `llvm.emit_c_interface`. // Add `llvm.emit_c_interface`.
// This allows ExecutionEngine to resolve the symbol properly. // This allows ExecutionEngine to resolve the symbol properly.
addEmitCInterfaceAttr(func); addEmitCInterfaceAttr(func);
@ -120,37 +136,43 @@ static LogicalResult mungeFunction(
} }
SmallVector<Operation *> toErase; SmallVector<Operation *> toErase;
bool isCompatible = false; bool isSupported = true;
func.walk([&](ReturnOp op) { func.walk([&](ReturnOp op) {
auto returnType = op.getOperandTypes()[0]; auto types = op.getOperandTypes();
b.setInsertionPoint(op); b.setInsertionPoint(op);
// Memref Types. // Memref Types.
if (auto memrefReturnType = returnType.dyn_cast<MemRefType>()) { std::vector<Type> retTypes;
auto elemType = memrefReturnType.getElementType(); SmallVector<Value> retVals;
auto unRankedType = UnrankedMemRefType::get(elemType, 0); for (auto en : llvm::enumerate(types)) {
isCompatible = Type retType = en.value();
isReturnOpCompatible(op, consumeFuncReturnFuncs, unRankedType); Value retVal = op.getOperand(en.index());
if (!isCompatible) if (auto memrefReturnType = retType.dyn_cast<MemRefType>()) {
return; auto elemType = memrefReturnType.getElementType();
retType = UnrankedMemRefType::get(elemType, 0);
// Cast to unranked memref type before sending it as a function argument. // Cast to unranked memref type before sending it as a function
auto cast = b.create<memref::CastOp>( // argument.
op.getLoc(), op.getOperand(0), retVal = b.create<memref::CastOp>(
getAbiTypeForMemRef(op.getOperandTypes()[0])); op.getLoc(), retVal, getAbiTypeForMemRef(types[en.index()]));
replaceCallToFunction(b, op, consumeFuncReturnFuncs[unRankedType], }
cast.getResult(), toErase); retTypes.push_back(retType);
// Elemental types. retVals.push_back(retVal);
} else if (returnType.isa<IntegerType>() || returnType.isa<FloatType>()) {
isCompatible =
isReturnOpCompatible(op, consumeFuncReturnFuncs, returnType);
if (!isCompatible)
return;
replaceCallToFunction(b, op, consumeFuncReturnFuncs[returnType],
op->getOperand(0), toErase);
} }
auto supportedFuncsEnd = supportedConsumeFuncReturnFuncs.end();
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
op.emitError(
"must have one return value of memref types or scalar types "
"of i32, i64, f32, f64 or three return values of memref f32");
isSupported = false;
}
auto invokedFuncsEnd = invokedConsumeFuncReturnFuncs.end();
if (invokedConsumeFuncReturnFuncs.find(funcName) == invokedFuncsEnd)
invokedConsumeFuncReturnFuncs.insert({funcName, retTypes});
replaceReturnWithCall(b, op, funcName, retTypes, retVals, toErase);
}); });
if (!isCompatible) if (!isSupported)
return failure(); return failure();
func.setType(FunctionType::get(func.getContext(), newArgTypes, {})); func.setType(FunctionType::get(func.getContext(), newArgTypes, {}));
for (Operation *op : toErase) for (Operation *op : toErase)
@ -158,50 +180,47 @@ static LogicalResult mungeFunction(
return success(); return success();
} }
static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
std::set<std::string> funcNames;
Type mri32 = UnrankedMemRefType::get(b.getI32Type(), 0);
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0);
Type i64 = b.getI64Type();
Type f32 = b.getF32Type();
Type f64 = b.getF64Type();
SmallVector<TypeRange> supportedReturnTypes = {
mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}};
llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));
});
return funcNames;
}
namespace { namespace {
class MungeCallingConventions class MungeCallingConventions
: public MungeCallingConventionsBase<MungeCallingConventions> { : public MungeCallingConventionsBase<MungeCallingConventions> {
void runOnOperation() override { void runOnOperation() override {
auto module = getOperation(); auto module = getOperation();
OpBuilder b(module.getBodyRegion()); OpBuilder b(module.getBodyRegion());
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs; static std::set<std::string> supported =
DenseSet<FuncOp> consumeFuncReturnFuncsSet; getSupportedConsumeFuncReturnFuncs(b);
auto createConsumeFuncReturnFunc = [&](Type returnType, std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
std::string funcName) {
auto consumeFuncReturnFunc = b.create<FuncOp>(
module.getLoc(), funcName,
FunctionType::get(module.getContext(), returnType, {}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc);
consumeFuncReturnFuncs[returnType] = consumeFuncReturnFunc;
consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc);
};
// Memref return types.
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0),
"refbackend_consume_memref_int32_func_return");
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
"refbackend_consume_memref_int64_func_return");
createConsumeFuncReturnFunc(
UnrankedMemRefType::get(b.getF32Type(), 0),
"refbackend_consume_memref_float32_func_return");
createConsumeFuncReturnFunc(
UnrankedMemRefType::get(b.getF64Type(), 0),
"refbackend_consume_memref_float64_func_return");
// Elemental return types.
createConsumeFuncReturnFunc(b.getI64Type(),
"refbackend_consume_int64_func_return");
createConsumeFuncReturnFunc(b.getF32Type(),
"refbackend_consume_float32_func_return");
createConsumeFuncReturnFunc(b.getF64Type(),
"refbackend_consume_float64_func_return");
for (auto func : module.getOps<FuncOp>()) { for (auto func : module.getOps<FuncOp>()) {
if (consumeFuncReturnFuncsSet.contains(func)) if (failed(mungeFunction(func, supported, invokedConsumeFuncReturnFuncs)))
continue;
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
return signalPassFailure(); return signalPassFailure();
} }
// Create FuncOp for consumeFuncReturnFuncs that are used.
for (auto &p : invokedConsumeFuncReturnFuncs) {
auto consumeFuncReturnFunc =
b.create<FuncOp>(module.getLoc(), p.first,
FunctionType::get(module.getContext(), p.second, {}),
b.getStringAttr("private"));
addEmitCInterfaceAttr(consumeFuncReturnFunc);
}
} }
}; };
} // namespace } // namespace

View File

@ -399,7 +399,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
emit_op(registry[key], f, **kwargs) emit_op(registry[key], f, **kwargs)
emit("prim::layout : (Tensor) -> (int)") emit("prim::layout : (Tensor) -> (int)")
emit("prim::TupleIndex : (Any, int) -> (Any)") emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True)
emit("prim::device : (Tensor) -> (Device)") emit("prim::device : (Tensor) -> (Device)")
emit("prim::dtype : (Tensor) -> (int)", has_folder=True) emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True) emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)

View File

@ -27,59 +27,73 @@ def checkArgTypeIsSupported(ty):
SUPPORTED = [np.float32, np.float64, np.int32, np.int64] SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported" assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
class RefBackendInvoker: class RefBackendInvoker:
def __init__(self, module): def __init__(self, module):
self.ee = ExecutionEngine(module) self.ee = ExecutionEngine(module)
self.result = None self.result = None
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_i32_return(a): def consume_return_mri32(a):
self.result = unranked_memref_to_numpy(a, np.int32) self.result = unranked_memref_to_numpy(a, np.int32)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_i64_return(a): def consume_return_mri64(a):
self.result = unranked_memref_to_numpy(a, np.int64) self.result = unranked_memref_to_numpy(a, np.int64)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_f32_return(a): def consume_return_mrf32(a):
self.result = unranked_memref_to_numpy(a, np.float32) self.result = unranked_memref_to_numpy(a, np.float32)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_memref_f64_return(a): def consume_return_mrf64(a):
self.result = unranked_memref_to_numpy(a, np.float64) self.result = unranked_memref_to_numpy(a, np.float64)
@ctypes.CFUNCTYPE(None, ctypes.c_int) @ctypes.CFUNCTYPE(None, ctypes.c_int)
def consume_i64_return(a): def consume_return_i64(a):
self.result = a self.result = a
@ctypes.CFUNCTYPE(None, ctypes.c_float) @ctypes.CFUNCTYPE(None, ctypes.c_float)
def consume_f32_return(a): def consume_return_f32(a):
self.result = a self.result = a
@ctypes.CFUNCTYPE(None, ctypes.c_double) @ctypes.CFUNCTYPE(None, ctypes.c_double)
def consume_f64_return(a): def consume_return_f64(a):
self.result = a self.result = a
self.ee.register_runtime("refbackend_consume_memref_int32_func_return", @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
consume_memref_i32_return) ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf32_mrf32_mrf32(arg0, arg1, arg2):
self.result = unranked_memref_to_numpy(
arg0, np.float32), unranked_memref_to_numpy(
arg1,
np.float32), unranked_memref_to_numpy(arg2, np.float32)
self.ee.register_runtime("refbackend_consume_memref_int64_func_return", self.ee.register_runtime("refbackend_consume_func_return_mri32",
consume_memref_i64_return) consume_return_mri32)
self.ee.register_runtime("refbackend_consume_memref_float32_func_return", self.ee.register_runtime("refbackend_consume_func_return_mri64",
consume_memref_f32_return) consume_return_mri64)
self.ee.register_runtime("refbackend_consume_memref_float64_func_return", self.ee.register_runtime("refbackend_consume_func_return_mrf32",
consume_memref_f64_return) consume_return_mrf32)
self.ee.register_runtime("refbackend_consume_int64_func_return", self.ee.register_runtime("refbackend_consume_func_return_mrf64",
consume_i64_return) consume_return_mrf64)
self.ee.register_runtime("refbackend_consume_float32_func_return", self.ee.register_runtime("refbackend_consume_func_return_i64",
consume_f32_return) consume_return_i64)
self.ee.register_runtime("refbackend_consume_float64_func_return", self.ee.register_runtime("refbackend_consume_func_return_f32",
consume_f64_return) consume_return_f32)
self.ee.register_runtime("refbackend_consume_func_return_f64",
consume_return_f64)
self.ee.register_runtime(
"refbackend_consume_func_return_mrf32_mrf32_mrf32",
consume_return_mrf32_mrf32_mrf32)
def __getattr__(self, function_name: str): def __getattr__(self, function_name: str):
def invoke(*args): def invoke(*args):

View File

@ -47,3 +47,53 @@ func @none_call_return() {
"test.use"(%0) : (!torch.none) -> () "test.use"(%0) : (!torch.none) -> ()
return return
} }
// CHECK-LABEL: func @tuple_return(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] :
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
return %1 : !torch.tuple<!torch.tensor, !torch.tensor>
}
// CHECK-LABEL: func @call_tuple_return(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) :
// CHECK-SAME: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor)
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 :
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor>
return %0 : !torch.tuple<!torch.tensor, !torch.tensor>
}

View File

@ -568,3 +568,29 @@ func @torch.tensor_static_info_cast$upcast_first(%t: !torch.tensor<[?,?],f64>) -
%downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64> %downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64>
return %downcast: !torch.tensor<[?,?],f64> return %downcast: !torch.tensor<[?,?],f64>
} }
// CHECK-LABEL: func @torch.prim.TupleIndex(
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[T1]] : !torch.tensor
func @torch.prim.TupleIndex(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
%int1 = torch.constant.int 1
%1 = torch.prim.TupleIndex %0, %int1 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
return %1 : !torch.tensor
}
// CHECK-LABEL: func @torch.prim.TupleIndex$out_of_bound(
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[INDEX3:.*]] = torch.constant.int 3
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]], %[[T2]] :
// CHECK-SAME: !torch.tensor, !torch.tensor, !torch.tensor ->
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
// CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INDEX3]] :
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
%int3 = torch.constant.int 3
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
return %1 : !torch.tensor
}

View File

@ -23,8 +23,6 @@ func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor
return %2 : !torch.tensor return %2 : !torch.tensor
} }
// ----- // -----
// Call to public function. // Call to public function.

View File

@ -4,7 +4,7 @@
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32> // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32> // CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_memref_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> () // CHECK: call @refbackend_consume_func_return_mrf32(%[[RESULT]]) : (memref<*xf32>) -> ()
// CHECK: return // CHECK: return
func @f(%arg0: memref<?xf32>) -> memref<?xf32> { func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
return %arg0 : memref<?xf32> return %arg0 : memref<?xf32>
@ -16,7 +16,7 @@ func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64> // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64> // CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
// CHECK: call @refbackend_consume_memref_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> () // CHECK: call @refbackend_consume_func_return_mri64(%[[RESULT]]) : (memref<*xi64>) -> ()
// CHECK: return // CHECK: return
func @i(%arg0: memref<?xi64>) -> memref<?xi64> { func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
return %arg0 : memref<?xi64> return %arg0 : memref<?xi64>
@ -28,9 +28,28 @@ func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} { // CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64> // CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64>
// CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64> // CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64>
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (i64) -> () // CHECK: call @refbackend_consume_func_return_i64(%[[RESULT]]) : (i64) -> ()
// CHECK: return // CHECK: return
func @elemental_type(%arg0: memref<i64>) -> i64 { func @elemental_type(%arg0: memref<i64>) -> i64 {
%0 = memref.load %arg0[] : memref<i64> %0 = memref.load %arg0[] : memref<i64>
return %0 : i64 return %0 : i64
} }
// -----
// CHECK-LABEL: func @multiple_return_values(
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[VAL1:.*]] = memref.cast %[[ARG1]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[VAL2:.*]] = memref.cast %[[ARG2]] : memref<*xf32> to memref<?xf32>
// CHECK: %[[RET0:.*]] = memref.cast %[[VAL0]] : memref<?xf32> to memref<*xf32>
// CHECK: %[[RET1:.*]] = memref.cast %[[VAL1]] : memref<?xf32> to memref<*xf32>
// CHECK: %[[RET2:.*]] = memref.cast %[[VAL2]] : memref<?xf32> to memref<*xf32>
// CHECK: call @refbackend_consume_func_return_mrf32_mrf32_mrf32(%[[RET0]], %[[RET1]], %[[RET2]])
// CHECK-SAME: : (memref<*xf32>, memref<*xf32>, memref<*xf32>) -> ()
// CHECK: return
func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
return %arg0 ,%arg1, %arg2 : memref<?xf32>, memref<?xf32>, memref<?xf32>
}