mirror of https://github.com/llvm/torch-mlir
Ensure that imported function input type and block arg types are consistent.
I wasn't able to find exactly what frontend situation created it, but `torch.jit.trace` will sometimes create functions where the `jit::Block`'s param node has refined tensor types. So we need to adjust the function's formal param types to those refined types.pull/802/head
parent
7f2577a848
commit
73cc2ac152
|
@ -57,6 +57,10 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
||||||
i++) {
|
i++) {
|
||||||
resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i));
|
resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i));
|
||||||
}
|
}
|
||||||
|
std::vector<MlirType> inputTypes;
|
||||||
|
for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) {
|
||||||
|
inputTypes.push_back(mlirFunctionTypeGetInput(functionType, i));
|
||||||
|
}
|
||||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||||
MlirBlock appendToBlock) {
|
MlirBlock appendToBlock) {
|
||||||
createMlirOperationAtEnd(
|
createMlirOperationAtEnd(
|
||||||
|
@ -65,7 +69,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
||||||
};
|
};
|
||||||
MlirBlock block = importBlock(
|
MlirBlock block = importBlock(
|
||||||
context, torch::jit::toGraphFunction(*function).graph()->block(),
|
context, torch::jit::toGraphFunction(*function).graph()->block(),
|
||||||
createTerminator);
|
createTerminator, inputTypes);
|
||||||
mlirRegionAppendOwnedBlock(bodyRegion, block);
|
mlirRegionAppendOwnedBlock(bodyRegion, block);
|
||||||
return func;
|
return func;
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,10 +34,14 @@ public:
|
||||||
NodeImporter(MlirContext context) : context(context) {}
|
NodeImporter(MlirContext context) : context(context) {}
|
||||||
|
|
||||||
void importNode(Node *node, MlirBlock appendToBlock);
|
void importNode(Node *node, MlirBlock appendToBlock);
|
||||||
MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator);
|
MlirBlock importBlock(
|
||||||
|
Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||||
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MlirBlock createBlockFor(Block *jitBlock);
|
MlirBlock
|
||||||
|
createBlockFor(Block *jitBlock,
|
||||||
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes);
|
||||||
void mapValue(Value *jitValue, MlirValue value);
|
void mapValue(Value *jitValue, MlirValue value);
|
||||||
void mapResults(Node *node, MlirOperation operation);
|
void mapResults(Node *node, MlirOperation operation);
|
||||||
MlirValue lookupMappedValue(Value *jitValue);
|
MlirValue lookupMappedValue(Value *jitValue);
|
||||||
|
@ -256,11 +260,12 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||||
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
||||||
auto methodName = node->s(c10::attr::name);
|
auto methodName = node->s(c10::attr::name);
|
||||||
torch::jit::Function *function = classType->findMethod(methodName);
|
torch::jit::Function *function = classType->findMethod(methodName);
|
||||||
torch::jit::Block *calleeEntryBlock =
|
MlirType calleeType =
|
||||||
torch::jit::toGraphFunction(*function).graph()->block();
|
getFunctionTypeFromSchema(context, function->getSchema());
|
||||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
std::vector<MlirType> expectedTypes;
|
||||||
return getMlirTypeFromTorchType(loc, v->type());
|
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
|
||||||
});
|
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
|
||||||
|
}
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
appendToBlock, "torch.prim.CallMethod", loc,
|
appendToBlock, "torch.prim.CallMethod", loc,
|
||||||
getMlirTypesFromValues(loc, node->outputs()),
|
getMlirTypesFromValues(loc, node->outputs()),
|
||||||
|
@ -298,9 +303,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock NodeImporter::importBlock(Block *jitBlock,
|
MlirBlock NodeImporter::importBlock(
|
||||||
CreateTerminatorFn createTerminator) {
|
Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||||
MlirBlock block = createBlockFor(jitBlock);
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
|
||||||
|
MlirBlock block = createBlockFor(jitBlock, blockArgTypes);
|
||||||
for (Node *node : jitBlock->nodes()) {
|
for (Node *node : jitBlock->nodes()) {
|
||||||
importNode(node, block);
|
importNode(node, block);
|
||||||
}
|
}
|
||||||
|
@ -309,17 +315,56 @@ MlirBlock NodeImporter::importBlock(Block *jitBlock,
|
||||||
return block;
|
return block;
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock NodeImporter::createBlockFor(Block *jitBlock) {
|
static MlirValue adjustBlockArgType(MlirContext context,
|
||||||
|
MlirBlock appendToBlock, MlirValue value,
|
||||||
|
MlirType expectedType, MlirLocation loc) {
|
||||||
|
MlirType type = mlirValueGetType(value);
|
||||||
|
if (mlirTypeEqual(type, expectedType)) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
// For tensors, we might need to erase or add static type information.
|
||||||
|
if (torchMlirTypeIsATorchNonValueTensor(type) ||
|
||||||
|
torchMlirTypeIsATorchValueTensor(type)) {
|
||||||
|
MlirOperation op =
|
||||||
|
createMlirOperationAtEnd(appendToBlock, "torch.tensor_static_info_cast",
|
||||||
|
loc, expectedType, value);
|
||||||
|
return mlirOperationGetResult(op, 0);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
std::stringstream msg;
|
||||||
|
MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) {
|
||||||
|
std::stringstream *stream = static_cast<std::stringstream *>(userData);
|
||||||
|
stream->write(str.data, str.length);
|
||||||
|
};
|
||||||
|
msg << "unhandled: could not adjust formal param type from ";
|
||||||
|
mlirTypePrint(type, printToStream, static_cast<void *>(&msg));
|
||||||
|
msg << " to expected type ";
|
||||||
|
mlirTypePrint(expectedType, printToStream, static_cast<void *>(&msg));
|
||||||
|
mlirEmitError(loc, msg.str().c_str());
|
||||||
|
throw mlir_diagnostic_emitted();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirBlock NodeImporter::createBlockFor(
|
||||||
|
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
|
||||||
Node *paramNode = jitBlock->param_node();
|
Node *paramNode = jitBlock->param_node();
|
||||||
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
||||||
std::vector<MlirType> blockArgTypes =
|
std::vector<MlirType> paramNodeTypes =
|
||||||
getMlirTypesFromValues(loc, paramNode->outputs());
|
getMlirTypesFromValues(loc, paramNode->outputs());
|
||||||
std::vector<MlirLocation> blockArgLocs(blockArgTypes.size(), loc);
|
if (!blockArgTypes)
|
||||||
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data(), blockArgLocs.data());
|
blockArgTypes = paramNodeTypes;
|
||||||
|
else
|
||||||
|
assert(blockArgTypes->size() == paramNodeTypes.size());
|
||||||
|
std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc);
|
||||||
|
MlirBlock block =
|
||||||
|
mlirBlockCreate(blockArgTypes.value().size(),
|
||||||
|
blockArgTypes.value().data(), blockArgLocs.data());
|
||||||
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
||||||
Value *jitValue = paramNode->outputs()[i];
|
Value *jitValue = paramNode->outputs()[i];
|
||||||
MlirValue value = mlirBlockGetArgument(block, i);
|
MlirValue value = mlirBlockGetArgument(block, i);
|
||||||
mapValue(jitValue, value);
|
MlirValue adjusted =
|
||||||
|
adjustBlockArgType(context, block, value, paramNodeTypes[i], loc);
|
||||||
|
mapValue(jitValue, adjusted);
|
||||||
}
|
}
|
||||||
return block;
|
return block;
|
||||||
}
|
}
|
||||||
|
@ -352,8 +397,10 @@ NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
MlirBlock
|
||||||
CreateTerminatorFn createTerminator) {
|
torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
||||||
|
CreateTerminatorFn createTerminator,
|
||||||
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
|
||||||
NodeImporter importer(context);
|
NodeImporter importer(context);
|
||||||
return importer.importBlock(jitBlock, createTerminator);
|
return importer.importBlock(jitBlock, createTerminator, blockArgTypes);
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,8 +24,22 @@ namespace torch_mlir {
|
||||||
using CreateTerminatorFn =
|
using CreateTerminatorFn =
|
||||||
std::function<void(c10::ArrayRef<MlirValue>, MlirBlock)>;
|
std::function<void(c10::ArrayRef<MlirValue>, MlirBlock)>;
|
||||||
|
|
||||||
MlirBlock importBlock(MlirContext context, torch::jit::Block *jitBlock,
|
/// Import `jitBlock` into a corresponding `MlirBlock`.
|
||||||
CreateTerminatorFn createTerminator);
|
///
|
||||||
|
/// Because `jit::Block` does not have a concept of terminator in the MLIR sense
|
||||||
|
/// (it is kind of "built-in" to the block, and not a free op chosen by the
|
||||||
|
/// enclosing op), the `createTerminator` function will be used to create the
|
||||||
|
/// terminator for the created block. Type adjustments like handling
|
||||||
|
/// derefinement can be handled there as well.
|
||||||
|
///
|
||||||
|
/// `blockArgTypes`, if present, gives a set of types that the block arguments
|
||||||
|
/// are required to be for correctness. The code will internally attempt to
|
||||||
|
/// adjust the types to the block argument types.
|
||||||
|
/// TODO: Formalize what type conversions are allowed here.
|
||||||
|
MlirBlock importBlock(
|
||||||
|
MlirContext context, torch::jit::Block *jitBlock,
|
||||||
|
CreateTerminatorFn createTerminator,
|
||||||
|
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt);
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ MlirType getMlirTypeFromTorchType(MlirLocation loc,
|
||||||
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
||||||
///
|
///
|
||||||
/// This can differ from the type inferred from the block of a
|
/// This can differ from the type inferred from the block of a
|
||||||
/// torch::jit::Function due to derefinement.
|
/// torch::jit::Function due to derefinement and refinement of tensor types.
|
||||||
MlirType getFunctionTypeFromSchema(MlirContext context,
|
MlirType getFunctionTypeFromSchema(MlirContext context,
|
||||||
const c10::FunctionSchema &schema);
|
const c10::FunctionSchema &schema);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See LICENSE.pytorch for license information.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||||
|
|
||||||
|
from torch._C import CompilationUnit
|
||||||
|
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||||
|
|
||||||
|
# Import TorchScript IR string as ScriptFunction.
|
||||||
|
def create_script_function(func_name, ts_ir_str):
|
||||||
|
cu = CompilationUnit()
|
||||||
|
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @__torch__.refined_block_arg(
|
||||||
|
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
|
||||||
|
# CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32>
|
||||||
|
# CHECK: %[[RESULT:.*]] = torch.derefine %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor
|
||||||
|
# CHECK: return %[[RESULT]] : !torch.tensor
|
||||||
|
script_function = create_script_function('__torch__.refined_block_arg', '''
|
||||||
|
graph(%0 : Float(1, 384)):
|
||||||
|
return (%0)
|
||||||
|
''')
|
||||||
|
|
||||||
|
mb = ModuleBuilder()
|
||||||
|
mb.import_function(script_function)
|
||||||
|
|
||||||
|
mb.module.operation.print()
|
||||||
|
print()
|
Loading…
Reference in New Issue