Properly model "derefinement".

In terms of IR structure, TorchScript allows types to vary in many
circumstances where MLIR requires pointer-identical types. In particular,
it is valid to pass any subtype in place of a type. For example, if an
`Optional[int]` is required somewhere in the IR, it is legal to pass a
value of just `int` (but not the other way around; see
`torch.prim.unchecked_cast`). In effect, every *use* can have a different
type.

We introduce a new op `torch.derefine` that models that impedance
mismatch. This op allows casting a value from one type to a type that it
is a subtype of to model this behavior.

Recommended review order:
- TorchOps.td for new torch.derefine (and updated docs for
  `torch.prim.unchecked_cast`)
- new test code in if.py, loop.py, function-derefine.py
- new code in node_importer.cpp for handling derefinement insertion
- function_importer.cpp and utils changes in torch_to_mlir_utils.cpp

Properly handling derefinement on function boundaries required
relayering the code so that graph_importer.cpp/.h is now
function_importer.cpp/.h because only the `torch::jit::Function`
(actually the `c10::FunctionSchema` it holds) knows the derefined types that are
actually needed at the boundary (see `function-derefine.py` for a test).

Annoyingly, this churns all the functions which are now prefixed with
`__torch__.` but that is more correct anyway (that is their linkage name
in the `torch::jit::CompilationUnit`; the previous `mb.import_function`
was actually buggy in the case of functions calling each other as it
would reference their unqualified name).

With this change, we can import `resnet18` from `torchvision` :)
IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
pull/178/head
Sean Silva 2021-03-01 17:24:15 -08:00
parent 1736ff0253
commit 43dba03afd
26 changed files with 371 additions and 154 deletions

View File

@ -15,7 +15,7 @@ add_library(NPCOMPTorchMLIRExt SHARED
builder/class_annotator.cpp
builder/debug.cpp
builder/func_builder.cpp
builder/graph_importer.cpp
builder/function_importer.cpp
builder/module_builder.cpp
builder/node_importer.cpp
builder/op_builder.cpp

View File

@ -0,0 +1,56 @@
//===- function_importer.cpp ----------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "function_importer.h"
#include <unordered_map>
#include "mlir_utils.h"
#include "torch_to_mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
namespace py = pybind11;
using namespace torch_mlir;
MlirOperation
torch_mlir::importJitFunctionAsFuncOp(MlirContext context,
torch::jit::Function *function) {
// Useful for debugging:
// graph->dump();
MlirLocation loc = mlirLocationUnknownGet(context);
MlirType functionType =
getFunctionTypeFromSchema(context, function->getSchema());
// Use the function's qualified name from the compilation unit.
// This is a stable linkage name that matches Python module lookup
// conventions (see compilation unit import in IValueImporter for more details
// on qualified names).
MlirAttribute symNameAttr = mlirStringAttrGet(
context, toMlirStringRef(function->qualname().qualifiedName()));
MlirOperation func = createMlirOperation(
"func", loc, mlirRegionCreate(),
toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)),
toMlirNamedAttribute("sym_name", symNameAttr));
MlirRegion bodyRegion = mlirOperationGetRegion(func, 0);
std::vector<MlirType> resultTypes;
for (int i = 0, e = mlirFunctionTypeGetNumResults(functionType); i != e;
i++) {
resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i));
}
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) {
createMlirOperationAtEnd(
appendToBlock, "std.return", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
};
MlirBlock block =
importBlock(context, function->graph()->block(), createTerminator);
mlirRegionAppendOwnedBlock(bodyRegion, block);
return func;
}

View File

@ -0,0 +1,40 @@
//===- function_importer.h --------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H
#include <memory>
#include "../pybind.h"
#include "func_builder.h"
#include "node_importer.h"
#include "mlir-c/IR.h"
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch_mlir {
/// Main entry-point for importing torch::jit::Function instances.
///
/// This code doesn't handle importing of torch::jit::Module's. See
/// IValueImporter for that.
///
/// A torch::jit::Function holds a c10::FunctionSchema along with a
/// c10::QualifiedName and a torch::jit::Graph.
///
/// The torch::jit::Graph is a combination of an MLIR context, function, and
/// builder. See NodeImporter for importing of the core IR Node/Block
/// structure that is analogous to MLIR's Operation/Region/Block core structure.
MlirOperation importJitFunctionAsFuncOp(MlirContext context,
torch::jit::Function *function);
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H

View File

@ -1,38 +0,0 @@
//===- graph_importer.cpp -------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "graph_importer.h"
#include <unordered_map>
#include "mlir_utils.h"
#include "torch_to_mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
namespace py = pybind11;
using namespace torch_mlir;
MlirOperation torch_mlir::importGraphAsFuncOp(MlirContext context,
torch::jit::Graph *graph,
const std::string &name) {
// Useful for debugging:
// graph->dump();
MlirLocation loc = mlirLocationUnknownGet(context);
MlirAttribute typeAttr =
mlirTypeAttrGet(getFunctionTypeFromBlock(context, graph->block()));
MlirAttribute symNameAttr = mlirStringAttrGet(context, toMlirStringRef(name));
MlirOperation func = createMlirOperation(
"func", loc, mlirRegionCreate(), toMlirNamedAttribute("type", typeAttr),
toMlirNamedAttribute("sym_name", symNameAttr));
MlirRegion bodyRegion = mlirOperationGetRegion(func, 0);
MlirBlock block = importBlock(context, graph->block(), "std.return");
mlirRegionAppendOwnedBlock(bodyRegion, block);
return func;
}

View File

@ -1,37 +0,0 @@
//===- graph_importer.h -----------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
#include <memory>
#include "../pybind.h"
#include "func_builder.h"
#include "node_importer.h"
#include "mlir-c/IR.h"
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch_mlir {
/// Main entry-point for importing torch::jit::Graph instances.
///
/// This code doesn't handle importing of torch::jit::Module's. See
/// IValueImporter for that.
///
/// A Graph is a combination of an MLIR context, function, and builder.
/// See NodeImporter for importing of the core IR Node/Block
/// structure that is analogous to MLIR's Operation/Region/Block core structure.
MlirOperation importGraphAsFuncOp(MlirContext context, torch::jit::Graph *graph,
const std::string &name);
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H

View File

@ -7,7 +7,7 @@
#include "ivalue_importer.h"
#include "class_annotator.h"
#include "graph_importer.h"
#include "function_importer.h"
#include <unordered_map>
@ -393,8 +393,7 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
}
for (torch::jit::Function *function : cu->get_functions()) {
MlirOperation func = importGraphAsFuncOp(
context, function->graph().get(), function->qualname().qualifiedName());
MlirOperation func = importJitFunctionAsFuncOp(context, function);
// For IValue importing, the logical linkage structure of the module
// is determined by the object graph.
//

View File

@ -7,7 +7,7 @@
#include "module_builder.h"
#include "graph_importer.h"
#include "function_importer.h"
#include "ivalue_importer.h"
#include "mlir-c/Bindings/Python/Interop.h"
@ -128,8 +128,7 @@ torch::jit::StrongFunctionPtr
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
MlirBlock block = getBodyBlock();
MlirOperation terminator = this->terminator;
MlirOperation func = importGraphAsFuncOp(
context, function.function_->graph().get(), function.function_->name());
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_);
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
return function;
}

View File

@ -30,7 +30,7 @@ public:
NodeImporter(MlirContext context) : context(context) {}
void importNode(Node *node, MlirBlock appendToBlock);
MlirBlock importBlock(Block *jitBlock, const std::string &terminatorOpName);
MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator);
private:
void importPrimNode(Node *node, MlirBlock appendToBlock);
@ -70,7 +70,7 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
const std::string &symName = function->qualname().qualifiedName();
op = createMlirOperation(
"std.constant", loc,
getFunctionTypeFromBlock(context, function->graph()->block()),
getFunctionTypeFromSchema(context, function->getSchema()),
toMlirNamedAttribute(
"value",
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
@ -141,14 +141,28 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
}
if (kind == c10::prim::Loop) {
std::vector<MlirType> resultTypes =
getMlirTypesFromValues(loc, node->outputs());
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.Loop", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()), mlirRegionCreate());
appendToBlock, "torch.prim.Loop", loc, resultTypes,
lookupMappedValues(node->inputs().slice(0, 2)),
derefineValues(lookupMappedValues(node->inputs().slice(2)), resultTypes,
loc, appendToBlock),
mlirRegionCreate());
mapResults(node, operation);
std::vector<MlirType> terminatorOperandTypes = {npcompBoolTypeGet(context)};
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
resultTypes.begin(), resultTypes.end());
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) {
createMlirOperationAtEnd(appendToBlock, "torch.prim.Loop.condition", loc,
derefineValues(yieldedValues,
terminatorOperandTypes, loc,
appendToBlock));
};
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], "torch.prim.Loop.condition"));
importBlock(node->blocks()[0], createTerminator));
return;
}
@ -158,15 +172,24 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
MlirOperation pred = createMlirOperationAtEnd(
appendToBlock, "basicpy.bool_cast", loc, mlirIntegerTypeGet(context, 1),
lookupMappedValue(node->input()));
std::vector<MlirType> resultTypes =
getMlirTypesFromValues(loc, node->outputs());
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0),
getMlirTypesFromValues(loc, node->outputs()), mlirRegionCreate(),
mlirRegionCreate());
resultTypes, mlirRegionCreate(), mlirRegionCreate());
mapResults(node, operation);
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], "scf.yield"));
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1),
importBlock(node->blocks()[1], "scf.yield"));
auto createTerminator =
[&](c10::ArrayRef<MlirValue> yieldedValues, MlirBlock appendToBlock) {
createMlirOperationAtEnd(
appendToBlock, "scf.yield", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
};
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], createTerminator));
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 1),
importBlock(node->blocks()[1], createTerminator));
return;
}
@ -180,10 +203,18 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
}
if (kind == c10::prim::CallFunction) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, "std.call_indirect", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs()));
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
torch::jit::Block *calleeEntryBlock =
functionType->function()->graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return typeMapper.mapFromTorchType(loc, v->type());
});
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "std.call_indirect", loc,
getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValue(node->input(0)),
derefineValues(lookupMappedValues(node->inputs().slice(1)),
expectedTypes, loc, appendToBlock));
mapResults(node, operation);
return;
}
@ -288,15 +319,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
}
MlirBlock NodeImporter::importBlock(Block *jitBlock,
const std::string &terminatorOpName) {
CreateTerminatorFn createTerminator) {
MlirBlock block = createBlockFor(jitBlock);
for (Node *node : jitBlock->nodes()) {
importNode(node, block);
}
Node *returnNode = jitBlock->return_node();
createMlirOperationAtEnd(block, terminatorOpName,
getMlirLocationFromNode(context, returnNode),
lookupMappedValues(returnNode->inputs()));
createTerminator(lookupMappedValues(returnNode->inputs()), block);
return block;
}
@ -343,7 +372,7 @@ NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
}
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
const std::string &terminatorOpName) {
CreateTerminatorFn createTerminator) {
NodeImporter importer(context);
return importer.importBlock(jitBlock, terminatorOpName);
return importer.importBlock(jitBlock, createTerminator);
}

View File

@ -20,8 +20,11 @@
namespace torch_mlir {
using CreateTerminatorFn =
std::function<void(c10::ArrayRef<MlirValue>, MlirBlock)>;
MlirBlock importBlock(MlirContext context, torch::jit::Block *jitBlock,
const std::string &terminatorOpName);
CreateTerminatorFn createTerminator);
} // namespace torch_mlir

View File

@ -5,7 +5,7 @@
//
//===----------------------------------------------------------------------===//
#include "graph_importer.h"
#include "function_importer.h"
#include "ivalue_importer.h"
#include <unordered_map>
@ -163,17 +163,28 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
}
MlirType torch_mlir::getFunctionTypeFromBlock(MlirContext context,
torch::jit::Block *block) {
MlirLocation inputLoc = getMlirLocationFromNode(context, block->param_node());
MlirType
torch_mlir::getFunctionTypeFromSchema(MlirContext context,
const c10::FunctionSchema &schema) {
MlirLocation loc = mlirLocationUnknownGet(context);
TypeMapper typeMapper(context);
auto mapType = [&](const c10::TypePtr &torchType) {
MlirType type = typeMapper.mapFromTorchType(loc, torchType);
if (mlirTypeIsNull(type)) {
std::stringstream msg;
msg << "unsupported type in function schema: '"
<< c10::toString(torchType) << "'";
throw std::invalid_argument(msg.str());
}
return type;
};
std::vector<MlirType> inputTypes =
getMlirTypesFromValues(inputLoc, block->param_node()->outputs());
MlirLocation outputLoc =
getMlirLocationFromNode(context, block->return_node());
c10::fmap(schema.arguments(),
[&](const c10::Argument &arg) { return mapType(arg.type()); });
std::vector<MlirType> outputTypes =
getMlirTypesFromValues(outputLoc, block->return_node()->inputs());
c10::fmap(schema.returns(),
[&](const c10::Argument &arg) { return mapType(arg.type()); });
return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
outputTypes.size(), outputTypes.data());
}
@ -303,3 +314,25 @@ torch_mlir::getMlirTypesFromValues(MlirLocation loc,
}
return ret;
}
std::vector<MlirValue>
torch_mlir::derefineValues(c10::ArrayRef<MlirValue> values,
c10::ArrayRef<MlirType> expectedTypes,
MlirLocation loc, MlirBlock appendToBlock) {
std::vector<MlirValue> ret;
assert(values.size() == expectedTypes.size());
for (int i = 0, e = values.size(); i != e; i++) {
MlirValue value = values[i];
MlirType expectedType = expectedTypes[i];
MlirType type = mlirValueGetType(value);
if (mlirTypeEqual(expectedType, type)) {
// No need to derefine.
ret.push_back(value);
} else {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.derefine", loc, expectedType, value);
ret.push_back(mlirOperationGetResult(operation, 0));
}
}
return ret;
}

View File

@ -51,18 +51,12 @@ private:
MlirContext context;
};
/// Creates a FunctionType suitable for expressing the signature of `block`.
/// Creates a FunctionType suitable for expressing the signature of `schema`.
///
/// `mlir::Block` only has a formalized notion of argument types (bb args), but
/// the exact nature of the block's terminator is left opaque (for example, you
/// can have a weird terminator that "returns all but the first operand").
/// `torch::jit::Block` on the other hand has a formalized notion of a
/// `param_node` and `return_node`, which are effectively dummy operations at
/// the start and end of the block, which establish a formal signature for the
/// block and can be generically reasoned about -- that is what we anchor on
/// here.
MlirType getFunctionTypeFromBlock(MlirContext context,
torch::jit::Block *block);
/// This can differ from the type inferred from the block of a
/// torch::jit::Function due to derefinement.
MlirType getFunctionTypeFromSchema(MlirContext context,
const c10::FunctionSchema &schema);
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
MlirAttribute converTensorToMlirElementsAttr(at::Tensor tensor,
@ -78,6 +72,11 @@ std::vector<MlirType>
getMlirTypesFromValues(MlirLocation loc,
c10::ArrayRef<torch::jit::Value *> values);
std::vector<MlirValue> derefineValues(c10::ArrayRef<MlirValue> values,
c10::ArrayRef<MlirType> expectedTypes,
MlirLocation loc,
MlirBlock appendToBlock);
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H

View File

@ -10,7 +10,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder()
# Verify without debug info.
# CHECK-LABEL: func @add3
# CHECK-LABEL: func @__torch__.add3
# CHECK-SAME: (%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
# CHECK: %[[C1:.*]] = constant 1 : i64
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @add3
# CHECK-LABEL: func @__torch__.add3
# Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently
# this includes naked constants, return and the function itself. This heuristic

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: @f
# CHECK-LABEL: @__torch__.f
@mb.import_function
@torch.jit.script
def f(b: bool, i: int):

View File

@ -17,8 +17,8 @@ try:
@torch.jit.script
def import_class(x: typing.Any):
return x
except RuntimeError as e:
except Exception as e:
# TODO: Once diagnostics are enabled, verify the actual error emitted.
assert str(e) == "unsupported type"
assert str(e) == "unsupported type in function schema: 'Any'"
else:
assert False, "Expected exception"

View File

@ -0,0 +1,42 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import typing
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.optional<i64> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
# CHECK: return %[[RET]] : !torch.optional<i64>
@mb.import_function
@torch.jit.script
def optional_return(i: int) -> typing.Optional[int]:
return i
# CHECK-LABEL: func @__torch__.optional_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<i64>) -> !basicpy.NoneType {
@mb.import_function
@torch.jit.script
def optional_arg(i: typing.Optional[int]) -> None:
return
# CHECK-LABEL: func @__torch__.calls_optional_arg(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.NoneType {
# CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional<i64>) -> !basicpy.NoneType
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
# CHECK: %{{.*}} = call_indirect %[[CALLEE]](%[[DEREFINED]]) : (!torch.optional<i64>) -> !basicpy.NoneType
@mb.import_function
@torch.jit.script
def calls_optional_arg(i: int):
optional_arg(i)
mb.module.operation.print()
print()

View File

@ -9,12 +9,12 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: @f(
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
# CHECK-SAME: %[[I:.*]]: i64) -> i64 {
# CHECK-LABEL: @__torch__.prim_If(
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
# CHECK-SAME: %[[I:.*]]: i64) -> i64 {
@mb.import_function
@torch.jit.script
def f(b: bool, i: int):
def prim_If(b: bool, i: int):
# CHECK: %[[I1:.*]] = basicpy.bool_cast %[[B]] : !basicpy.BoolType -> i1
# CHECK: %[[RES:.*]] = scf.if %[[I1]] -> (i64) {
# CHECK: %[[ADD:.*]] = torch.kernel_call "aten::add" %[[I]], %[[I]]
@ -30,6 +30,25 @@ def f(b: bool, i: int):
return i * i
# elif is modeled as a nested if, so no need to specially test it here.
assert isinstance(f, torch.jit.ScriptFunction)
# CHECK-LABEL: func @__torch__.prim_If_derefine(
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
# CHECK-SAME: %[[I:.*]]: i64) -> !torch.optional<i64> {
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[PRED:.*]] = basicpy.bool_cast %[[B]] : !basicpy.BoolType -> i1
# CHECK: %[[RES:.*]] = scf.if %[[PRED]] -> (!torch.optional<i64>) {
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
# CHECK: scf.yield %[[NONE_DEREFINED]] : !torch.optional<i64>
# CHECK: } else {
# CHECK: %[[I_DEREFINED:.*]] = torch.derefine %[[I]] : i64 -> !torch.optional<i64>
# CHECK: scf.yield %[[I_DEREFINED]] : !torch.optional<i64>
# CHECK: }
# CHECK: return %[[RES:.*]] : !torch.optional<i64>
@mb.import_function
@torch.jit.script
def prim_If_derefine(b: bool, i: int):
if b:
return None
return i
mb.module.operation.print()
print()

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @f(
# CHECK-LABEL: func @__torch__.f(
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType {
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType

View File

@ -5,11 +5,13 @@
import torch
import torch_mlir
import typing
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @prim_Loop_forlike(
# CHECK-LABEL: func @__torch__.prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: i64) -> f64 {
# CHECK: %[[BOOL_TRUE:.*]] = basicpy.bool_constant true
# CHECK: %[[F_INIT:.*]] = constant 0.000000e+00 : f64
@ -27,18 +29,18 @@ def prim_Loop_forlike(n: int):
f += i
return f
# CHECK-LABEL: func @prim_Loop_whilelike(
# CHECK-LABEL: func @__torch__.prim_Loop_whilelike(
# CHECK-SAME: %[[VAL_0:.*]]: i64) -> f64 {
# CHECK: %[[F_INIT:.*]] = constant 3.200000e+00 : f64
# CHECK: %[[MAX_ITERATIONS:.*]] = constant 9223372036854775807 : i64
# CHECK: %[[COND_INIT:.*]] = torch.kernel_call "aten::lt" %[[F_INIT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
# CHECK: %[[IV:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) {
# CHECK: %[[RET:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) {
# CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64):
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
# CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
# CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
# CHECK: return %[[VAL_9:.*]] : f64
# CHECK: return %[[RET:.*]] : f64
@mb.import_function
@torch.jit.script
def prim_Loop_whilelike(n: int):
@ -47,5 +49,24 @@ def prim_Loop_whilelike(n: int):
f = f * f
return f
# CHECK-LABEL: func @__torch__.prim_Loop_derefine(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.optional<i64> {
# CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
# CHECK: %[[RET:.*]] = torch.prim.Loop %[[ARG]], %[[TRUE]], init(%[[NONE_DEREFINED]]) {
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[X_ITER:.*]]: !torch.optional<i64>):
# CHECK: %[[X_NEXT:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
# CHECK: torch.prim.Loop.condition %[[TRUE]] iter(%[[X_NEXT]]) : !basicpy.BoolType, (!torch.optional<i64>)
# CHECK: } : (i64, !basicpy.BoolType, !torch.optional<i64>) -> !torch.optional<i64>
# CHECK: return %[[RET:.*]] : !torch.optional<i64>
@mb.import_function
@torch.jit.script
def prim_Loop_derefine(n: int):
x: typing.Optional[int] = None
for i in range(n):
x = n
return x
mb.module.operation.print()
print()

View File

@ -14,7 +14,7 @@ import typing
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @prim_NumToTensor(
# CHECK-LABEL: func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
@ -25,7 +25,7 @@ mb = torch_mlir.ModuleBuilder()
def prim_NumToTensor(i: int):
return _to_tensor(i)
# CHECK-LABEL: func @prim_Print(
# CHECK-LABEL: func @__torch__.prim_Print(
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType {
# CHECK: %[[STR:.*]] = basicpy.bytes_constant "x"
# CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !basicpy.BytesType, !numpy.ndarray<*:!numpy.any_dtype>
@ -34,7 +34,7 @@ def prim_NumToTensor(i: int):
def prim_Print(x):
print("x", x)
# CHECK-LABEL: func @prim_RaiseException() -> !basicpy.NoneType {
# CHECK-LABEL: func @__torch__.prim_RaiseException() -> !basicpy.NoneType {
# CHECK: %[[ERRORSTR:.*]] = basicpy.bytes_constant "Error"
# CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !basicpy.NoneType
# CHECK: torch.prim.RaiseException %[[ERRORSTR]]
@ -44,7 +44,7 @@ def prim_Print(x):
def prim_RaiseException():
raise Exception("Error")
# CHECK-LABEL: func @prim_unchecked_cast(
# CHECK-LABEL: func @__torch__.prim_unchecked_cast(
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 {
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[C3:.*]] = constant 3 : i64
@ -64,7 +64,7 @@ def prim_unchecked_cast(i: typing.Optional[int]):
return 3
return i
# CHECK-LABEL: func @prim_TupleUnpack(
# CHECK-LABEL: func @__torch__.prim_TupleUnpack(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64
# CHECK: return %[[RET]]#0 : i64
@ -74,7 +74,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
val, _ = tup
return val
# CHECK-LABEL: func @prim_TupleIndex(
# CHECK-LABEL: func @__torch__.prim_TupleIndex(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !basicpy.TupleType, i64 -> i64
# CHECK: return %[[RET]] : i64
@ -83,7 +83,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
def prim_TupleIndex(tup: typing.Tuple[int, int]):
return tup[0]
# CHECK-LABEL: func @prim_ListUnpack(
# CHECK-LABEL: func @__torch__.prim_ListUnpack(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 {
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64
# CHECK: return %[[RET]]#1 : i64

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @f(
# CHECK-LABEL: func @__torch__.f(
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType {
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder()
# CHECK: @returns_bool
# CHECK: @__torch__.returns_bool
@mb.import_function
@torch.jit.script
def returns_bool():

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder()
# CHECK: @returns_none
# CHECK: @__torch__.returns_none
@mb.import_function
@torch.jit.script
def returns_none():

View File

@ -14,6 +14,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"

View File

@ -13,6 +13,7 @@ include "npcomp/Dialect/Torch/IR/TorchTypes.td"
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Torch_Dialect, mnemonic, traits> {
@ -457,10 +458,59 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", []> {
}];
}
def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", []> {
def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", [
NoSideEffect
]> {
let summary = "TorchScript prim::unchecked_cast op";
// TODO: This seems to mostly be used for casting "optional" to the contained
// type. Verify that and tighten the verifier.
let description = [{
Refine a type to one of its subtypes.
For example, refine a type that was only statically known to be
Optional[T] to a T when we obtain static information that guarantees it.
The key observation here is that Optional[T] does not have a corresponding
runtime type (i.e. `c10::IValue` subclass). It represents a set of possible
concrete types which for `Optional[T]` is either `None` or a concrete
subtype of `T` (which in the simplest case is just `T`). In particular,
at runtime there is no way to distinguish `Optional[int]` from
`Optional[Optional[int]]`, because both are either `None` or `int`.
This differs from C++ std::optional.
The best documentation of this op is inspection of the code in
`torch/csrc/jit/frontend/ir_emitter.cpp`.
}];
// TODO: When we model PyTorch's notion of subtyping, verify the types here.
let arguments = (ins AnyTorchType:$operand);
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
//===----------------------------------------------------------------------===//
// Additional ops used to model TorchScript's Graph's / Node's.
//===----------------------------------------------------------------------===//
def Torch_DerefineOp : Torch_Op<"derefine", [
NoSideEffect
]> {
let summary = "De-refine a type";
let description = [{
In terms of IR structure, TorchScript allows types to vary in many
circumstances where MLIR requires pointer-identical types. In particular,
it is valid to pass any subtype in place of a type. For example, if an
`Optional[int]` is required somewhere in the IR, it is legal to pass a
value of just `int` (but not the other way around; see
`torch.prim.unchecked_cast`). In effect, every *use* can have a different
type.
This op bridges that impedance mismatch. This op allows casting a value
from one type to a type that it is a subtype of to model this behavior.
}];
// TODO: When we model PyTorch's notion of subtyping, verify the types here.
let arguments = (ins AnyTorchType:$operand);
let results = (outs AnyTorchType:$result);

View File

@ -17,6 +17,7 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
MLIRIR
MLIRSupport
MLIRControlFlowInterfaces
MLIRSideEffectInterfaces
NPCOMPBasicpyDialect
NPCOMPNumpyDialect
)