From ab5ad7af09ac1fade2eb27dd3ed5bd1a744e0c67 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 3 May 2022 08:43:13 +0000 Subject: [PATCH] Add tracing suport to `torch_mlir.compile`. This also has a fix for the adjustment of types of TupleConstruct inputs, which I found when using this new functionality on a model. Some scenarios in tracing create situations where the output of TupleConstruct has a more refined type than the inputs. This introduces a helper `adjustStaticInformationForValues` which subsumes the `derefineValues` helper and the tensor static information adjustment we were doing. --- include/torch-mlir-c/TorchOps.h | 39 ++++++++ lib/CAPI/CMakeLists.txt | 1 + lib/CAPI/TorchOps.cpp | 75 ++++++++++++++ python/torch_mlir/__init__.py | 11 ++- .../jit_ir/csrc/function_importer.cpp | 7 +- .../importer/jit_ir/csrc/node_importer.cpp | 98 ++++++++++--------- .../jit_ir/csrc/torch_to_mlir_utils.cpp | 37 ++++--- .../jit_ir/csrc/torch_to_mlir_utils.h | 7 +- .../function-block-arg-adjustment.py | 18 +--- .../importer/jit_ir/node_import/prim.py | 15 ++- .../importer/jit_ir/node_import/tuple.py | 24 +++-- .../importer/jit_ir/node_import/utils.py | 15 +++ 12 files changed, 247 insertions(+), 100 deletions(-) create mode 100644 include/torch-mlir-c/TorchOps.h create mode 100644 lib/CAPI/TorchOps.cpp create mode 100644 test/python/importer/jit_ir/node_import/utils.py diff --git a/include/torch-mlir-c/TorchOps.h b/include/torch-mlir-c/TorchOps.h new file mode 100644 index 000000000..26b030f63 --- /dev/null +++ b/include/torch-mlir-c/TorchOps.h @@ -0,0 +1,39 @@ +//===-- torch-mlir-c/TorchOps.h - C API for torch ops -------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_C_TORCHOPS_H +#define TORCHMLIR_C_TORCHOPS_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// Utilities. +//===----------------------------------------------------------------------===// + +/// Adjusts the static information in the type of `value` to `desiredType`. +/// +/// Returns null if such an adjustment is not possible. +/// +/// If `userAllowsRefinement` is true, then the original value will be returned +/// if it is a subtype of `desiredType`. +MLIR_CAPI_EXPORTED MlirValue torchMlirAdjustStaticInformation( + MlirBlock block, MlirOperation insertBefore, MlirValue value, + MlirType desiredType, bool userAllowsRefinement); + +#ifdef __cplusplus +} +#endif + +#endif // TORCHMLIR_C_TORCHOPS_H diff --git a/lib/CAPI/CMakeLists.txt b/lib/CAPI/CMakeLists.txt index 275f288cf..87977a86f 100644 --- a/lib/CAPI/CMakeLists.txt +++ b/lib/CAPI/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_public_c_api_library(TorchMLIRCAPI Dialects.cpp Registration.cpp + TorchOps.cpp TorchTypes.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/CAPI/TorchOps.cpp b/lib/CAPI/TorchOps.cpp new file mode 100644 index 000000000..b67e4a894 --- /dev/null +++ b/lib/CAPI/TorchOps.cpp @@ -0,0 +1,75 @@ +//===- TorchOps.cpp - C Interface for torch ops ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir-c/TorchOps.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/IR/BuiltinTypes.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" + +using namespace mlir; +using namespace mlir::torch; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +MlirValue torchMlirAdjustStaticInformation(MlirBlock block_, + MlirOperation insertBefore_, + MlirValue value_, + MlirType desiredType_, + bool userAllowsRefinement) { + Block *block = unwrap(block_); + Operation *insertBefore = unwrap(insertBefore_); + OpBuilder builder(unwrap(mlirTypeGetContext(desiredType_))); + builder.setInsertionPoint(block, insertBefore ? insertBefore->getIterator() + : block->end()); + + Value value = unwrap(value_); + Type type = value.getType(); + Type desiredType = unwrap(desiredType_); + + // If the value is already of the desired type, we're done. + if (type == desiredType) + return wrap(value); + + // If the type is a tensor, then adjust the static information. + if ((type.isa() && + desiredType.isa()) || + (type.isa() && + desiredType.isa())) { + Value adjusted = builder.create( + value.getLoc(), desiredType, value); + return wrap(adjusted); + } + + // If the type is a subtype of desiredType, then we need to derefine it to + // desiredType, unless the user allows refinement. + if (Torch::isValidSubtype(type, desiredType)) { + if (!userAllowsRefinement) { + Value adjusted = + builder.create(value.getLoc(), desiredType, value); + return wrap(adjusted); + } else { + return wrap(value); + } + } + + // If the desiredType is subtype of type, then we assume that the desiredType + // is dynamically valid, so we do an unchecked cast. + if (Torch::isValidSubtype(desiredType, type)) { + Value adjusted = builder.create( + value.getLoc(), desiredType, value); + return wrap(adjusted); + } + + // No known adjustment. + return {}; +} diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 687469574..16997274e 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -34,7 +34,8 @@ class OutputType(Enum): def compile(model: torch.nn.Module, example_args: List[torch.Tensor], - output_type: OutputType = OutputType.TORCH): + output_type: OutputType = OutputType.TORCH, + use_tracing=False): """Convert a PyTorch model to MLIR. Args: @@ -44,6 +45,8 @@ def compile(model: torch.nn.Module, A single tensor is treated as a list of a single tensor. output_type: The kind of output to produce. See `OutputType` for more details. + use_tracing: If True, use `torch.jit.trace` to convert the model to + JIT IR rather than `torch.jit.script`. Returns: An MLIR module that contains the converted model in the specified @@ -55,8 +58,10 @@ def compile(model: torch.nn.Module, # TODO: Support dynamic dimension sizes. See `torch.onnx.export`'s # `dynamic_axes` for API inspiration, or do something more ergonomic # like a tensor wrapper possibly. - # TODO: Support tracing the model instead of scripting it. - scripted = torch.jit.script(model) + if use_tracing: + scripted = torch.jit.trace(model, tuple(example_args)) + else: + scripted = torch.jit.script(model) if isinstance(example_args, torch.Tensor): example_args = [example_args] diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp index ec65e4d74..dcda400c7 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp @@ -63,9 +63,10 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp( } auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { - createMlirOperationAtEnd( - appendToBlock, "func.return", loc, - derefineValues(yieldedValues, resultTypes, loc, appendToBlock)); + createMlirOperationAtEnd(appendToBlock, "func.return", loc, + adjustStaticInformationForValues( + appendToBlock, loc, yieldedValues, resultTypes, + /*userAllowsRefinement=*/false)); }; MlirBlock block = importBlock( context, torch::jit::toGraphFunction(*function).graph()->block(), diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index 39a852f61..a879d46ad 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -19,6 +19,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" +#include "torch-mlir-c/TorchOps.h" #include "torch-mlir-c/TorchTypes.h" namespace py = pybind11; @@ -114,16 +115,40 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { } // Builtin interpreter ops with no operator/schema. - InputsTransformFn transformer = - kind != c10::prim::DictConstruct ? nullptr : rearrangeDictConstructInputs; switch (kind) { case c10::prim::ListUnpack: case c10::prim::ListConstruct: - case c10::prim::TupleConstruct: - case c10::prim::DictConstruct: case c10::prim::CreateObject: { createAndMapTrivialNode( - node, "torch.prim." + std::string(kind.toUnqualString()), transformer); + node, "torch.prim." + std::string(kind.toUnqualString()), nullptr); + return; + } + case c10::prim::TupleConstruct: { + // TODO: We will probably need to adjust the static information for + // ListConstruct and DictConstruct too. + auto containedTypes = c10::fmap( + node->output()->type()->cast()->containedTypes(), + [&](const c10::TypePtr &t) { + MlirType type = getMlirTypeFromTorchType(loc, t); + if (mlirTypeIsNull(type)) { + throw mlir_diagnostic_emitted(); + } + return type; + }); + createAndMapTrivialNode(node, + "torch.prim." + std::string(kind.toUnqualString()), + [&](std::vector &inputs) { + assert(containedTypes.size() == inputs.size()); + return adjustStaticInformationForValues( + appendToBlock, loc, inputs, containedTypes, + /*userAllowsRefinement=*/true); + }); + return; + } + case c10::prim::DictConstruct: { + createAndMapTrivialNode(node, + "torch.prim." + std::string(kind.toUnqualString()), + rearrangeDictConstructInputs); return; } case c10::prim::GetAttr: @@ -213,8 +238,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.Loop", loc, resultTypes, lookupMappedValues(node->inputs().slice(0, 2)), - derefineValues(lookupMappedValues(node->inputs().slice(2)), resultTypes, - loc, appendToBlock), + adjustStaticInformationForValues( + appendToBlock, loc, lookupMappedValues(node->inputs().slice(2)), + resultTypes, /*userAllowsRefinement=*/false), mlirRegionCreate()); mapResults(node, operation); std::vector terminatorOperandTypes = { @@ -223,10 +249,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { resultTypes.begin(), resultTypes.end()); auto createTerminator = [&](c10::ArrayRef yieldedValues, MlirBlock appendToBlock) { - createMlirOperationAtEnd(appendToBlock, "torch.prim.Loop.condition", loc, - derefineValues(yieldedValues, - terminatorOperandTypes, loc, - appendToBlock)); + createMlirOperationAtEnd( + appendToBlock, "torch.prim.Loop.condition", loc, + adjustStaticInformationForValues(appendToBlock, loc, yieldedValues, + terminatorOperandTypes, + /*userAllowsRefinement=*/false)); }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), @@ -245,7 +272,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { MlirBlock appendToBlock) { createMlirOperationAtEnd( appendToBlock, "torch.prim.If.yield", loc, - derefineValues(yieldedValues, resultTypes, loc, appendToBlock)); + adjustStaticInformationForValues(appendToBlock, loc, yieldedValues, + resultTypes, + /*userAllowsRefinement=*/false)); }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), @@ -269,8 +298,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.CallMethod", loc, getMlirTypesFromValues(loc, node->outputs()), - derefineValues(lookupMappedValues(node->inputs()), expectedTypes, loc, - appendToBlock), + adjustStaticInformationForValues( + appendToBlock, loc, lookupMappedValues(node->inputs()), + expectedTypes, /*userAllowsRefinement=*/false), toMlirNamedAttribute("name", importAttribute(loc, node, c10::attr::name))); mapResults(node, operation); @@ -288,8 +318,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { appendToBlock, "func.call_indirect", loc, getMlirTypesFromValues(loc, node->outputs()), lookupMappedValue(node->input(0)), - derefineValues(lookupMappedValues(node->inputs().slice(1)), - expectedTypes, loc, appendToBlock)); + adjustStaticInformationForValues( + appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)), + expectedTypes, /*userAllowsRefinement=*/false)); mapResults(node, operation); return; } @@ -315,36 +346,6 @@ MlirBlock NodeImporter::importBlock( return block; } -static MlirValue adjustBlockArgType(MlirContext context, - MlirBlock appendToBlock, MlirValue value, - MlirType expectedType, MlirLocation loc) { - MlirType type = mlirValueGetType(value); - if (mlirTypeEqual(type, expectedType)) { - return value; - } - // For tensors, we might need to erase or add static type information. - if (torchMlirTypeIsATorchNonValueTensor(type) || - torchMlirTypeIsATorchValueTensor(type)) { - MlirOperation op = - createMlirOperationAtEnd(appendToBlock, "torch.tensor_static_info_cast", - loc, expectedType, value); - return mlirOperationGetResult(op, 0); - } - { - std::stringstream msg; - MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) { - std::stringstream *stream = static_cast(userData); - stream->write(str.data, str.length); - }; - msg << "unhandled: could not adjust formal param type from "; - mlirTypePrint(type, printToStream, static_cast(&msg)); - msg << " to expected type "; - mlirTypePrint(expectedType, printToStream, static_cast(&msg)); - mlirEmitError(loc, msg.str().c_str()); - throw mlir_diagnostic_emitted(); - } -} - MlirBlock NodeImporter::createBlockFor( Block *jitBlock, c10::optional> blockArgTypes) { Node *paramNode = jitBlock->param_node(); @@ -362,8 +363,9 @@ MlirBlock NodeImporter::createBlockFor( for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) { Value *jitValue = paramNode->outputs()[i]; MlirValue value = mlirBlockGetArgument(block, i); - MlirValue adjusted = - adjustBlockArgType(context, block, value, paramNodeTypes[i], loc); + MlirValue adjusted = adjustStaticInformationForValues( + block, loc, {value}, {paramNodeTypes[i]}, + /*userAllowsRefinement=*/false)[0]; mapValue(jitValue, adjusted); } return block; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index 9b325bd05..8a69a73a5 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -19,6 +19,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" +#include "torch-mlir-c/TorchOps.h" #include "torch-mlir-c/TorchTypes.h" using namespace torch_mlir; @@ -380,24 +381,34 @@ torch_mlir::getMlirTypesFromValues(MlirLocation loc, return ret; } -std::vector -torch_mlir::derefineValues(c10::ArrayRef values, - c10::ArrayRef expectedTypes, - MlirLocation loc, MlirBlock appendToBlock) { +std::vector torch_mlir::adjustStaticInformationForValues( + MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef values, + c10::ArrayRef desiredTypes, bool userAllowsRefinement) { std::vector ret; - assert(values.size() == expectedTypes.size()); + assert(values.size() == desiredTypes.size()); for (int i = 0, e = values.size(); i != e; i++) { MlirValue value = values[i]; - MlirType expectedType = expectedTypes[i]; + MlirType expectedType = desiredTypes[i]; MlirType type = mlirValueGetType(value); - if (mlirTypeEqual(expectedType, type)) { - // No need to derefine. - ret.push_back(value); - } else { - MlirOperation operation = createMlirOperationAtEnd( - appendToBlock, "torch.derefine", loc, expectedType, value); - ret.push_back(mlirOperationGetResult(operation, 0)); + MlirValue adjusted = torchMlirAdjustStaticInformation( + appendToBlock, mlirBlockGetTerminator(appendToBlock), value, + expectedType, userAllowsRefinement); + if (!mlirValueIsNull(adjusted)) { + ret.push_back(adjusted); + continue; } + + std::stringstream msg; + MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) { + std::stringstream *stream = static_cast(userData); + stream->write(str.data, str.length); + }; + msg << "unhandled: could not adjust static info for type from "; + mlirTypePrint(type, printToStream, static_cast(&msg)); + msg << " to type "; + mlirTypePrint(expectedType, printToStream, static_cast(&msg)); + mlirEmitError(loc, msg.str().c_str()); + throw mlir_diagnostic_emitted(); } return ret; } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h index 7b1207422..328be1291 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h @@ -60,10 +60,9 @@ std::vector getMlirTypesFromValues(MlirLocation loc, c10::ArrayRef values); -std::vector derefineValues(c10::ArrayRef values, - c10::ArrayRef expectedTypes, - MlirLocation loc, - MlirBlock appendToBlock); +std::vector adjustStaticInformationForValues( + MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef values, + c10::ArrayRef desiredTypes, bool userAllowsRefinement); /// Create the appropriate MLIR operation for the Torch operator with schema /// "schema". diff --git a/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py b/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py index 33b862720..f70ad3db5 100644 --- a/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py +++ b/test/python/importer/jit_ir/node_import/function-block-arg-adjustment.py @@ -2,31 +2,23 @@ # This file is licensed under a pytorch-style license # See LICENSE.pytorch for license information. -import torch from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder -from torch._C import CompilationUnit - +from utils import create_script_function # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s -# Import TorchScript IR string as ScriptFunction. -def create_script_function(func_name, ts_ir_str): - cu = CompilationUnit() - return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str)) +mb = ModuleBuilder() # CHECK-LABEL: func @__torch__.refined_block_arg( # CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor { # CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32> -# CHECK: %[[RESULT:.*]] = torch.derefine %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor +# CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor # CHECK: return %[[RESULT]] : !torch.tensor -script_function = create_script_function('__torch__.refined_block_arg', ''' +mb.import_function(create_script_function("__torch__.refined_block_arg", """ graph(%0 : Float(1, 384)): return (%0) -''') - -mb = ModuleBuilder() -mb.import_function(script_function) +""")) mb.module.operation.print() print() diff --git a/test/python/importer/jit_ir/node_import/prim.py b/test/python/importer/jit_ir/node_import/prim.py index ecde6307d..885d35f00 100644 --- a/test/python/importer/jit_ir/node_import/prim.py +++ b/test/python/importer/jit_ir/node_import/prim.py @@ -5,21 +5,16 @@ import typing import torch -from torch._C import CompilationUnit from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder +from utils import create_script_function + import typing # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s mb = ModuleBuilder() - -# Import TorchScript IR string as ScriptFunction. -def import_ts_ir(func_name, ts_ir_str): - cu = CompilationUnit() - mb.import_function(cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))) - # CHECK-LABEL: func @__torch__.prim_NumToTensor( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor { # CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor @@ -159,9 +154,11 @@ def prim_max(x: int): # CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[A]], %[[B]], %[[C]] : # CHECK-SAME: (!torch.int, !torch.int, !torch.int) -> !torch.list # CHECK: return %[[RET]] : !torch.list -import_ts_ir('__torch__.prim_Constant_list', '''graph(): +mb.import_function(create_script_function("__torch__.prim_Constant_list", """ +graph(): %list : int[] = prim::Constant[value=[1, 2, 3]]() - return (%list)''') + return (%list) +""")) mb.module.operation.print() print() diff --git a/test/python/importer/jit_ir/node_import/tuple.py b/test/python/importer/jit_ir/node_import/tuple.py index 162f04d50..2bd66d610 100644 --- a/test/python/importer/jit_ir/node_import/tuple.py +++ b/test/python/importer/jit_ir/node_import/tuple.py @@ -4,9 +4,10 @@ import torch from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder -import collections from typing import Tuple, Optional, NamedTuple +from utils import create_script_function + # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s mb = ModuleBuilder() @@ -20,8 +21,6 @@ NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]), # CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] : # CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple # CHECK: return %[[RET]] : !torch.tuple - - @mb.import_function @torch.jit.script def tuple(t0, t1): @@ -38,8 +37,6 @@ def tuple(t0, t1): # CHECK-SAME: !torch.tuple to # CHECK-SAME: !torch.tuple, optional> # CHECK: return %[[RET]] : !torch.tuple, optional> - - @mb.import_function @torch.jit.script def tuple_optional( @@ -55,8 +52,6 @@ def tuple_optional( # CHECK-SAME: !torch.tensor, !torch.tensor -> # CHECK-SAME: !torch.tuple, optional> # CHECK: return %[[RET]] : !torch.tuple, optional> -# CHECK: } -# @mb.import_function @torch.jit.script def namedtuple_optional( @@ -64,5 +59,20 @@ def namedtuple_optional( return NT(t0, t1) +# CHECK-LABEL: func @__torch__.tuple_construct_arg_needs_refinement( +# CHECK-SAME: %[[T0:.*]]: !torch.tensor, +# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple { +# CHECK: %[[T0_REFINED:.*]] = torch.tensor_static_info_cast %[[T1]] : !torch.tensor to !torch.tensor<[4],f32> +# CHECK: %[[T1_REFINED:.*]] = torch.tensor_static_info_cast %[[T0]] : !torch.tensor to !torch.tensor<[3],f32> +# CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0_REFINED]], %[[T1_REFINED]] : !torch.tensor<[4],f32>, !torch.tensor<[3],f32> -> !torch.tuple, tensor<[3],f32>> +# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[TUPLE]] : !torch.tuple, tensor<[3],f32>> to !torch.tuple +# CHECK: return %[[DEREFINED]] : !torch.tuple +mb.import_function(create_script_function("__torch__.tuple_construct_arg_needs_refinement", """ +graph(%t0 : Tensor, + %t1 : Tensor): + %10 : (Float(4), Float(3)) = prim::TupleConstruct(%t1, %t0) + return (%10) +""")) + mb.module.operation.print() print() diff --git a/test/python/importer/jit_ir/node_import/utils.py b/test/python/importer/jit_ir/node_import/utils.py new file mode 100644 index 000000000..6e8d1ac45 --- /dev/null +++ b/test/python/importer/jit_ir/node_import/utils.py @@ -0,0 +1,15 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See LICENSE.pytorch for license information. + +# Helpers for the other tests. + +import torch +from torch._C import CompilationUnit + +# RUN: %PYTHON %s + +# Import TorchScript IR string as ScriptFunction. +def create_script_function(func_name, ts_ir_str): + cu = CompilationUnit() + return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))