mirror of https://github.com/llvm/torch-mlir
Add support for multiple return values
This change is to unblock the work of some backprop ops returning more than one tensors. We will need to think of a more scalable approach in the future if more flexible return types combinations are needed.pull/418/head
parent
6e8d39642e
commit
0fe70994e5
|
@ -11,11 +11,11 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SoftmaxBackwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -33,6 +33,8 @@ class SoftmaxBackwardModule(torch.nn.Module):
|
|||
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class TanhBackwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -43,10 +45,11 @@ class TanhBackwardModule(torch.nn.Module):
|
|||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, output):
|
||||
return torch.ops.aten.tanh_backward(grad_out, output)
|
||||
|
||||
def forward(self, out_grad, output):
|
||||
return torch.ops.aten.tanh_backward(out_grad, output)
|
||||
|
||||
@register_test_case(module_factory=lambda: TanhBackwardModule())
|
||||
def TanhBackward_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3), torch.randn(3, 3))
|
||||
|
||||
|
|
|
@ -543,7 +543,7 @@ class TensorToInt(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: TensorToInt())
|
||||
def TensorToInt_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10,[]))
|
||||
|
||||
|
||||
class LogSoftmaxIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -577,3 +577,24 @@ class NumToTensorModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NumToTensorModule())
|
||||
def NumToTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# This test can be removed once we have one real op returning 3 float32 tensors
|
||||
class ReturnThreeTensorFloat32(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b, c):
|
||||
return a, b, c
|
||||
|
||||
@register_test_case(module_factory=lambda: ReturnThreeTensorFloat32())
|
||||
def ReturnThreeTensorFloat32_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3), tu.rand(2, 3), tu.rand(2, 3))
|
||||
|
||||
|
|
|
@ -29,4 +29,5 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseFloorModule_basic",
|
||||
"ElementwiseLogModule_basic",
|
||||
"TanhBackward_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
}
|
||||
|
|
|
@ -43,6 +43,7 @@ def Torch_PrimTupleIndexOp : Torch_Op<"prim.TupleIndex", [
|
|||
AnyTorchType:$result
|
||||
);
|
||||
let assemblyFormat = "$tup `,` $i attr-dict `:` type($tup) `,` type($i) `->` type($result)";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_PrimDeviceOp : Torch_Op<"prim.device", [
|
||||
|
|
|
@ -121,8 +121,8 @@ def AdjustCallingConventions
|
|||
function arguments, which should be `!numpy.ndarray<...>`'s.
|
||||
- Python-isms are rewritten to MLIR-isms
|
||||
- NoneType return is rewritten to the absence of a return value.
|
||||
- (Not implemented yet) Tuple return is rewritten to multiple return
|
||||
values
|
||||
- Tuple return is rewritten to multiple return values.
|
||||
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -219,14 +219,14 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "FuncOp"> {
|
|||
let summary = "Decompose complicated torch operations";
|
||||
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
|
||||
let description = [{
|
||||
Decompose torch operation that are losslessly represented as combinations of
|
||||
other operations, modulo appropropriate compiler fusion. Note that this pass
|
||||
is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about
|
||||
systematic reductions of a large number of ops at once, guided mostly by
|
||||
Decompose torch operation that are losslessly represented as combinations of
|
||||
other operations, modulo appropropriate compiler fusion. Note that this pass
|
||||
is similar in spirit to ReduceOpVariants, but ReduceOpVariants is about
|
||||
systematic reductions of a large number of ops at once, guided mostly by
|
||||
traits.
|
||||
|
||||
An example of the transformations done in this pass is:
|
||||
- convert aten.softmax to softmax(x, dim)
|
||||
- convert aten.softmax to softmax(x, dim)
|
||||
=> tmp=exp(x); tmp / sum(tmp, dim, keepdim=True)
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -916,6 +916,29 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimTupleIndexOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
|
||||
auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
||||
if (!tupleConstruct)
|
||||
return failure();
|
||||
|
||||
int64_t i;
|
||||
if (!matchPattern(op.i(), m_TorchConstantInt(&i)))
|
||||
return failure();
|
||||
|
||||
if (i >= (int64_t)tupleConstruct.elements().size())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, tupleConstruct.elements()[i]);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimTupleUnpackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -923,9 +946,7 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
|||
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
|
||||
auto torchTuple = op.tup();
|
||||
auto tupleConstruct =
|
||||
torchTuple.getDefiningOp<Torch::PrimTupleConstructOp>();
|
||||
auto tupleConstruct = op.tup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
||||
if (!tupleConstruct)
|
||||
return failure();
|
||||
|
||||
|
|
|
@ -66,15 +66,20 @@ public:
|
|||
// TODO: add tuple type.
|
||||
conversion.addInputs(type.index(), type.value());
|
||||
}
|
||||
rewriter.applySignatureConversion(&func.getBody(), conversion,
|
||||
typeConverter);
|
||||
|
||||
SmallVector<Type> newResultTypes;
|
||||
for (auto type : func.getType().getResults()) {
|
||||
if (auto none = type.dyn_cast<Torch::NoneType>()) {
|
||||
continue;
|
||||
}
|
||||
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
|
||||
llvm::append_range(newResultTypes, tuple.getContainedTypes());
|
||||
continue;
|
||||
}
|
||||
newResultTypes.push_back(type);
|
||||
}
|
||||
rewriter.applySignatureConversion(&func.getBody(), conversion,
|
||||
typeConverter);
|
||||
rewriter.updateRootInPlace(func, [&] {
|
||||
func.setType(FunctionType::get(
|
||||
getContext(), conversion.getConvertedTypes(), newResultTypes));
|
||||
|
@ -131,6 +136,11 @@ public:
|
|||
rewriter.create<ConstantNoneOp>(call.getLoc(), type));
|
||||
continue;
|
||||
}
|
||||
if (type.isa<Torch::TupleType>()) {
|
||||
newResults.push_back(rewriter.create<PrimTupleConstructOp>(
|
||||
call.getLoc(), type, newCall.getResults()));
|
||||
continue;
|
||||
}
|
||||
newResults.push_back(newCall.getResult(newOpResultIdx++));
|
||||
}
|
||||
rewriter.replaceOp(call, newResults);
|
||||
|
@ -151,12 +161,22 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
SmallVector<Value> newOperands;
|
||||
for (auto operand : llvm::enumerate(adaptor.getOperands())) {
|
||||
if (!operand.value())
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
if (!operand)
|
||||
continue;
|
||||
if (operand.value().getType().isa<Torch::NoneType>())
|
||||
if (operand.getType().isa<Torch::NoneType>())
|
||||
continue;
|
||||
newOperands.push_back(operand.value());
|
||||
if (auto tuple = operand.getType().dyn_cast<Torch::TupleType>()) {
|
||||
Location loc = op.getLoc();
|
||||
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
|
||||
auto i = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(en.index()));
|
||||
newOperands.push_back(
|
||||
rewriter.create<PrimTupleIndexOp>(loc, en.value(), operand, i));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
newOperands.push_back(operand);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
|
||||
return success();
|
||||
|
@ -168,9 +188,14 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
|||
TypeBoundMap &typeBoundMap) {
|
||||
MLIRContext *context = func.getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
// TODO: TupleTypes
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
typeConverter.addConversion(
|
||||
[](Torch::TupleType type,
|
||||
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
|
||||
llvm::append_range(types, type.getContainedTypes());
|
||||
return success();
|
||||
});
|
||||
typeConverter.addConversion(
|
||||
[](Torch::NoneType type,
|
||||
SmallVectorImpl<Type> &types) -> Optional<LogicalResult> {
|
||||
|
@ -220,6 +245,9 @@ static LogicalResult adjustCallingConventions(FuncOp func,
|
|||
target.addLegalOp<CopyToNonValueTensorOp, CopyToValueTensorOp>();
|
||||
target.addLegalOp<TensorStaticInfoCastOp>();
|
||||
target.addLegalOp<ConstantNoneOp>();
|
||||
target.addLegalOp<ConstantIntOp>();
|
||||
target.addLegalOp<PrimTupleIndexOp>();
|
||||
target.addLegalOp<PrimTupleConstructOp>();
|
||||
// We don't know how to rewrite it, so mark it as illegal.
|
||||
target.addIllegalOp<CallIndirectOp>();
|
||||
if (failed(applyPartialConversion(func.getOperation(), target,
|
||||
|
|
|
@ -89,6 +89,9 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline(
|
|||
pm.addPass(createAdjustCallingConventionsPass());
|
||||
|
||||
if (options.optimize) {
|
||||
// Eliminate the PrimTupleIndexOp generated from the
|
||||
// adjustCallingConventions
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
// Inline global slots, which for most inference scenarios deletes them.
|
||||
// This also exposes more information to intraprocedural transformations
|
||||
// below like MaximizeValueSemantics and RefineTypes.
|
||||
|
|
|
@ -21,7 +21,9 @@
|
|||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "set"
|
||||
#include "torch-mlir/RefBackend/Passes.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -67,33 +69,47 @@ static Type getAbiTypeForMemRef(Type type) {
|
|||
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(), 0);
|
||||
}
|
||||
|
||||
// Passes the return op operands `val` to `funOp`. Also, adds the op to the
|
||||
// `toErase` vector.
|
||||
static void replaceCallToFunction(OpBuilder b, ReturnOp op, FuncOp funcOp,
|
||||
Value val,
|
||||
// Helper function to get the type string for one return value like i32, f64,
|
||||
// mri32 etc. The strings from multiple return values are concatenated to get
|
||||
// the consumeFuncReturnFunc name.
|
||||
static std::string getTypeToken(Type type) {
|
||||
if (type.isSignlessInteger())
|
||||
return ("i" + Twine(type.getIntOrFloatBitWidth())).str();
|
||||
else if (type.isa<mlir::FloatType>())
|
||||
return ("f" + Twine(type.getIntOrFloatBitWidth())).str();
|
||||
else if (auto memRefType = type.dyn_cast<UnrankedMemRefType>())
|
||||
return "mr" + getTypeToken(memRefType.getElementType());
|
||||
|
||||
llvm_unreachable(
|
||||
"Type token should handle all types: memref, float and int type");
|
||||
}
|
||||
|
||||
// Systematically derive the consumeFuncReturnFunc name from return value types.
|
||||
static std::string getConsumeReturnFunctionNameForReturnTypes(TypeRange types) {
|
||||
SmallVector<std::string> tokens = {"refbackend_consume_func_return"};
|
||||
for (auto type : types)
|
||||
tokens.push_back(getTypeToken(type));
|
||||
|
||||
return std::accumulate(tokens.begin(), tokens.end(), std::string(),
|
||||
[](std::string &a, std::string &b) {
|
||||
return a.empty() ? b : (a + "_" + b);
|
||||
});
|
||||
}
|
||||
|
||||
// Replace the original returnOp with a call to consumeFuncReturnFunc and add
|
||||
// the op to the `toErase` vector.
|
||||
static void replaceReturnWithCall(OpBuilder b, ReturnOp op, StringRef funcName,
|
||||
TypeRange retTypes,
|
||||
SmallVectorImpl<Value> &vals,
|
||||
SmallVectorImpl<Operation *> &toErase) {
|
||||
b.create<mlir::CallOp>(op.getLoc(), funcOp, val);
|
||||
b.create<mlir::CallOp>(op.getLoc(), funcName, TypeRange({}), vals);
|
||||
b.create<mlir::ReturnOp>(op.getLoc());
|
||||
toErase.push_back(op);
|
||||
}
|
||||
|
||||
// Checks whether the return op is munge-compatible and the respective calling
|
||||
// function is defined.
|
||||
static bool isReturnOpCompatible(ReturnOp op,
|
||||
DenseMap<Type, FuncOp> &consumeFuncReturnFuncs,
|
||||
Type returnType) {
|
||||
auto it = consumeFuncReturnFuncs.find(returnType);
|
||||
if (op.getNumOperands() != 1 || it == consumeFuncReturnFuncs.end()) {
|
||||
op.emitError("must have one return value of Memref type or Elemental types "
|
||||
"of i64, f64, f32");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static LogicalResult mungeFunction(
|
||||
FuncOp func,
|
||||
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs) {
|
||||
FuncOp func, std::set<std::string> &supportedConsumeFuncReturnFuncs,
|
||||
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
|
||||
// Add `llvm.emit_c_interface`.
|
||||
// This allows ExecutionEngine to resolve the symbol properly.
|
||||
addEmitCInterfaceAttr(func);
|
||||
|
@ -120,37 +136,43 @@ static LogicalResult mungeFunction(
|
|||
}
|
||||
|
||||
SmallVector<Operation *> toErase;
|
||||
bool isCompatible = false;
|
||||
bool isSupported = true;
|
||||
func.walk([&](ReturnOp op) {
|
||||
auto returnType = op.getOperandTypes()[0];
|
||||
|
||||
auto types = op.getOperandTypes();
|
||||
b.setInsertionPoint(op);
|
||||
// Memref Types.
|
||||
if (auto memrefReturnType = returnType.dyn_cast<MemRefType>()) {
|
||||
auto elemType = memrefReturnType.getElementType();
|
||||
auto unRankedType = UnrankedMemRefType::get(elemType, 0);
|
||||
isCompatible =
|
||||
isReturnOpCompatible(op, consumeFuncReturnFuncs, unRankedType);
|
||||
if (!isCompatible)
|
||||
return;
|
||||
|
||||
// Cast to unranked memref type before sending it as a function argument.
|
||||
auto cast = b.create<memref::CastOp>(
|
||||
op.getLoc(), op.getOperand(0),
|
||||
getAbiTypeForMemRef(op.getOperandTypes()[0]));
|
||||
replaceCallToFunction(b, op, consumeFuncReturnFuncs[unRankedType],
|
||||
cast.getResult(), toErase);
|
||||
// Elemental types.
|
||||
} else if (returnType.isa<IntegerType>() || returnType.isa<FloatType>()) {
|
||||
isCompatible =
|
||||
isReturnOpCompatible(op, consumeFuncReturnFuncs, returnType);
|
||||
if (!isCompatible)
|
||||
return;
|
||||
replaceCallToFunction(b, op, consumeFuncReturnFuncs[returnType],
|
||||
op->getOperand(0), toErase);
|
||||
std::vector<Type> retTypes;
|
||||
SmallVector<Value> retVals;
|
||||
for (auto en : llvm::enumerate(types)) {
|
||||
Type retType = en.value();
|
||||
Value retVal = op.getOperand(en.index());
|
||||
if (auto memrefReturnType = retType.dyn_cast<MemRefType>()) {
|
||||
auto elemType = memrefReturnType.getElementType();
|
||||
retType = UnrankedMemRefType::get(elemType, 0);
|
||||
// Cast to unranked memref type before sending it as a function
|
||||
// argument.
|
||||
retVal = b.create<memref::CastOp>(
|
||||
op.getLoc(), retVal, getAbiTypeForMemRef(types[en.index()]));
|
||||
}
|
||||
retTypes.push_back(retType);
|
||||
retVals.push_back(retVal);
|
||||
}
|
||||
|
||||
auto supportedFuncsEnd = supportedConsumeFuncReturnFuncs.end();
|
||||
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
|
||||
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
|
||||
op.emitError(
|
||||
"must have one return value of memref types or scalar types "
|
||||
"of i32, i64, f32, f64 or three return values of memref f32");
|
||||
isSupported = false;
|
||||
}
|
||||
|
||||
auto invokedFuncsEnd = invokedConsumeFuncReturnFuncs.end();
|
||||
if (invokedConsumeFuncReturnFuncs.find(funcName) == invokedFuncsEnd)
|
||||
invokedConsumeFuncReturnFuncs.insert({funcName, retTypes});
|
||||
replaceReturnWithCall(b, op, funcName, retTypes, retVals, toErase);
|
||||
});
|
||||
if (!isCompatible)
|
||||
if (!isSupported)
|
||||
return failure();
|
||||
func.setType(FunctionType::get(func.getContext(), newArgTypes, {}));
|
||||
for (Operation *op : toErase)
|
||||
|
@ -158,50 +180,47 @@ static LogicalResult mungeFunction(
|
|||
return success();
|
||||
}
|
||||
|
||||
static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
|
||||
std::set<std::string> funcNames;
|
||||
Type mri32 = UnrankedMemRefType::get(b.getI32Type(), 0);
|
||||
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
|
||||
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
|
||||
Type mrf64 = UnrankedMemRefType::get(b.getF64Type(), 0);
|
||||
Type i64 = b.getI64Type();
|
||||
Type f32 = b.getF32Type();
|
||||
Type f64 = b.getF64Type();
|
||||
|
||||
SmallVector<TypeRange> supportedReturnTypes = {
|
||||
mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}};
|
||||
|
||||
llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
|
||||
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));
|
||||
});
|
||||
return funcNames;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class MungeCallingConventions
|
||||
: public MungeCallingConventionsBase<MungeCallingConventions> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
OpBuilder b(module.getBodyRegion());
|
||||
DenseMap</*returnElementType*/ Type, FuncOp> consumeFuncReturnFuncs;
|
||||
DenseSet<FuncOp> consumeFuncReturnFuncsSet;
|
||||
auto createConsumeFuncReturnFunc = [&](Type returnType,
|
||||
std::string funcName) {
|
||||
auto consumeFuncReturnFunc = b.create<FuncOp>(
|
||||
module.getLoc(), funcName,
|
||||
FunctionType::get(module.getContext(), returnType, {}),
|
||||
b.getStringAttr("private"));
|
||||
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
||||
consumeFuncReturnFuncs[returnType] = consumeFuncReturnFunc;
|
||||
consumeFuncReturnFuncsSet.insert(consumeFuncReturnFunc);
|
||||
};
|
||||
|
||||
// Memref return types.
|
||||
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI32Type(), 0),
|
||||
"refbackend_consume_memref_int32_func_return");
|
||||
createConsumeFuncReturnFunc(UnrankedMemRefType::get(b.getI64Type(), 0),
|
||||
"refbackend_consume_memref_int64_func_return");
|
||||
createConsumeFuncReturnFunc(
|
||||
UnrankedMemRefType::get(b.getF32Type(), 0),
|
||||
"refbackend_consume_memref_float32_func_return");
|
||||
createConsumeFuncReturnFunc(
|
||||
UnrankedMemRefType::get(b.getF64Type(), 0),
|
||||
"refbackend_consume_memref_float64_func_return");
|
||||
|
||||
// Elemental return types.
|
||||
createConsumeFuncReturnFunc(b.getI64Type(),
|
||||
"refbackend_consume_int64_func_return");
|
||||
createConsumeFuncReturnFunc(b.getF32Type(),
|
||||
"refbackend_consume_float32_func_return");
|
||||
createConsumeFuncReturnFunc(b.getF64Type(),
|
||||
"refbackend_consume_float64_func_return");
|
||||
static std::set<std::string> supported =
|
||||
getSupportedConsumeFuncReturnFuncs(b);
|
||||
std::map<std::string, std::vector<Type>> invokedConsumeFuncReturnFuncs;
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (consumeFuncReturnFuncsSet.contains(func))
|
||||
continue;
|
||||
if (failed(mungeFunction(func, consumeFuncReturnFuncs)))
|
||||
if (failed(mungeFunction(func, supported, invokedConsumeFuncReturnFuncs)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// Create FuncOp for consumeFuncReturnFuncs that are used.
|
||||
for (auto &p : invokedConsumeFuncReturnFuncs) {
|
||||
auto consumeFuncReturnFunc =
|
||||
b.create<FuncOp>(module.getLoc(), p.first,
|
||||
FunctionType::get(module.getContext(), p.second, {}),
|
||||
b.getStringAttr("private"));
|
||||
addEmitCInterfaceAttr(consumeFuncReturnFunc);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -399,7 +399,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit_op(registry[key], f, **kwargs)
|
||||
|
||||
emit("prim::layout : (Tensor) -> (int)")
|
||||
emit("prim::TupleIndex : (Any, int) -> (Any)")
|
||||
emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True)
|
||||
emit("prim::device : (Tensor) -> (Device)")
|
||||
emit("prim::dtype : (Tensor) -> (int)", has_folder=True)
|
||||
emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True)
|
||||
|
|
|
@ -27,59 +27,73 @@ def checkArgTypeIsSupported(ty):
|
|||
SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
|
||||
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
|
||||
|
||||
|
||||
class RefBackendInvoker:
|
||||
def __init__(self, module):
|
||||
self.ee = ExecutionEngine(module)
|
||||
self.result = None
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_memref_i32_return(a):
|
||||
def consume_return_mri32(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.int32)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_memref_i64_return(a):
|
||||
def consume_return_mri64(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.int64)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_memref_f32_return(a):
|
||||
def consume_return_mrf32(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.float32)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_memref_f64_return(a):
|
||||
def consume_return_mrf64(a):
|
||||
self.result = unranked_memref_to_numpy(a, np.float64)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.c_int)
|
||||
def consume_i64_return(a):
|
||||
def consume_return_i64(a):
|
||||
self.result = a
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.c_float)
|
||||
def consume_f32_return(a):
|
||||
def consume_return_f32(a):
|
||||
self.result = a
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.c_double)
|
||||
def consume_f64_return(a):
|
||||
def consume_return_f64(a):
|
||||
self.result = a
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_memref_int32_func_return",
|
||||
consume_memref_i32_return)
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
|
||||
ctypes.POINTER(UnrankedMemRefDescriptor),
|
||||
ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_return_mrf32_mrf32_mrf32(arg0, arg1, arg2):
|
||||
self.result = unranked_memref_to_numpy(
|
||||
arg0, np.float32), unranked_memref_to_numpy(
|
||||
arg1,
|
||||
np.float32), unranked_memref_to_numpy(arg2, np.float32)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_memref_int64_func_return",
|
||||
consume_memref_i64_return)
|
||||
self.ee.register_runtime("refbackend_consume_func_return_mri32",
|
||||
consume_return_mri32)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_memref_float32_func_return",
|
||||
consume_memref_f32_return)
|
||||
self.ee.register_runtime("refbackend_consume_func_return_mri64",
|
||||
consume_return_mri64)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_memref_float64_func_return",
|
||||
consume_memref_f64_return)
|
||||
self.ee.register_runtime("refbackend_consume_func_return_mrf32",
|
||||
consume_return_mrf32)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_int64_func_return",
|
||||
consume_i64_return)
|
||||
self.ee.register_runtime("refbackend_consume_func_return_mrf64",
|
||||
consume_return_mrf64)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_float32_func_return",
|
||||
consume_f32_return)
|
||||
self.ee.register_runtime("refbackend_consume_func_return_i64",
|
||||
consume_return_i64)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_float64_func_return",
|
||||
consume_f64_return)
|
||||
self.ee.register_runtime("refbackend_consume_func_return_f32",
|
||||
consume_return_f32)
|
||||
|
||||
self.ee.register_runtime("refbackend_consume_func_return_f64",
|
||||
consume_return_f64)
|
||||
|
||||
self.ee.register_runtime(
|
||||
"refbackend_consume_func_return_mrf32_mrf32_mrf32",
|
||||
consume_return_mrf32_mrf32_mrf32)
|
||||
|
||||
def __getattr__(self, function_name: str):
|
||||
def invoke(*args):
|
||||
|
|
|
@ -47,3 +47,53 @@ func @none_call_return() {
|
|||
"test.use"(%0) : (!torch.none) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tuple_return(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
|
||||
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
|
||||
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
|
||||
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] :
|
||||
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
|
||||
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
|
||||
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
|
||||
func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
|
||||
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
|
||||
%1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
return %1 : !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @call_tuple_return(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) {
|
||||
// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||
// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor
|
||||
// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor
|
||||
// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor
|
||||
// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
|
||||
// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32>
|
||||
// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) :
|
||||
// CHECK-SAME: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor)
|
||||
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 :
|
||||
// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] :
|
||||
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] :
|
||||
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||
// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor
|
||||
func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>},
|
||||
%arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple<!torch.tensor, !torch.tensor> {
|
||||
%0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
return %0 : !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
}
|
||||
|
|
|
@ -568,3 +568,29 @@ func @torch.tensor_static_info_cast$upcast_first(%t: !torch.tensor<[?,?],f64>) -
|
|||
%downcast = torch.tensor_static_info_cast %upcast : !torch.tensor to !torch.tensor<[?,?],f64>
|
||||
return %downcast: !torch.tensor<[?,?],f64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.TupleIndex(
|
||||
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: return %[[T1]] : !torch.tensor
|
||||
func @torch.prim.TupleIndex(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
|
||||
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
|
||||
%int1 = torch.constant.int 1
|
||||
%1 = torch.prim.TupleIndex %0, %int1 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.TupleIndex$out_of_bound(
|
||||
// CHECK-SAME: %[[T0:.*]]: !torch.tensor, %[[T1:.*]]: !torch.tensor, %[[T2:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: %[[INDEX3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]], %[[T2]] :
|
||||
// CHECK-SAME: !torch.tensor, !torch.tensor, !torch.tensor ->
|
||||
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
|
||||
// CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[INDEX3]] :
|
||||
// CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||
// CHECK: return %[[RET]] : !torch.tensor
|
||||
func @torch.prim.TupleIndex$out_of_bound(%t0: !torch.tensor, %t1: !torch.tensor, %t2: !torch.tensor) -> !torch.tensor {
|
||||
%0 = torch.prim.TupleConstruct %t0, %t1, %t2 : !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>
|
||||
%int3 = torch.constant.int 3
|
||||
%1 = torch.prim.TupleIndex %0, %int3 : !torch.tuple<!torch.tensor, !torch.tensor, !torch.tensor>, !torch.int -> !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
|
|
@ -23,8 +23,6 @@ func private @basic_private(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.tensor
|
|||
return %2 : !torch.tensor
|
||||
}
|
||||
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Call to public function.
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
||||
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xf32> to memref<*xf32>
|
||||
// CHECK: call @refbackend_consume_memref_float32_func_return(%[[RESULT]]) : (memref<*xf32>) -> ()
|
||||
// CHECK: call @refbackend_consume_func_return_mrf32(%[[RESULT]]) : (memref<*xf32>) -> ()
|
||||
// CHECK: return
|
||||
func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||
return %arg0 : memref<?xf32>
|
||||
|
@ -16,7 +16,7 @@ func @f(%arg0: memref<?xf32>) -> memref<?xf32> {
|
|||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
|
||||
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<?xi64>
|
||||
// CHECK: %[[RESULT:.*]] = memref.cast %[[VAL]] : memref<?xi64> to memref<*xi64>
|
||||
// CHECK: call @refbackend_consume_memref_int64_func_return(%[[RESULT]]) : (memref<*xi64>) -> ()
|
||||
// CHECK: call @refbackend_consume_func_return_mri64(%[[RESULT]]) : (memref<*xi64>) -> ()
|
||||
// CHECK: return
|
||||
func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
|
||||
return %arg0 : memref<?xi64>
|
||||
|
@ -28,9 +28,28 @@ func @i(%arg0: memref<?xi64>) -> memref<?xi64> {
|
|||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xi64>) attributes {llvm.emit_c_interface} {
|
||||
// CHECK: %[[VAL:.*]] = memref.cast %[[ARG0]] : memref<*xi64> to memref<i64>
|
||||
// CHECK: %[[RESULT:.*]] = memref.load %[[VAL]][] : memref<i64>
|
||||
// CHECK: call @refbackend_consume_int64_func_return(%[[RESULT]]) : (i64) -> ()
|
||||
// CHECK: call @refbackend_consume_func_return_i64(%[[RESULT]]) : (i64) -> ()
|
||||
// CHECK: return
|
||||
func @elemental_type(%arg0: memref<i64>) -> i64 {
|
||||
%0 = memref.load %arg0[] : memref<i64>
|
||||
return %0 : i64
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @multiple_return_values(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<*xf32>, %[[ARG1:.*]]: memref<*xf32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<*xf32>) attributes {llvm.emit_c_interface} {
|
||||
// CHECK: %[[VAL0:.*]] = memref.cast %[[ARG0]] : memref<*xf32> to memref<?xf32>
|
||||
// CHECK: %[[VAL1:.*]] = memref.cast %[[ARG1]] : memref<*xf32> to memref<?xf32>
|
||||
// CHECK: %[[VAL2:.*]] = memref.cast %[[ARG2]] : memref<*xf32> to memref<?xf32>
|
||||
// CHECK: %[[RET0:.*]] = memref.cast %[[VAL0]] : memref<?xf32> to memref<*xf32>
|
||||
// CHECK: %[[RET1:.*]] = memref.cast %[[VAL1]] : memref<?xf32> to memref<*xf32>
|
||||
// CHECK: %[[RET2:.*]] = memref.cast %[[VAL2]] : memref<?xf32> to memref<*xf32>
|
||||
// CHECK: call @refbackend_consume_func_return_mrf32_mrf32_mrf32(%[[RET0]], %[[RET1]], %[[RET2]])
|
||||
// CHECK-SAME: : (memref<*xf32>, memref<*xf32>, memref<*xf32>) -> ()
|
||||
// CHECK: return
|
||||
|
||||
func @multiple_return_values(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
|
||||
return %arg0 ,%arg1, %arg2 : memref<?xf32>, memref<?xf32>, memref<?xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue