Add tracing suport to `torch_mlir.compile`.

This also has a fix for the adjustment of types of TupleConstruct
inputs, which I found when using this new functionality on a model.

Some scenarios in tracing create situations where the output of
TupleConstruct has a more refined type than the inputs.

This introduces a helper `adjustStaticInformationForValues` which
subsumes the `derefineValues` helper and the tensor static information
adjustment we were doing.
pull/775/head
Sean Silva 2022-05-03 08:43:13 +00:00
parent 641b0b3be2
commit ab5ad7af09
12 changed files with 247 additions and 100 deletions

View File

@ -0,0 +1,39 @@
//===-- torch-mlir-c/TorchOps.h - C API for torch ops -------------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_C_TORCHOPS_H
#define TORCHMLIR_C_TORCHOPS_H
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
#endif
//===----------------------------------------------------------------------===//
// Utilities.
//===----------------------------------------------------------------------===//
/// Adjusts the static information in the type of `value` to `desiredType`.
///
/// Returns null if such an adjustment is not possible.
///
/// If `userAllowsRefinement` is true, then the original value will be returned
/// if it is a subtype of `desiredType`.
MLIR_CAPI_EXPORTED MlirValue torchMlirAdjustStaticInformation(
MlirBlock block, MlirOperation insertBefore, MlirValue value,
MlirType desiredType, bool userAllowsRefinement);
#ifdef __cplusplus
}
#endif
#endif // TORCHMLIR_C_TORCHOPS_H

View File

@ -1,6 +1,7 @@
add_mlir_public_c_api_library(TorchMLIRCAPI
Dialects.cpp
Registration.cpp
TorchOps.cpp
TorchTypes.cpp
ADDITIONAL_HEADER_DIRS

View File

@ -0,0 +1,75 @@
//===- TorchOps.cpp - C Interface for torch ops ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir-c/TorchOps.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinTypes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
using namespace mlir;
using namespace mlir::torch;
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
MlirValue torchMlirAdjustStaticInformation(MlirBlock block_,
MlirOperation insertBefore_,
MlirValue value_,
MlirType desiredType_,
bool userAllowsRefinement) {
Block *block = unwrap(block_);
Operation *insertBefore = unwrap(insertBefore_);
OpBuilder builder(unwrap(mlirTypeGetContext(desiredType_)));
builder.setInsertionPoint(block, insertBefore ? insertBefore->getIterator()
: block->end());
Value value = unwrap(value_);
Type type = value.getType();
Type desiredType = unwrap(desiredType_);
// If the value is already of the desired type, we're done.
if (type == desiredType)
return wrap(value);
// If the type is a tensor, then adjust the static information.
if ((type.isa<Torch::ValueTensorType>() &&
desiredType.isa<Torch::ValueTensorType>()) ||
(type.isa<Torch::NonValueTensorType>() &&
desiredType.isa<Torch::NonValueTensorType>())) {
Value adjusted = builder.create<Torch::TensorStaticInfoCastOp>(
value.getLoc(), desiredType, value);
return wrap(adjusted);
}
// If the type is a subtype of desiredType, then we need to derefine it to
// desiredType, unless the user allows refinement.
if (Torch::isValidSubtype(type, desiredType)) {
if (!userAllowsRefinement) {
Value adjusted =
builder.create<Torch::DerefineOp>(value.getLoc(), desiredType, value);
return wrap(adjusted);
} else {
return wrap(value);
}
}
// If the desiredType is subtype of type, then we assume that the desiredType
// is dynamically valid, so we do an unchecked cast.
if (Torch::isValidSubtype(desiredType, type)) {
Value adjusted = builder.create<Torch::PrimUncheckedCastOp>(
value.getLoc(), desiredType, value);
return wrap(adjusted);
}
// No known adjustment.
return {};
}

View File

@ -34,7 +34,8 @@ class OutputType(Enum):
def compile(model: torch.nn.Module,
example_args: List[torch.Tensor],
output_type: OutputType = OutputType.TORCH):
output_type: OutputType = OutputType.TORCH,
use_tracing=False):
"""Convert a PyTorch model to MLIR.
Args:
@ -44,6 +45,8 @@ def compile(model: torch.nn.Module,
A single tensor is treated as a list of a single tensor.
output_type: The kind of output to produce. See `OutputType` for more
details.
use_tracing: If True, use `torch.jit.trace` to convert the model to
JIT IR rather than `torch.jit.script`.
Returns:
An MLIR module that contains the converted model in the specified
@ -55,8 +58,10 @@ def compile(model: torch.nn.Module,
# TODO: Support dynamic dimension sizes. See `torch.onnx.export`'s
# `dynamic_axes` for API inspiration, or do something more ergonomic
# like a tensor wrapper possibly.
# TODO: Support tracing the model instead of scripting it.
scripted = torch.jit.script(model)
if use_tracing:
scripted = torch.jit.trace(model, tuple(example_args))
else:
scripted = torch.jit.script(model)
if isinstance(example_args, torch.Tensor):
example_args = [example_args]

View File

@ -63,9 +63,10 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
}
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) {
createMlirOperationAtEnd(
appendToBlock, "func.return", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
createMlirOperationAtEnd(appendToBlock, "func.return", loc,
adjustStaticInformationForValues(
appendToBlock, loc, yieldedValues, resultTypes,
/*userAllowsRefinement=*/false));
};
MlirBlock block = importBlock(
context, torch::jit::toGraphFunction(*function).graph()->block(),

View File

@ -19,6 +19,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "torch-mlir-c/TorchOps.h"
#include "torch-mlir-c/TorchTypes.h"
namespace py = pybind11;
@ -114,16 +115,40 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
}
// Builtin interpreter ops with no operator/schema.
InputsTransformFn transformer =
kind != c10::prim::DictConstruct ? nullptr : rearrangeDictConstructInputs;
switch (kind) {
case c10::prim::ListUnpack:
case c10::prim::ListConstruct:
case c10::prim::TupleConstruct:
case c10::prim::DictConstruct:
case c10::prim::CreateObject: {
createAndMapTrivialNode(
node, "torch.prim." + std::string(kind.toUnqualString()), transformer);
node, "torch.prim." + std::string(kind.toUnqualString()), nullptr);
return;
}
case c10::prim::TupleConstruct: {
// TODO: We will probably need to adjust the static information for
// ListConstruct and DictConstruct too.
auto containedTypes = c10::fmap(
node->output()->type()->cast<c10::TupleType>()->containedTypes(),
[&](const c10::TypePtr &t) {
MlirType type = getMlirTypeFromTorchType(loc, t);
if (mlirTypeIsNull(type)) {
throw mlir_diagnostic_emitted();
}
return type;
});
createAndMapTrivialNode(node,
"torch.prim." + std::string(kind.toUnqualString()),
[&](std::vector<MlirValue> &inputs) {
assert(containedTypes.size() == inputs.size());
return adjustStaticInformationForValues(
appendToBlock, loc, inputs, containedTypes,
/*userAllowsRefinement=*/true);
});
return;
}
case c10::prim::DictConstruct: {
createAndMapTrivialNode(node,
"torch.prim." + std::string(kind.toUnqualString()),
rearrangeDictConstructInputs);
return;
}
case c10::prim::GetAttr:
@ -213,8 +238,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.Loop", loc, resultTypes,
lookupMappedValues(node->inputs().slice(0, 2)),
derefineValues(lookupMappedValues(node->inputs().slice(2)), resultTypes,
loc, appendToBlock),
adjustStaticInformationForValues(
appendToBlock, loc, lookupMappedValues(node->inputs().slice(2)),
resultTypes, /*userAllowsRefinement=*/false),
mlirRegionCreate());
mapResults(node, operation);
std::vector<MlirType> terminatorOperandTypes = {
@ -223,10 +249,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
resultTypes.begin(), resultTypes.end());
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) {
createMlirOperationAtEnd(appendToBlock, "torch.prim.Loop.condition", loc,
derefineValues(yieldedValues,
terminatorOperandTypes, loc,
appendToBlock));
createMlirOperationAtEnd(
appendToBlock, "torch.prim.Loop.condition", loc,
adjustStaticInformationForValues(appendToBlock, loc, yieldedValues,
terminatorOperandTypes,
/*userAllowsRefinement=*/false));
};
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0),
@ -245,7 +272,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
MlirBlock appendToBlock) {
createMlirOperationAtEnd(
appendToBlock, "torch.prim.If.yield", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
adjustStaticInformationForValues(appendToBlock, loc, yieldedValues,
resultTypes,
/*userAllowsRefinement=*/false));
};
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0),
@ -269,8 +298,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.CallMethod", loc,
getMlirTypesFromValues(loc, node->outputs()),
derefineValues(lookupMappedValues(node->inputs()), expectedTypes, loc,
appendToBlock),
adjustStaticInformationForValues(
appendToBlock, loc, lookupMappedValues(node->inputs()),
expectedTypes, /*userAllowsRefinement=*/false),
toMlirNamedAttribute("name",
importAttribute(loc, node, c10::attr::name)));
mapResults(node, operation);
@ -288,8 +318,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
appendToBlock, "func.call_indirect", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValue(node->input(0)),
derefineValues(lookupMappedValues(node->inputs().slice(1)),
expectedTypes, loc, appendToBlock));
adjustStaticInformationForValues(
appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)),
expectedTypes, /*userAllowsRefinement=*/false));
mapResults(node, operation);
return;
}
@ -315,36 +346,6 @@ MlirBlock NodeImporter::importBlock(
return block;
}
static MlirValue adjustBlockArgType(MlirContext context,
MlirBlock appendToBlock, MlirValue value,
MlirType expectedType, MlirLocation loc) {
MlirType type = mlirValueGetType(value);
if (mlirTypeEqual(type, expectedType)) {
return value;
}
// For tensors, we might need to erase or add static type information.
if (torchMlirTypeIsATorchNonValueTensor(type) ||
torchMlirTypeIsATorchValueTensor(type)) {
MlirOperation op =
createMlirOperationAtEnd(appendToBlock, "torch.tensor_static_info_cast",
loc, expectedType, value);
return mlirOperationGetResult(op, 0);
}
{
std::stringstream msg;
MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) {
std::stringstream *stream = static_cast<std::stringstream *>(userData);
stream->write(str.data, str.length);
};
msg << "unhandled: could not adjust formal param type from ";
mlirTypePrint(type, printToStream, static_cast<void *>(&msg));
msg << " to expected type ";
mlirTypePrint(expectedType, printToStream, static_cast<void *>(&msg));
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
}
MlirBlock NodeImporter::createBlockFor(
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
Node *paramNode = jitBlock->param_node();
@ -362,8 +363,9 @@ MlirBlock NodeImporter::createBlockFor(
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
Value *jitValue = paramNode->outputs()[i];
MlirValue value = mlirBlockGetArgument(block, i);
MlirValue adjusted =
adjustBlockArgType(context, block, value, paramNodeTypes[i], loc);
MlirValue adjusted = adjustStaticInformationForValues(
block, loc, {value}, {paramNodeTypes[i]},
/*userAllowsRefinement=*/false)[0];
mapValue(jitValue, adjusted);
}
return block;

View File

@ -19,6 +19,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "torch-mlir-c/TorchOps.h"
#include "torch-mlir-c/TorchTypes.h"
using namespace torch_mlir;
@ -380,24 +381,34 @@ torch_mlir::getMlirTypesFromValues(MlirLocation loc,
return ret;
}
std::vector<MlirValue>
torch_mlir::derefineValues(c10::ArrayRef<MlirValue> values,
c10::ArrayRef<MlirType> expectedTypes,
MlirLocation loc, MlirBlock appendToBlock) {
std::vector<MlirValue> torch_mlir::adjustStaticInformationForValues(
MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values,
c10::ArrayRef<MlirType> desiredTypes, bool userAllowsRefinement) {
std::vector<MlirValue> ret;
assert(values.size() == expectedTypes.size());
assert(values.size() == desiredTypes.size());
for (int i = 0, e = values.size(); i != e; i++) {
MlirValue value = values[i];
MlirType expectedType = expectedTypes[i];
MlirType expectedType = desiredTypes[i];
MlirType type = mlirValueGetType(value);
if (mlirTypeEqual(expectedType, type)) {
// No need to derefine.
ret.push_back(value);
} else {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.derefine", loc, expectedType, value);
ret.push_back(mlirOperationGetResult(operation, 0));
MlirValue adjusted = torchMlirAdjustStaticInformation(
appendToBlock, mlirBlockGetTerminator(appendToBlock), value,
expectedType, userAllowsRefinement);
if (!mlirValueIsNull(adjusted)) {
ret.push_back(adjusted);
continue;
}
std::stringstream msg;
MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) {
std::stringstream *stream = static_cast<std::stringstream *>(userData);
stream->write(str.data, str.length);
};
msg << "unhandled: could not adjust static info for type from ";
mlirTypePrint(type, printToStream, static_cast<void *>(&msg));
msg << " to type ";
mlirTypePrint(expectedType, printToStream, static_cast<void *>(&msg));
mlirEmitError(loc, msg.str().c_str());
throw mlir_diagnostic_emitted();
}
return ret;
}

View File

@ -60,10 +60,9 @@ std::vector<MlirType>
getMlirTypesFromValues(MlirLocation loc,
c10::ArrayRef<torch::jit::Value *> values);
std::vector<MlirValue> derefineValues(c10::ArrayRef<MlirValue> values,
c10::ArrayRef<MlirType> expectedTypes,
MlirLocation loc,
MlirBlock appendToBlock);
std::vector<MlirValue> adjustStaticInformationForValues(
MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values,
c10::ArrayRef<MlirType> desiredTypes, bool userAllowsRefinement);
/// Create the appropriate MLIR operation for the Torch operator with schema
/// "schema".

View File

@ -2,31 +2,23 @@
# This file is licensed under a pytorch-style license
# See LICENSE.pytorch for license information.
import torch
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
from torch._C import CompilationUnit
from utils import create_script_function
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
# Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str):
cu = CompilationUnit()
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))
mb = ModuleBuilder()
# CHECK-LABEL: func @__torch__.refined_block_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
# CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32>
# CHECK: %[[RESULT:.*]] = torch.derefine %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor
# CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor
# CHECK: return %[[RESULT]] : !torch.tensor
script_function = create_script_function('__torch__.refined_block_arg', '''
mb.import_function(create_script_function("__torch__.refined_block_arg", """
graph(%0 : Float(1, 384)):
return (%0)
''')
mb = ModuleBuilder()
mb.import_function(script_function)
"""))
mb.module.operation.print()
print()

View File

@ -5,21 +5,16 @@
import typing
import torch
from torch._C import CompilationUnit
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
from utils import create_script_function
import typing
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
mb = ModuleBuilder()
# Import TorchScript IR string as ScriptFunction.
def import_ts_ir(func_name, ts_ir_str):
cu = CompilationUnit()
mb.import_function(cu.create_function(func_name, torch._C.parse_ir(ts_ir_str)))
# CHECK-LABEL: func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
@ -159,9 +154,11 @@ def prim_max(x: int):
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[A]], %[[B]], %[[C]] :
# CHECK-SAME: (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
# CHECK: return %[[RET]] : !torch.list<int>
import_ts_ir('__torch__.prim_Constant_list', '''graph():
mb.import_function(create_script_function("__torch__.prim_Constant_list", """
graph():
%list : int[] = prim::Constant[value=[1, 2, 3]]()
return (%list)''')
return (%list)
"""))
mb.module.operation.print()
print()

View File

@ -4,9 +4,10 @@
import torch
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
import collections
from typing import Tuple, Optional, NamedTuple
from utils import create_script_function
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
mb = ModuleBuilder()
@ -20,8 +21,6 @@ NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]),
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
# CHECK: return %[[RET]] : !torch.tuple<tensor, tensor>
@mb.import_function
@torch.jit.script
def tuple(t0, t1):
@ -38,8 +37,6 @@ def tuple(t0, t1):
# CHECK-SAME: !torch.tuple<tensor, tensor> to
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>>
# CHECK: return %[[RET]] : !torch.tuple<optional<tensor>, optional<tensor>>
@mb.import_function
@torch.jit.script
def tuple_optional(
@ -55,8 +52,6 @@ def tuple_optional(
# CHECK-SAME: !torch.tensor, !torch.tensor ->
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>>
# CHECK: return %[[RET]] : !torch.tuple<optional<tensor>, optional<tensor>>
# CHECK: }
#
@mb.import_function
@torch.jit.script
def namedtuple_optional(
@ -64,5 +59,20 @@ def namedtuple_optional(
return NT(t0, t1)
# CHECK-LABEL: func @__torch__.tuple_construct_arg_needs_refinement(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple<tensor, tensor> {
# CHECK: %[[T0_REFINED:.*]] = torch.tensor_static_info_cast %[[T1]] : !torch.tensor to !torch.tensor<[4],f32>
# CHECK: %[[T1_REFINED:.*]] = torch.tensor_static_info_cast %[[T0]] : !torch.tensor to !torch.tensor<[3],f32>
# CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0_REFINED]], %[[T1_REFINED]] : !torch.tensor<[4],f32>, !torch.tensor<[3],f32> -> !torch.tuple<tensor<[4],f32>, tensor<[3],f32>>
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[TUPLE]] : !torch.tuple<tensor<[4],f32>, tensor<[3],f32>> to !torch.tuple<tensor, tensor>
# CHECK: return %[[DEREFINED]] : !torch.tuple<tensor, tensor>
mb.import_function(create_script_function("__torch__.tuple_construct_arg_needs_refinement", """
graph(%t0 : Tensor,
%t1 : Tensor):
%10 : (Float(4), Float(3)) = prim::TupleConstruct(%t1, %t0)
return (%10)
"""))
mb.module.operation.print()
print()

View File

@ -0,0 +1,15 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE.pytorch for license information.
# Helpers for the other tests.
import torch
from torch._C import CompilationUnit
# RUN: %PYTHON %s
# Import TorchScript IR string as ScriptFunction.
def create_script_function(func_name, ts_ir_str):
cu = CompilationUnit()
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))