From 73cc2ac15238de1aa174d4984b4dc3ed04365cc3 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 27 Apr 2022 13:57:40 +0000 Subject: [PATCH] Ensure that imported function input type and block arg types are consistent. I wasn't able to find exactly what frontend situation created it, but `torch.jit.trace` will sometimes create functions where the `jit::Block`'s param node has refined tensor types. So we need to adjust the function's formal param types to those refined types. --- .../jit_ir/csrc/function_importer.cpp | 6 +- .../importer/jit_ir/csrc/node_importer.cpp | 83 +++++++++++++++---- .../importer/jit_ir/csrc/node_importer.h | 18 +++- .../jit_ir/csrc/torch_to_mlir_utils.h | 2 +- .../function-block-arg-adjustment.py | 32 +++++++ 5 files changed, 119 insertions(+), 22 deletions(-) create mode 100644 test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp index 6abebc2c8..ec65e4d74 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp @@ -57,6 +57,10 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp( i++) { resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i)); } + std::vector inputTypes; + for (int i = 0, e = mlirFunctionTypeGetNumInputs(functionType); i != e; i++) { + inputTypes.push_back(mlirFunctionTypeGetInput(functionType, i)); + } auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { createMlirOperationAtEnd( @@ -65,7 +69,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp( }; MlirBlock block = importBlock( context, torch::jit::toGraphFunction(*function).graph()->block(), - createTerminator); + createTerminator, inputTypes); mlirRegionAppendOwnedBlock(bodyRegion, block); return func; } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index 6cc6833fd..39a852f61 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -34,10 +34,14 @@ public: NodeImporter(MlirContext context) : context(context) {} void importNode(Node *node, MlirBlock appendToBlock); - MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator); + MlirBlock importBlock( + Block *jitBlock, CreateTerminatorFn createTerminator, + c10::optional> blockArgTypes = c10::nullopt); private: - MlirBlock createBlockFor(Block *jitBlock); + MlirBlock + createBlockFor(Block *jitBlock, + c10::optional> blockArgTypes); void mapValue(Value *jitValue, MlirValue value); void mapResults(Node *node, MlirOperation operation); MlirValue lookupMappedValue(Value *jitValue); @@ -256,11 +260,12 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { auto classType = node->input(0)->type()->cast(); auto methodName = node->s(c10::attr::name); torch::jit::Function *function = classType->findMethod(methodName); - torch::jit::Block *calleeEntryBlock = - torch::jit::toGraphFunction(*function).graph()->block(); - auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) { - return getMlirTypeFromTorchType(loc, v->type()); - }); + MlirType calleeType = + getFunctionTypeFromSchema(context, function->getSchema()); + std::vector expectedTypes; + for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) { + expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i)); + } MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.CallMethod", loc, getMlirTypesFromValues(loc, node->outputs()), @@ -298,9 +303,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { } } -MlirBlock NodeImporter::importBlock(Block *jitBlock, - CreateTerminatorFn createTerminator) { - MlirBlock block = createBlockFor(jitBlock); +MlirBlock NodeImporter::importBlock( + Block *jitBlock, CreateTerminatorFn createTerminator, + c10::optional> blockArgTypes) { + MlirBlock block = createBlockFor(jitBlock, blockArgTypes); for (Node *node : jitBlock->nodes()) { importNode(node, block); } @@ -309,17 +315,56 @@ MlirBlock NodeImporter::importBlock(Block *jitBlock, return block; } -MlirBlock NodeImporter::createBlockFor(Block *jitBlock) { +static MlirValue adjustBlockArgType(MlirContext context, + MlirBlock appendToBlock, MlirValue value, + MlirType expectedType, MlirLocation loc) { + MlirType type = mlirValueGetType(value); + if (mlirTypeEqual(type, expectedType)) { + return value; + } + // For tensors, we might need to erase or add static type information. + if (torchMlirTypeIsATorchNonValueTensor(type) || + torchMlirTypeIsATorchValueTensor(type)) { + MlirOperation op = + createMlirOperationAtEnd(appendToBlock, "torch.tensor_static_info_cast", + loc, expectedType, value); + return mlirOperationGetResult(op, 0); + } + { + std::stringstream msg; + MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) { + std::stringstream *stream = static_cast(userData); + stream->write(str.data, str.length); + }; + msg << "unhandled: could not adjust formal param type from "; + mlirTypePrint(type, printToStream, static_cast(&msg)); + msg << " to expected type "; + mlirTypePrint(expectedType, printToStream, static_cast(&msg)); + mlirEmitError(loc, msg.str().c_str()); + throw mlir_diagnostic_emitted(); + } +} + +MlirBlock NodeImporter::createBlockFor( + Block *jitBlock, c10::optional> blockArgTypes) { Node *paramNode = jitBlock->param_node(); MlirLocation loc = getMlirLocationFromNode(context, paramNode); - std::vector blockArgTypes = + std::vector paramNodeTypes = getMlirTypesFromValues(loc, paramNode->outputs()); - std::vector blockArgLocs(blockArgTypes.size(), loc); - MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data(), blockArgLocs.data()); + if (!blockArgTypes) + blockArgTypes = paramNodeTypes; + else + assert(blockArgTypes->size() == paramNodeTypes.size()); + std::vector blockArgLocs(paramNodeTypes.size(), loc); + MlirBlock block = + mlirBlockCreate(blockArgTypes.value().size(), + blockArgTypes.value().data(), blockArgLocs.data()); for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) { Value *jitValue = paramNode->outputs()[i]; MlirValue value = mlirBlockGetArgument(block, i); - mapValue(jitValue, value); + MlirValue adjusted = + adjustBlockArgType(context, block, value, paramNodeTypes[i], loc); + mapValue(jitValue, adjusted); } return block; } @@ -352,8 +397,10 @@ NodeImporter::lookupMappedValues(c10::ArrayRef values) { return ret; } -MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock, - CreateTerminatorFn createTerminator) { +MlirBlock +torch_mlir::importBlock(MlirContext context, Block *jitBlock, + CreateTerminatorFn createTerminator, + c10::optional> blockArgTypes) { NodeImporter importer(context); - return importer.importBlock(jitBlock, createTerminator); + return importer.importBlock(jitBlock, createTerminator, blockArgTypes); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h index 866750f03..c4893fdb9 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h @@ -24,8 +24,22 @@ namespace torch_mlir { using CreateTerminatorFn = std::function, MlirBlock)>; -MlirBlock importBlock(MlirContext context, torch::jit::Block *jitBlock, - CreateTerminatorFn createTerminator); +/// Import `jitBlock` into a corresponding `MlirBlock`. +/// +/// Because `jit::Block` does not have a concept of terminator in the MLIR sense +/// (it is kind of "built-in" to the block, and not a free op chosen by the +/// enclosing op), the `createTerminator` function will be used to create the +/// terminator for the created block. Type adjustments like handling +/// derefinement can be handled there as well. +/// +/// `blockArgTypes`, if present, gives a set of types that the block arguments +/// are required to be for correctness. The code will internally attempt to +/// adjust the types to the block argument types. +/// TODO: Formalize what type conversions are allowed here. +MlirBlock importBlock( + MlirContext context, torch::jit::Block *jitBlock, + CreateTerminatorFn createTerminator, + c10::optional> blockArgTypes = c10::nullopt); } // namespace torch_mlir diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h index ebe54d51c..7b1207422 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h @@ -42,7 +42,7 @@ MlirType getMlirTypeFromTorchType(MlirLocation loc, /// Creates a FunctionType suitable for expressing the signature of `schema`. /// /// This can differ from the type inferred from the block of a -/// torch::jit::Function due to derefinement. +/// torch::jit::Function due to derefinement and refinement of tensor types. MlirType getFunctionTypeFromSchema(MlirContext context, const c10::FunctionSchema &schema); diff --git a/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py b/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py new file mode 100644 index 000000000..33b862720 --- /dev/null +++ b/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py @@ -0,0 +1,32 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See LICENSE.pytorch for license information. + +import torch +from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder + +from torch._C import CompilationUnit + + +# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s + +# Import TorchScript IR string as ScriptFunction. +def create_script_function(func_name, ts_ir_str): + cu = CompilationUnit() + return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str)) + +# CHECK-LABEL: func @__torch__.refined_block_arg( +# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor { +# CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32> +# CHECK: %[[RESULT:.*]] = torch.derefine %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor +# CHECK: return %[[RESULT]] : !torch.tensor +script_function = create_script_function('__torch__.refined_block_arg', ''' +graph(%0 : Float(1, 384)): + return (%0) +''') + +mb = ModuleBuilder() +mb.import_function(script_function) + +mb.module.operation.print() +print()