Ensure that imported function input type and block arg types are consistent.

I wasn't able to find exactly what frontend situation created it, but
`torch.jit.trace` will sometimes create functions where the
`jit::Block`'s param node has refined tensor types. So we need to adjust
the function's formal param types to those refined types.
pull/802/head
Sean Silva 2022-04-27 13:57:40 +00:00
parent 7f2577a848
commit 73cc2ac152
5 changed files with 119 additions and 22 deletions

View File

@ -57,6 +57,10 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
i++) {
resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i));
}
std::vector<MlirType> inputTypes;
for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) {
inputTypes.push_back(mlirFunctionTypeGetInput(functionType, i));
}
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) {
createMlirOperationAtEnd(
@ -65,7 +69,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
};
MlirBlock block = importBlock(
context, torch::jit::toGraphFunction(*function).graph()->block(),
createTerminator);
createTerminator, inputTypes);
mlirRegionAppendOwnedBlock(bodyRegion, block);
return func;
}

View File

@ -34,10 +34,14 @@ public:
NodeImporter(MlirContext context) : context(context) {}
void importNode(Node *node, MlirBlock appendToBlock);
MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator);
MlirBlock importBlock(
Block *jitBlock, CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt);
private:
MlirBlock createBlockFor(Block *jitBlock);
MlirBlock
createBlockFor(Block *jitBlock,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes);
void mapValue(Value *jitValue, MlirValue value);
void mapResults(Node *node, MlirOperation operation);
MlirValue lookupMappedValue(Value *jitValue);
@ -256,11 +260,12 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
auto classType = node->input(0)->type()->cast<c10::ClassType>();
auto methodName = node->s(c10::attr::name);
torch::jit::Function *function = classType->findMethod(methodName);
torch::jit::Block *calleeEntryBlock =
torch::jit::toGraphFunction(*function).graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return getMlirTypeFromTorchType(loc, v->type());
});
MlirType calleeType =
getFunctionTypeFromSchema(context, function->getSchema());
std::vector<MlirType> expectedTypes;
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
}
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.CallMethod", loc,
getMlirTypesFromValues(loc, node->outputs()),
@ -298,9 +303,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
}
}
MlirBlock NodeImporter::importBlock(Block *jitBlock,
CreateTerminatorFn createTerminator) {
MlirBlock block = createBlockFor(jitBlock);
MlirBlock NodeImporter::importBlock(
Block *jitBlock, CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
MlirBlock block = createBlockFor(jitBlock, blockArgTypes);
for (Node *node : jitBlock->nodes()) {
importNode(node, block);
}
@ -309,17 +315,56 @@ MlirBlock NodeImporter::importBlock(Block *jitBlock,
return block;
}
MlirBlock NodeImporter::createBlockFor(Block *jitBlock) {
static MlirValue adjustBlockArgType(MlirContext context,
MlirBlock appendToBlock, MlirValue value,
MlirType expectedType, MlirLocation loc) {
MlirType type = mlirValueGetType(value);
if (mlirTypeEqual(type, expectedType)) {
return value;
}
// For tensors, we might need to erase or add static type information.
if (torchMlirTypeIsATorchNonValueTensor(type) ||
torchMlirTypeIsATorchValueTensor(type)) {
MlirOperation op =
createMlirOperationAtEnd(appendToBlock, "torch.tensor_static_info_cast",
loc, expectedType, value);
return mlirOperationGetResult(op, 0);
}
{
std::stringstream msg;
MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) {
std::stringstream *stream = static_cast<std::stringstream *>(userData);
stream->write(str.data, str.length);
};
msg << "unhandled: could not adjust formal param type from ";
mlirTypePrint(type, printToStream, static_cast<void *>(&msg));
msg << " to expected type ";
mlirTypePrint(expectedType, printToStream, static_cast<void *>(&msg));
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
}
MlirBlock NodeImporter::createBlockFor(
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
Node *paramNode = jitBlock->param_node();
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
std::vector<MlirType> blockArgTypes =
std::vector<MlirType> paramNodeTypes =
getMlirTypesFromValues(loc, paramNode->outputs());
std::vector<MlirLocation> blockArgLocs(blockArgTypes.size(), loc);
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data(), blockArgLocs.data());
if (!blockArgTypes)
blockArgTypes = paramNodeTypes;
else
assert(blockArgTypes->size() == paramNodeTypes.size());
std::vector<MlirLocation> blockArgLocs(paramNodeTypes.size(), loc);
MlirBlock block =
mlirBlockCreate(blockArgTypes.value().size(),
blockArgTypes.value().data(), blockArgLocs.data());
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
Value *jitValue = paramNode->outputs()[i];
MlirValue value = mlirBlockGetArgument(block, i);
mapValue(jitValue, value);
MlirValue adjusted =
adjustBlockArgType(context, block, value, paramNodeTypes[i], loc);
mapValue(jitValue, adjusted);
}
return block;
}
@ -352,8 +397,10 @@ NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
return ret;
}
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
CreateTerminatorFn createTerminator) {
MlirBlock
torch_mlir::importBlock(MlirContext context, Block *jitBlock,
CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
NodeImporter importer(context);
return importer.importBlock(jitBlock, createTerminator);
return importer.importBlock(jitBlock, createTerminator, blockArgTypes);
}

View File

@ -24,8 +24,22 @@ namespace torch_mlir {
using CreateTerminatorFn =
std::function<void(c10::ArrayRef<MlirValue>, MlirBlock)>;
MlirBlock importBlock(MlirContext context, torch::jit::Block *jitBlock,
CreateTerminatorFn createTerminator);
/// Import `jitBlock` into a corresponding `MlirBlock`.
///
/// Because `jit::Block` does not have a concept of terminator in the MLIR sense
/// (it is kind of "built-in" to the block, and not a free op chosen by the
/// enclosing op), the `createTerminator` function will be used to create the
/// terminator for the created block. Type adjustments like handling
/// derefinement can be handled there as well.
///
/// `blockArgTypes`, if present, gives a set of types that the block arguments
/// are required to be for correctness. The code will internally attempt to
/// adjust the types to the block argument types.
/// TODO: Formalize what type conversions are allowed here.
MlirBlock importBlock(
MlirContext context, torch::jit::Block *jitBlock,
CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt);
} // namespace torch_mlir

View File

@ -42,7 +42,7 @@ MlirType getMlirTypeFromTorchType(MlirLocation loc,
/// Creates a FunctionType suitable for expressing the signature of `schema`.
///
/// This can differ from the type inferred from the block of a
/// torch::jit::Function due to derefinement.
/// torch::jit::Function due to derefinement and refinement of tensor types.
MlirType getFunctionTypeFromSchema(MlirContext context,
const c10::FunctionSchema &schema);

View File

@ -0,0 +1,32 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE.pytorch for license information.
import torch
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
from torch._C import CompilationUnit
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
# Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str):
cu = CompilationUnit()
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))
# CHECK-LABEL: func @__torch__.refined_block_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
# CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32>
# CHECK: %[[RESULT:.*]] = torch.derefine %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor
# CHECK: return %[[RESULT]] : !torch.tensor
script_function = create_script_function('__torch__.refined_block_arg', '''
graph(%0 : Float(1, 384)):
return (%0)
''')
mb = ModuleBuilder()
mb.import_function(script_function)
mb.module.operation.print()
print()