mirror of https://github.com/llvm/torch-mlir
Add tracing suport to `torch_mlir.compile`.
This also has a fix for the adjustment of types of TupleConstruct inputs, which I found when using this new functionality on a model. Some scenarios in tracing create situations where the output of TupleConstruct has a more refined type than the inputs. This introduces a helper `adjustStaticInformationForValues` which subsumes the `derefineValues` helper and the tensor static information adjustment we were doing.pull/775/head
parent
641b0b3be2
commit
ab5ad7af09
|
@ -0,0 +1,39 @@
|
|||
//===-- torch-mlir-c/TorchOps.h - C API for torch ops -------------*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIR_C_TORCHOPS_H
|
||||
#define TORCHMLIR_C_TORCHOPS_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Adjusts the static information in the type of `value` to `desiredType`.
|
||||
///
|
||||
/// Returns null if such an adjustment is not possible.
|
||||
///
|
||||
/// If `userAllowsRefinement` is true, then the original value will be returned
|
||||
/// if it is a subtype of `desiredType`.
|
||||
MLIR_CAPI_EXPORTED MlirValue torchMlirAdjustStaticInformation(
|
||||
MlirBlock block, MlirOperation insertBefore, MlirValue value,
|
||||
MlirType desiredType, bool userAllowsRefinement);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // TORCHMLIR_C_TORCHOPS_H
|
|
@ -1,6 +1,7 @@
|
|||
add_mlir_public_c_api_library(TorchMLIRCAPI
|
||||
Dialects.cpp
|
||||
Registration.cpp
|
||||
TorchOps.cpp
|
||||
TorchTypes.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
//===- TorchOps.cpp - C Interface for torch ops ---------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir-c/TorchOps.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirValue torchMlirAdjustStaticInformation(MlirBlock block_,
|
||||
MlirOperation insertBefore_,
|
||||
MlirValue value_,
|
||||
MlirType desiredType_,
|
||||
bool userAllowsRefinement) {
|
||||
Block *block = unwrap(block_);
|
||||
Operation *insertBefore = unwrap(insertBefore_);
|
||||
OpBuilder builder(unwrap(mlirTypeGetContext(desiredType_)));
|
||||
builder.setInsertionPoint(block, insertBefore ? insertBefore->getIterator()
|
||||
: block->end());
|
||||
|
||||
Value value = unwrap(value_);
|
||||
Type type = value.getType();
|
||||
Type desiredType = unwrap(desiredType_);
|
||||
|
||||
// If the value is already of the desired type, we're done.
|
||||
if (type == desiredType)
|
||||
return wrap(value);
|
||||
|
||||
// If the type is a tensor, then adjust the static information.
|
||||
if ((type.isa<Torch::ValueTensorType>() &&
|
||||
desiredType.isa<Torch::ValueTensorType>()) ||
|
||||
(type.isa<Torch::NonValueTensorType>() &&
|
||||
desiredType.isa<Torch::NonValueTensorType>())) {
|
||||
Value adjusted = builder.create<Torch::TensorStaticInfoCastOp>(
|
||||
value.getLoc(), desiredType, value);
|
||||
return wrap(adjusted);
|
||||
}
|
||||
|
||||
// If the type is a subtype of desiredType, then we need to derefine it to
|
||||
// desiredType, unless the user allows refinement.
|
||||
if (Torch::isValidSubtype(type, desiredType)) {
|
||||
if (!userAllowsRefinement) {
|
||||
Value adjusted =
|
||||
builder.create<Torch::DerefineOp>(value.getLoc(), desiredType, value);
|
||||
return wrap(adjusted);
|
||||
} else {
|
||||
return wrap(value);
|
||||
}
|
||||
}
|
||||
|
||||
// If the desiredType is subtype of type, then we assume that the desiredType
|
||||
// is dynamically valid, so we do an unchecked cast.
|
||||
if (Torch::isValidSubtype(desiredType, type)) {
|
||||
Value adjusted = builder.create<Torch::PrimUncheckedCastOp>(
|
||||
value.getLoc(), desiredType, value);
|
||||
return wrap(adjusted);
|
||||
}
|
||||
|
||||
// No known adjustment.
|
||||
return {};
|
||||
}
|
|
@ -34,7 +34,8 @@ class OutputType(Enum):
|
|||
|
||||
def compile(model: torch.nn.Module,
|
||||
example_args: List[torch.Tensor],
|
||||
output_type: OutputType = OutputType.TORCH):
|
||||
output_type: OutputType = OutputType.TORCH,
|
||||
use_tracing=False):
|
||||
"""Convert a PyTorch model to MLIR.
|
||||
|
||||
Args:
|
||||
|
@ -44,6 +45,8 @@ def compile(model: torch.nn.Module,
|
|||
A single tensor is treated as a list of a single tensor.
|
||||
output_type: The kind of output to produce. See `OutputType` for more
|
||||
details.
|
||||
use_tracing: If True, use `torch.jit.trace` to convert the model to
|
||||
JIT IR rather than `torch.jit.script`.
|
||||
|
||||
Returns:
|
||||
An MLIR module that contains the converted model in the specified
|
||||
|
@ -55,8 +58,10 @@ def compile(model: torch.nn.Module,
|
|||
# TODO: Support dynamic dimension sizes. See `torch.onnx.export`'s
|
||||
# `dynamic_axes` for API inspiration, or do something more ergonomic
|
||||
# like a tensor wrapper possibly.
|
||||
# TODO: Support tracing the model instead of scripting it.
|
||||
scripted = torch.jit.script(model)
|
||||
if use_tracing:
|
||||
scripted = torch.jit.trace(model, tuple(example_args))
|
||||
else:
|
||||
scripted = torch.jit.script(model)
|
||||
|
||||
if isinstance(example_args, torch.Tensor):
|
||||
example_args = [example_args]
|
||||
|
|
|
@ -63,9 +63,10 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
|||
}
|
||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||
MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(
|
||||
appendToBlock, "func.return", loc,
|
||||
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||
createMlirOperationAtEnd(appendToBlock, "func.return", loc,
|
||||
adjustStaticInformationForValues(
|
||||
appendToBlock, loc, yieldedValues, resultTypes,
|
||||
/*userAllowsRefinement=*/false));
|
||||
};
|
||||
MlirBlock block = importBlock(
|
||||
context, torch::jit::toGraphFunction(*function).graph()->block(),
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "torch-mlir-c/TorchOps.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
@ -114,16 +115,40 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
}
|
||||
|
||||
// Builtin interpreter ops with no operator/schema.
|
||||
InputsTransformFn transformer =
|
||||
kind != c10::prim::DictConstruct ? nullptr : rearrangeDictConstructInputs;
|
||||
switch (kind) {
|
||||
case c10::prim::ListUnpack:
|
||||
case c10::prim::ListConstruct:
|
||||
case c10::prim::TupleConstruct:
|
||||
case c10::prim::DictConstruct:
|
||||
case c10::prim::CreateObject: {
|
||||
createAndMapTrivialNode(
|
||||
node, "torch.prim." + std::string(kind.toUnqualString()), transformer);
|
||||
node, "torch.prim." + std::string(kind.toUnqualString()), nullptr);
|
||||
return;
|
||||
}
|
||||
case c10::prim::TupleConstruct: {
|
||||
// TODO: We will probably need to adjust the static information for
|
||||
// ListConstruct and DictConstruct too.
|
||||
auto containedTypes = c10::fmap(
|
||||
node->output()->type()->cast<c10::TupleType>()->containedTypes(),
|
||||
[&](const c10::TypePtr &t) {
|
||||
MlirType type = getMlirTypeFromTorchType(loc, t);
|
||||
if (mlirTypeIsNull(type)) {
|
||||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
return type;
|
||||
});
|
||||
createAndMapTrivialNode(node,
|
||||
"torch.prim." + std::string(kind.toUnqualString()),
|
||||
[&](std::vector<MlirValue> &inputs) {
|
||||
assert(containedTypes.size() == inputs.size());
|
||||
return adjustStaticInformationForValues(
|
||||
appendToBlock, loc, inputs, containedTypes,
|
||||
/*userAllowsRefinement=*/true);
|
||||
});
|
||||
return;
|
||||
}
|
||||
case c10::prim::DictConstruct: {
|
||||
createAndMapTrivialNode(node,
|
||||
"torch.prim." + std::string(kind.toUnqualString()),
|
||||
rearrangeDictConstructInputs);
|
||||
return;
|
||||
}
|
||||
case c10::prim::GetAttr:
|
||||
|
@ -213,8 +238,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.Loop", loc, resultTypes,
|
||||
lookupMappedValues(node->inputs().slice(0, 2)),
|
||||
derefineValues(lookupMappedValues(node->inputs().slice(2)), resultTypes,
|
||||
loc, appendToBlock),
|
||||
adjustStaticInformationForValues(
|
||||
appendToBlock, loc, lookupMappedValues(node->inputs().slice(2)),
|
||||
resultTypes, /*userAllowsRefinement=*/false),
|
||||
mlirRegionCreate());
|
||||
mapResults(node, operation);
|
||||
std::vector<MlirType> terminatorOperandTypes = {
|
||||
|
@ -223,10 +249,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
resultTypes.begin(), resultTypes.end());
|
||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||
MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(appendToBlock, "torch.prim.Loop.condition", loc,
|
||||
derefineValues(yieldedValues,
|
||||
terminatorOperandTypes, loc,
|
||||
appendToBlock));
|
||||
createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.Loop.condition", loc,
|
||||
adjustStaticInformationForValues(appendToBlock, loc, yieldedValues,
|
||||
terminatorOperandTypes,
|
||||
/*userAllowsRefinement=*/false));
|
||||
};
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 0),
|
||||
|
@ -245,7 +272,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.If.yield", loc,
|
||||
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||
adjustStaticInformationForValues(appendToBlock, loc, yieldedValues,
|
||||
resultTypes,
|
||||
/*userAllowsRefinement=*/false));
|
||||
};
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 0),
|
||||
|
@ -269,8 +298,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.CallMethod", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
derefineValues(lookupMappedValues(node->inputs()), expectedTypes, loc,
|
||||
appendToBlock),
|
||||
adjustStaticInformationForValues(
|
||||
appendToBlock, loc, lookupMappedValues(node->inputs()),
|
||||
expectedTypes, /*userAllowsRefinement=*/false),
|
||||
toMlirNamedAttribute("name",
|
||||
importAttribute(loc, node, c10::attr::name)));
|
||||
mapResults(node, operation);
|
||||
|
@ -288,8 +318,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
appendToBlock, "func.call_indirect", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValue(node->input(0)),
|
||||
derefineValues(lookupMappedValues(node->inputs().slice(1)),
|
||||
expectedTypes, loc, appendToBlock));
|
||||
adjustStaticInformationForValues(
|
||||
appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)),
|
||||
expectedTypes, /*userAllowsRefinement=*/false));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
@ -315,36 +346,6 @@ MlirBlock NodeImporter::importBlock(
|
|||
return block;
|
||||
}
|
||||
|
||||
static MlirValue adjustBlockArgType(MlirContext context,
|
||||
MlirBlock appendToBlock, MlirValue value,
|
||||
MlirType expectedType, MlirLocation loc) {
|
||||
MlirType type = mlirValueGetType(value);
|
||||
if (mlirTypeEqual(type, expectedType)) {
|
||||
return value;
|
||||
}
|
||||
// For tensors, we might need to erase or add static type information.
|
||||
if (torchMlirTypeIsATorchNonValueTensor(type) ||
|
||||
torchMlirTypeIsATorchValueTensor(type)) {
|
||||
MlirOperation op =
|
||||
createMlirOperationAtEnd(appendToBlock, "torch.tensor_static_info_cast",
|
||||
loc, expectedType, value);
|
||||
return mlirOperationGetResult(op, 0);
|
||||
}
|
||||
{
|
||||
std::stringstream msg;
|
||||
MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) {
|
||||
std::stringstream *stream = static_cast<std::stringstream *>(userData);
|
||||
stream->write(str.data, str.length);
|
||||
};
|
||||
msg << "unhandled: could not adjust formal param type from ";
|
||||
mlirTypePrint(type, printToStream, static_cast<void *>(&msg));
|
||||
msg << " to expected type ";
|
||||
mlirTypePrint(expectedType, printToStream, static_cast<void *>(&msg));
|
||||
mlirEmitError(loc, msg.str().c_str());
|
||||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
}
|
||||
|
||||
MlirBlock NodeImporter::createBlockFor(
|
||||
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
|
||||
Node *paramNode = jitBlock->param_node();
|
||||
|
@ -362,8 +363,9 @@ MlirBlock NodeImporter::createBlockFor(
|
|||
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
||||
Value *jitValue = paramNode->outputs()[i];
|
||||
MlirValue value = mlirBlockGetArgument(block, i);
|
||||
MlirValue adjusted =
|
||||
adjustBlockArgType(context, block, value, paramNodeTypes[i], loc);
|
||||
MlirValue adjusted = adjustStaticInformationForValues(
|
||||
block, loc, {value}, {paramNodeTypes[i]},
|
||||
/*userAllowsRefinement=*/false)[0];
|
||||
mapValue(jitValue, adjusted);
|
||||
}
|
||||
return block;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "torch-mlir-c/TorchOps.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
@ -380,24 +381,34 @@ torch_mlir::getMlirTypesFromValues(MlirLocation loc,
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::vector<MlirValue>
|
||||
torch_mlir::derefineValues(c10::ArrayRef<MlirValue> values,
|
||||
c10::ArrayRef<MlirType> expectedTypes,
|
||||
MlirLocation loc, MlirBlock appendToBlock) {
|
||||
std::vector<MlirValue> torch_mlir::adjustStaticInformationForValues(
|
||||
MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values,
|
||||
c10::ArrayRef<MlirType> desiredTypes, bool userAllowsRefinement) {
|
||||
std::vector<MlirValue> ret;
|
||||
assert(values.size() == expectedTypes.size());
|
||||
assert(values.size() == desiredTypes.size());
|
||||
for (int i = 0, e = values.size(); i != e; i++) {
|
||||
MlirValue value = values[i];
|
||||
MlirType expectedType = expectedTypes[i];
|
||||
MlirType expectedType = desiredTypes[i];
|
||||
MlirType type = mlirValueGetType(value);
|
||||
if (mlirTypeEqual(expectedType, type)) {
|
||||
// No need to derefine.
|
||||
ret.push_back(value);
|
||||
} else {
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.derefine", loc, expectedType, value);
|
||||
ret.push_back(mlirOperationGetResult(operation, 0));
|
||||
MlirValue adjusted = torchMlirAdjustStaticInformation(
|
||||
appendToBlock, mlirBlockGetTerminator(appendToBlock), value,
|
||||
expectedType, userAllowsRefinement);
|
||||
if (!mlirValueIsNull(adjusted)) {
|
||||
ret.push_back(adjusted);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::stringstream msg;
|
||||
MlirStringCallback printToStream = +[](MlirStringRef str, void *userData) {
|
||||
std::stringstream *stream = static_cast<std::stringstream *>(userData);
|
||||
stream->write(str.data, str.length);
|
||||
};
|
||||
msg << "unhandled: could not adjust static info for type from ";
|
||||
mlirTypePrint(type, printToStream, static_cast<void *>(&msg));
|
||||
msg << " to type ";
|
||||
mlirTypePrint(expectedType, printToStream, static_cast<void *>(&msg));
|
||||
mlirEmitError(loc, msg.str().c_str());
|
||||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -60,10 +60,9 @@ std::vector<MlirType>
|
|||
getMlirTypesFromValues(MlirLocation loc,
|
||||
c10::ArrayRef<torch::jit::Value *> values);
|
||||
|
||||
std::vector<MlirValue> derefineValues(c10::ArrayRef<MlirValue> values,
|
||||
c10::ArrayRef<MlirType> expectedTypes,
|
||||
MlirLocation loc,
|
||||
MlirBlock appendToBlock);
|
||||
std::vector<MlirValue> adjustStaticInformationForValues(
|
||||
MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values,
|
||||
c10::ArrayRef<MlirType> desiredTypes, bool userAllowsRefinement);
|
||||
|
||||
/// Create the appropriate MLIR operation for the Torch operator with schema
|
||||
/// "schema".
|
||||
|
|
|
@ -2,31 +2,23 @@
|
|||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE.pytorch for license information.
|
||||
|
||||
import torch
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
|
||||
from torch._C import CompilationUnit
|
||||
|
||||
from utils import create_script_function
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
# Import TorchScript IR string as ScriptFunction.
|
||||
def create_script_function(func_name, ts_ir_str):
|
||||
cu = CompilationUnit()
|
||||
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))
|
||||
mb = ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @__torch__.refined_block_arg(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
# CHECK: %[[REFINED:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.tensor to !torch.tensor<[1,384],f32>
|
||||
# CHECK: %[[RESULT:.*]] = torch.derefine %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor
|
||||
# CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[REFINED]] : !torch.tensor<[1,384],f32> to !torch.tensor
|
||||
# CHECK: return %[[RESULT]] : !torch.tensor
|
||||
script_function = create_script_function('__torch__.refined_block_arg', '''
|
||||
mb.import_function(create_script_function("__torch__.refined_block_arg", """
|
||||
graph(%0 : Float(1, 384)):
|
||||
return (%0)
|
||||
''')
|
||||
|
||||
mb = ModuleBuilder()
|
||||
mb.import_function(script_function)
|
||||
"""))
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -5,21 +5,16 @@
|
|||
import typing
|
||||
|
||||
import torch
|
||||
from torch._C import CompilationUnit
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
|
||||
from utils import create_script_function
|
||||
|
||||
import typing
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
mb = ModuleBuilder()
|
||||
|
||||
|
||||
# Import TorchScript IR string as ScriptFunction.
|
||||
def import_ts_ir(func_name, ts_ir_str):
|
||||
cu = CompilationUnit()
|
||||
mb.import_function(cu.create_function(func_name, torch._C.parse_ir(ts_ir_str)))
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_NumToTensor(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor
|
||||
|
@ -159,9 +154,11 @@ def prim_max(x: int):
|
|||
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[A]], %[[B]], %[[C]] :
|
||||
# CHECK-SAME: (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
# CHECK: return %[[RET]] : !torch.list<int>
|
||||
import_ts_ir('__torch__.prim_Constant_list', '''graph():
|
||||
mb.import_function(create_script_function("__torch__.prim_Constant_list", """
|
||||
graph():
|
||||
%list : int[] = prim::Constant[value=[1, 2, 3]]()
|
||||
return (%list)''')
|
||||
return (%list)
|
||||
"""))
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -4,9 +4,10 @@
|
|||
|
||||
import torch
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder
|
||||
import collections
|
||||
from typing import Tuple, Optional, NamedTuple
|
||||
|
||||
from utils import create_script_function
|
||||
|
||||
# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s
|
||||
|
||||
mb = ModuleBuilder()
|
||||
|
@ -20,8 +21,6 @@ NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]),
|
|||
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] :
|
||||
# CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor>
|
||||
# CHECK: return %[[RET]] : !torch.tuple<tensor, tensor>
|
||||
|
||||
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def tuple(t0, t1):
|
||||
|
@ -38,8 +37,6 @@ def tuple(t0, t1):
|
|||
# CHECK-SAME: !torch.tuple<tensor, tensor> to
|
||||
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>>
|
||||
# CHECK: return %[[RET]] : !torch.tuple<optional<tensor>, optional<tensor>>
|
||||
|
||||
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def tuple_optional(
|
||||
|
@ -55,8 +52,6 @@ def tuple_optional(
|
|||
# CHECK-SAME: !torch.tensor, !torch.tensor ->
|
||||
# CHECK-SAME: !torch.tuple<optional<tensor>, optional<tensor>>
|
||||
# CHECK: return %[[RET]] : !torch.tuple<optional<tensor>, optional<tensor>>
|
||||
# CHECK: }
|
||||
#
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def namedtuple_optional(
|
||||
|
@ -64,5 +59,20 @@ def namedtuple_optional(
|
|||
return NT(t0, t1)
|
||||
|
||||
|
||||
# CHECK-LABEL: func @__torch__.tuple_construct_arg_needs_refinement(
|
||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple<tensor, tensor> {
|
||||
# CHECK: %[[T0_REFINED:.*]] = torch.tensor_static_info_cast %[[T1]] : !torch.tensor to !torch.tensor<[4],f32>
|
||||
# CHECK: %[[T1_REFINED:.*]] = torch.tensor_static_info_cast %[[T0]] : !torch.tensor to !torch.tensor<[3],f32>
|
||||
# CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[T0_REFINED]], %[[T1_REFINED]] : !torch.tensor<[4],f32>, !torch.tensor<[3],f32> -> !torch.tuple<tensor<[4],f32>, tensor<[3],f32>>
|
||||
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[TUPLE]] : !torch.tuple<tensor<[4],f32>, tensor<[3],f32>> to !torch.tuple<tensor, tensor>
|
||||
# CHECK: return %[[DEREFINED]] : !torch.tuple<tensor, tensor>
|
||||
mb.import_function(create_script_function("__torch__.tuple_construct_arg_needs_refinement", """
|
||||
graph(%t0 : Tensor,
|
||||
%t1 : Tensor):
|
||||
%10 : (Float(4), Float(3)) = prim::TupleConstruct(%t1, %t0)
|
||||
return (%10)
|
||||
"""))
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See LICENSE.pytorch for license information.
|
||||
|
||||
# Helpers for the other tests.
|
||||
|
||||
import torch
|
||||
from torch._C import CompilationUnit
|
||||
|
||||
# RUN: %PYTHON %s
|
||||
|
||||
# Import TorchScript IR string as ScriptFunction.
|
||||
def create_script_function(func_name, ts_ir_str):
|
||||
cu = CompilationUnit()
|
||||
return cu.create_function(func_name, torch._C.parse_ir(ts_ir_str))
|
Loading…
Reference in New Issue