mirror of https://github.com/llvm/torch-mlir
Ensure that imported function input type and block arg types are consistent.
I wasn't able to find exactly what frontend situation created it, but `torch.jit.trace` will sometimes create functions where the `jit::Block`'s param node has refined tensor types. So we need to adjust the function's formal param types to those refined types.pull/802/head
parent
7f2577a848
commit
73cc2ac152
|
@ -57,6 +57,10 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
|||
i++) {
|
||||
resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i));
|
||||
}
|
||||
std::vector<MlirType> inputTypes;
|
||||
for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) {
|
||||
inputTypes.push_back(mlirFunctionTypeGetInput(functionType, i));
|
||||
}
|
||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||
MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(
|
||||
|
@ -65,7 +69,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
|||
};
|
||||
MlirBlock block = importBlock(
|
||||
context, torch::jit::toGraphFunction(*function).graph()->block(),
|
||||
createTerminator);
|
||||
createTerminator, inputTypes);
|
||||
mlirRegionAppendOwnedBlock(bodyRegion, block);
|
||||
return func;
|
||||
}
|
||||
|
|
|
@ -34,10 +34,14 @@ public:
|
|||
NodeImporter(MlirContext context) : context(context) {}
|
||||
|
||||
void importNode(Node *node, MlirBlock appendToBlock);
|
||||
MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator);
|
||||
MlirBlock importBlock(
|
||||
Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt);
|
||||
|
||||
private:
|
||||
MlirBlock createBlockFor(Block *jitBlock);
|
||||
MlirBlock
|
||||
createBlockFor(Block *jitBlock,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes);
|
||||
void mapValue(Value *jitValue, MlirValue value);
|
||||
void mapResults(Node *node, MlirOperation operation);
|
||||
MlirValue lookupMappedValue(Value *jitValue);
|
||||
|
@ -256,11 +260,12 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
auto classType = node->input(0)->type()->cast<c10::ClassType>();
|
||||
auto methodName = node->s(c10::attr::name);
|
||||
torch::jit::Function *function = classType->findMethod(methodName);
|
||||
torch::jit::Block *calleeEntryBlock =
|
||||
torch::jit::toGraphFunction(*function).graph()->block();
|
||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||
return getMlirTypeFromTorchType(loc, v->type());
|
||||
});
|
||||
MlirType calleeType =
|
||||
getFunctionTypeFromSchema(context, function->getSchema());
|
||||
std::vector<MlirType> expectedTypes;
|
||||
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
|
||||
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
|
||||
}
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.CallMethod", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
|
@ -298,9 +303,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
}
|
||||
}
|
||||
|
||||
MlirBlock NodeImporter::importBlock(Block *jitBlock,
|
||||
CreateTerminatorFn createTerminator) {
|
||||
MlirBlock block = createBlockFor(jitBlock);
|
||||
MlirBlock NodeImporter::importBlock(
|
||||
Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
|
||||
MlirBlock block = createBlockFor(jitBlock, blockArgTypes);
|
||||
for (Node *node : jitBlock->nodes()) {
|
||||
importNode(node, block);
|
||||
}
|
||||
|
@ -309,17 +315,56 @@ MlirBlock NodeImporter::importBlock(Block *jitBlock,
|
|||
return block;
|
||||
}
|
||||
|
||||
MlirBlock NodeImporter::createBlockFor(Block *jitBlock) {
|
||||
static MlirValue adjustBlockArgType(MlirContext context,
|
||||
MlirBlock appendToBlock, MlirValue value,
|
||||
MlirType expectedType, MlirLocation loc) {
|
||||
MlirType type = mlirValueGetType(value);
|
||||
if (mlirTypeEqual(type, expectedType)) {
|
||||
return value;
|
||||
}
|
||||
// For tensors, we might need to erase or add static type information.
|
||||
if (torchMlirTypeIsATorchNonValueTensor(type) ||
|
||||
torchMlirTypeIsATorchValueTensor(type)) {
|
||||
MlirOperation op =
|
||||
createMlirOperationAtEnd(appendToBlock, "torch.tensor_static_info_cast",
|
||||
loc, expectedType, value);
|
||||
return mlirOperationGetResult(op, 0);
|
||||
}
|
||||
{
|
||||
std::stringstream msg;
|
||||
MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) {
|
||||
std::stringstream *stream = static_cast<std::stringstream *>(userData);
|
||||
stream->write(str.data, str.length);
|
||||
};
|
||||
msg << "unhandled: could not adjust formal param type from ";
|
||||
mlirTypePrint(type, printToStream, static_cast<void *>(&msg));
|
||||
msg << " to expected type ";
|
||||
mlirTypePrint(expectedType, printToStream, static_cast<void *>(&msg));
|
||||
mlirEmitError(loc, msg.str().c_str());
|
||||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
}
|
||||
|
||||
MlirBlock NodeImporter::createBlockFor(
|
||||
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
|
||||
Node *paramNode = jitBlock->param_node();
|
||||
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
||||
std::vector<MlirType> blockArgTypes =
|
||||
std::vector<MlirType> paramNodeTypes =
|
||||
getMlirTypesFromValues(loc, paramNode->outputs());
|
||||
std::vector<MlirLocation> blockArgLocs(blockArgTypes.size(), loc);
|
||||
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data(), blockArgLocs.data());
|
||||
if (!blockArgTypes)
|
||||
blockArgTypes = paramNodeTypes;
|
||||
else
|
||||
assert(blockArgTypes->size() == paramNodeTypes.size());
|
||||
std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc);
|
||||
MlirBlock block =
|
||||
mlirBlockCreate(blockArgTypes.value().size(),
|
||||
blockArgTypes.value().data(), blockArgLocs.data());
|
||||
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
||||
Value *jitValue = paramNode->outputs()[i];
|
||||
MlirValue value = mlirBlockGetArgument(block, i);
|
||||
mapValue(jitValue, value);
|
||||
MlirValue adjusted =
|
||||
adjustBlockArgType(context, block, value, paramNodeTypes[i], loc);
|
||||
mapValue(jitValue, adjusted);
|
||||
}
|
||||
return block;
|
||||
}
|
||||
|
@ -352,8 +397,10 @@ NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
||||
CreateTerminatorFn createTerminator) {
|
||||
MlirBlock
|
||||
torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
||||
CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
|
||||
NodeImporter importer(context);
|
||||
return importer.importBlock(jitBlock, createTerminator);
|
||||
return importer.importBlock(jitBlock, createTerminator, blockArgTypes);
|
||||
}
|
||||
|
|
|
@ -24,8 +24,22 @@ namespace torch_mlir {
|
|||
using CreateTerminatorFn =
|
||||
std::function<void(c10::ArrayRef<MlirValue>, MlirBlock)>;
|
||||
|
||||
MlirBlock importBlock(MlirContext context, torch::jit::Block *jitBlock,
|
||||
CreateTerminatorFn createTerminator);
|
||||
/// Import `jitBlock` into a corresponding `MlirBlock`.
|
||||
///
|
||||
/// Because `jit::Block` does not have a concept of terminator in the MLIR sense
|
||||
/// (it is kind of "built-in" to the block, and not a free op chosen by the
|
||||
/// enclosing op), the `createTerminator` function will be used to create the
|
||||
/// terminator for the created block. Type adjustments like handling
|
||||
/// derefinement can be handled there as well.
|
||||
///
|
||||
/// `blockArgTypes`, if present, gives a set of types that the block arguments
|
||||
/// are required to be for correctness. The code will internally attempt to
|
||||
/// adjust the types to the block argument types.
|
||||
/// TODO: Formalize what type conversions are allowed here.
|
||||
MlirBlock importBlock(
|
||||
MlirContext context, torch::jit::Block *jitBlock,
|
||||
CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt);
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ MlirType getMlirTypeFromTorchType(MlirLocation loc,
|
|||
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
||||
///
|
||||
/// This can differ from the type inferred from the block of a
|
||||
/// torch::jit::Function due to derefinement.
|
||||
/// torch::jit::Function due to derefinement and refinement of tensor types.
|
||||
MlirType getFunctionTypeFromSchema(MlirContext context,
|
||||
const c10::FunctionSchema &schema);
|
||||
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE.pytorch for license information.
|
||||
|
||||
import torch
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
|
||||
from torch._C import CompilationUnit
|
||||
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
# Import TorchScript IR string as ScriptFunction.
|
||||
def create_script_function(func_name, ts_ir_str):
|
||||
cu = CompilationUnit()
|
||||
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))
|
||||
|
||||
# CHECK-LABEL: func @__torch__.refined_block_arg(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
# CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32>
|
||||
# CHECK: %[[RESULT:.*]] = torch.derefine %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor
|
||||
# CHECK: return %[[RESULT]] : !torch.tensor
|
||||
script_function = create_script_function('__torch__.refined_block_arg', '''
|
||||
graph(%0 : Float(1, 384)):
|
||||
return (%0)
|
||||
''')
|
||||
|
||||
mb = ModuleBuilder()
|
||||
mb.import_function(script_function)
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
Loading…
Reference in New Issue