Properly model "derefinement".

In terms of IR structure, TorchScript allows types to vary in many
circumstances where MLIR requires pointer-identical types. In particular,
it is valid to pass any subtype in place of a type. For example, if an
`Optional[int]` is required somewhere in the IR, it is legal to pass a
value of just `int` (but not the other way around; see
`torch.prim.unchecked_cast`). In effect, every *use* can have a different
type.

We introduce a new op `torch.derefine` that models that impedance
mismatch. This op allows casting a value from one type to a type that it
is a subtype of to model this behavior.

Recommended review order:
- TorchOps.td for new torch.derefine (and updated docs for
  `torch.prim.unchecked_cast`)
- new test code in if.py, loop.py, function-derefine.py
- new code in node_importer.cpp for handling derefinement insertion
- function_importer.cpp and utils changes in torch_to_mlir_utils.cpp

Properly handling derefinement on function boundaries required
relayering the code so that graph_importer.cpp/.h is now
function_importer.cpp/.h because only the `torch::jit::Function`
(actually the `c10::FunctionSchema` it holds) knows the derefined types that are
actually needed at the boundary (see `function-derefine.py` for a test).

Annoyingly, this churns all the functions which are now prefixed with
`__torch__.` but that is more correct anyway (that is their linkage name
in the `torch::jit::CompilationUnit`; the previous `mb.import_function`
was actually buggy in the case of functions calling each other as it
would reference their unqualified name).

With this change, we can import `resnet18` from `torchvision` :)
IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704
pull/178/head
Sean Silva 2021-03-01 17:24:15 -08:00
parent 1736ff0253
commit 43dba03afd
26 changed files with 371 additions and 154 deletions

View File

@ -15,7 +15,7 @@ add_library(NPCOMPTorchMLIRExt SHARED
builder/class_annotator.cpp builder/class_annotator.cpp
builder/debug.cpp builder/debug.cpp
builder/func_builder.cpp builder/func_builder.cpp
builder/graph_importer.cpp builder/function_importer.cpp
builder/module_builder.cpp builder/module_builder.cpp
builder/node_importer.cpp builder/node_importer.cpp
builder/op_builder.cpp builder/op_builder.cpp

View File

@ -0,0 +1,56 @@
//===- function_importer.cpp ----------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "function_importer.h"
#include <unordered_map>
#include "mlir_utils.h"
#include "torch_to_mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
namespace py = pybind11;
using namespace torch_mlir;
MlirOperation
torch_mlir::importJitFunctionAsFuncOp(MlirContext context,
torch::jit::Function *function) {
// Useful for debugging:
// graph->dump();
MlirLocation loc = mlirLocationUnknownGet(context);
MlirType functionType =
getFunctionTypeFromSchema(context, function->getSchema());
// Use the function's qualified name from the compilation unit.
// This is a stable linkage name that matches Python module lookup
// conventions (see compilation unit import in IValueImporter for more details
// on qualified names).
MlirAttribute symNameAttr = mlirStringAttrGet(
context, toMlirStringRef(function->qualname().qualifiedName()));
MlirOperation func = createMlirOperation(
"func", loc, mlirRegionCreate(),
toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)),
toMlirNamedAttribute("sym_name", symNameAttr));
MlirRegion bodyRegion = mlirOperationGetRegion(func, 0);
std::vector<MlirType> resultTypes;
for (int i = 0, e = mlirFunctionTypeGetNumResults(functionType); i != e;
i++) {
resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i));
}
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) {
createMlirOperationAtEnd(
appendToBlock, "std.return", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
};
MlirBlock block =
importBlock(context, function->graph()->block(), createTerminator);
mlirRegionAppendOwnedBlock(bodyRegion, block);
return func;
}

View File

@ -0,0 +1,40 @@
//===- function_importer.h --------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H
#include <memory>
#include "../pybind.h"
#include "func_builder.h"
#include "node_importer.h"
#include "mlir-c/IR.h"
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch_mlir {
/// Main entry-point for importing torch::jit::Function instances.
///
/// This code doesn't handle importing of torch::jit::Module's. See
/// IValueImporter for that.
///
/// A torch::jit::Function holds a c10::FunctionSchema along with a
/// c10::QualifiedName and a torch::jit::Graph.
///
/// The torch::jit::Graph is a combination of an MLIR context, function, and
/// builder. See NodeImporter for importing of the core IR Node/Block
/// structure that is analogous to MLIR's Operation/Region/Block core structure.
MlirOperation importJitFunctionAsFuncOp(MlirContext context,
torch::jit::Function *function);
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H

View File

@ -1,38 +0,0 @@
//===- graph_importer.cpp -------------------------------------------------===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#include "graph_importer.h"
#include <unordered_map>
#include "mlir_utils.h"
#include "torch_to_mlir_utils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
namespace py = pybind11;
using namespace torch_mlir;
MlirOperation torch_mlir::importGraphAsFuncOp(MlirContext context,
torch::jit::Graph *graph,
const std::string &name) {
// Useful for debugging:
// graph->dump();
MlirLocation loc = mlirLocationUnknownGet(context);
MlirAttribute typeAttr =
mlirTypeAttrGet(getFunctionTypeFromBlock(context, graph->block()));
MlirAttribute symNameAttr = mlirStringAttrGet(context, toMlirStringRef(name));
MlirOperation func = createMlirOperation(
"func", loc, mlirRegionCreate(), toMlirNamedAttribute("type", typeAttr),
toMlirNamedAttribute("sym_name", symNameAttr));
MlirRegion bodyRegion = mlirOperationGetRegion(func, 0);
MlirBlock block = importBlock(context, graph->block(), "std.return");
mlirRegionAppendOwnedBlock(bodyRegion, block);
return func;
}

View File

@ -1,37 +0,0 @@
//===- graph_importer.h -----------------------------------------*- C++ -*-===//
//
// This file is licensed under a pytorch-style license
// See frontends/pytorch/LICENSE for license information.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
#include <memory>
#include "../pybind.h"
#include "func_builder.h"
#include "node_importer.h"
#include "mlir-c/IR.h"
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch_mlir {
/// Main entry-point for importing torch::jit::Graph instances.
///
/// This code doesn't handle importing of torch::jit::Module's. See
/// IValueImporter for that.
///
/// A Graph is a combination of an MLIR context, function, and builder.
/// See NodeImporter for importing of the core IR Node/Block
/// structure that is analogous to MLIR's Operation/Region/Block core structure.
MlirOperation importGraphAsFuncOp(MlirContext context, torch::jit::Graph *graph,
const std::string &name);
} // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H

View File

@ -7,7 +7,7 @@
#include "ivalue_importer.h" #include "ivalue_importer.h"
#include "class_annotator.h" #include "class_annotator.h"
#include "graph_importer.h" #include "function_importer.h"
#include <unordered_map> #include <unordered_map>
@ -393,8 +393,7 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
} }
for (torch::jit::Function *function : cu->get_functions()) { for (torch::jit::Function *function : cu->get_functions()) {
MlirOperation func = importGraphAsFuncOp( MlirOperation func = importJitFunctionAsFuncOp(context, function);
context, function->graph().get(), function->qualname().qualifiedName());
// For IValue importing, the logical linkage structure of the module // For IValue importing, the logical linkage structure of the module
// is determined by the object graph. // is determined by the object graph.
// //

View File

@ -7,7 +7,7 @@
#include "module_builder.h" #include "module_builder.h"
#include "graph_importer.h" #include "function_importer.h"
#include "ivalue_importer.h" #include "ivalue_importer.h"
#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Bindings/Python/Interop.h"
@ -128,8 +128,7 @@ torch::jit::StrongFunctionPtr
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) { ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
MlirBlock block = getBodyBlock(); MlirBlock block = getBodyBlock();
MlirOperation terminator = this->terminator; MlirOperation terminator = this->terminator;
MlirOperation func = importGraphAsFuncOp( MlirOperation func = importJitFunctionAsFuncOp(context, function.function_);
context, function.function_->graph().get(), function.function_->name());
mlirBlockInsertOwnedOperationBefore(block, terminator, func); mlirBlockInsertOwnedOperationBefore(block, terminator, func);
return function; return function;
} }

View File

@ -30,7 +30,7 @@ public:
NodeImporter(MlirContext context) : context(context) {} NodeImporter(MlirContext context) : context(context) {}
void importNode(Node *node, MlirBlock appendToBlock); void importNode(Node *node, MlirBlock appendToBlock);
MlirBlock importBlock(Block *jitBlock, const std::string &terminatorOpName); MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator);
private: private:
void importPrimNode(Node *node, MlirBlock appendToBlock); void importPrimNode(Node *node, MlirBlock appendToBlock);
@ -70,7 +70,7 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
const std::string &symName = function->qualname().qualifiedName(); const std::string &symName = function->qualname().qualifiedName();
op = createMlirOperation( op = createMlirOperation(
"std.constant", loc, "std.constant", loc,
getFunctionTypeFromBlock(context, function->graph()->block()), getFunctionTypeFromSchema(context, function->getSchema()),
toMlirNamedAttribute( toMlirNamedAttribute(
"value", "value",
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)))); mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
@ -141,14 +141,28 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
} }
if (kind == c10::prim::Loop) { if (kind == c10::prim::Loop) {
std::vector<MlirType> resultTypes =
getMlirTypesFromValues(loc, node->outputs());
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.Loop", loc, appendToBlock, "torch.prim.Loop", loc, resultTypes,
getMlirTypesFromValues(loc, node->outputs()), lookupMappedValues(node->inputs().slice(0, 2)),
lookupMappedValues(node->inputs()), mlirRegionCreate()); derefineValues(lookupMappedValues(node->inputs().slice(2)), resultTypes,
loc, appendToBlock),
mlirRegionCreate());
mapResults(node, operation); mapResults(node, operation);
std::vector<MlirType> terminatorOperandTypes = {npcompBoolTypeGet(context)};
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
resultTypes.begin(), resultTypes.end());
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
MlirBlock appendToBlock) {
createMlirOperationAtEnd(appendToBlock, "torch.prim.Loop.condition", loc,
derefineValues(yieldedValues,
terminatorOperandTypes, loc,
appendToBlock));
};
mlirRegionAppendOwnedBlock( mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0), mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], "torch.prim.Loop.condition")); importBlock(node->blocks()[0], createTerminator));
return; return;
} }
@ -158,15 +172,24 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
MlirOperation pred = createMlirOperationAtEnd( MlirOperation pred = createMlirOperationAtEnd(
appendToBlock, "basicpy.bool_cast", loc, mlirIntegerTypeGet(context, 1), appendToBlock, "basicpy.bool_cast", loc, mlirIntegerTypeGet(context, 1),
lookupMappedValue(node->input())); lookupMappedValue(node->input()));
std::vector<MlirType> resultTypes =
getMlirTypesFromValues(loc, node->outputs());
MlirOperation operation = createMlirOperationAtEnd( MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0), appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0),
getMlirTypesFromValues(loc, node->outputs()), mlirRegionCreate(), resultTypes, mlirRegionCreate(), mlirRegionCreate());
mlirRegionCreate());
mapResults(node, operation); mapResults(node, operation);
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0), auto createTerminator =
importBlock(node->blocks()[0], "scf.yield")); [&](c10::ArrayRef<MlirValue> yieldedValues, MlirBlock appendToBlock) {
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1), createMlirOperationAtEnd(
importBlock(node->blocks()[1], "scf.yield")); appendToBlock, "scf.yield", loc,
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
};
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], createTerminator));
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 1),
importBlock(node->blocks()[1], createTerminator));
return; return;
} }
@ -180,10 +203,18 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
} }
if (kind == c10::prim::CallFunction) { if (kind == c10::prim::CallFunction) {
MlirOperation operation = auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
createMlirOperationAtEnd(appendToBlock, "std.call_indirect", loc, torch::jit::Block *calleeEntryBlock =
functionType->function()->graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return typeMapper.mapFromTorchType(loc, v->type());
});
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "std.call_indirect", loc,
getMlirTypesFromValues(loc, node->outputs()), getMlirTypesFromValues(loc, node->outputs()),
lookupMappedValues(node->inputs())); lookupMappedValue(node->input(0)),
derefineValues(lookupMappedValues(node->inputs().slice(1)),
expectedTypes, loc, appendToBlock));
mapResults(node, operation); mapResults(node, operation);
return; return;
} }
@ -288,15 +319,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
} }
MlirBlock NodeImporter::importBlock(Block *jitBlock, MlirBlock NodeImporter::importBlock(Block *jitBlock,
const std::string &terminatorOpName) { CreateTerminatorFn createTerminator) {
MlirBlock block = createBlockFor(jitBlock); MlirBlock block = createBlockFor(jitBlock);
for (Node *node : jitBlock->nodes()) { for (Node *node : jitBlock->nodes()) {
importNode(node, block); importNode(node, block);
} }
Node *returnNode = jitBlock->return_node(); Node *returnNode = jitBlock->return_node();
createMlirOperationAtEnd(block, terminatorOpName, createTerminator(lookupMappedValues(returnNode->inputs()), block);
getMlirLocationFromNode(context, returnNode),
lookupMappedValues(returnNode->inputs()));
return block; return block;
} }
@ -343,7 +372,7 @@ NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
} }
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock, MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
const std::string &terminatorOpName) { CreateTerminatorFn createTerminator) {
NodeImporter importer(context); NodeImporter importer(context);
return importer.importBlock(jitBlock, terminatorOpName); return importer.importBlock(jitBlock, createTerminator);
} }

View File

@ -20,8 +20,11 @@
namespace torch_mlir { namespace torch_mlir {
using CreateTerminatorFn =
std::function<void(c10::ArrayRef<MlirValue>, MlirBlock)>;
MlirBlock importBlock(MlirContext context, torch::jit::Block *jitBlock, MlirBlock importBlock(MlirContext context, torch::jit::Block *jitBlock,
const std::string &terminatorOpName); CreateTerminatorFn createTerminator);
} // namespace torch_mlir } // namespace torch_mlir

View File

@ -5,7 +5,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "graph_importer.h" #include "function_importer.h"
#include "ivalue_importer.h" #include "ivalue_importer.h"
#include <unordered_map> #include <unordered_map>
@ -163,17 +163,28 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType); return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
} }
MlirType torch_mlir::getFunctionTypeFromBlock(MlirContext context, MlirType
torch::jit::Block *block) { torch_mlir::getFunctionTypeFromSchema(MlirContext context,
MlirLocation inputLoc = getMlirLocationFromNode(context, block->param_node()); const c10::FunctionSchema &schema) {
MlirLocation loc = mlirLocationUnknownGet(context);
TypeMapper typeMapper(context);
auto mapType = [&](const c10::TypePtr &torchType) {
MlirType type = typeMapper.mapFromTorchType(loc, torchType);
if (mlirTypeIsNull(type)) {
std::stringstream msg;
msg << "unsupported type in function schema: '"
<< c10::toString(torchType) << "'";
throw std::invalid_argument(msg.str());
}
return type;
};
std::vector<MlirType> inputTypes = std::vector<MlirType> inputTypes =
getMlirTypesFromValues(inputLoc, block->param_node()->outputs()); c10::fmap(schema.arguments(),
[&](const c10::Argument &arg) { return mapType(arg.type()); });
MlirLocation outputLoc =
getMlirLocationFromNode(context, block->return_node());
std::vector<MlirType> outputTypes = std::vector<MlirType> outputTypes =
getMlirTypesFromValues(outputLoc, block->return_node()->inputs()); c10::fmap(schema.returns(),
[&](const c10::Argument &arg) { return mapType(arg.type()); });
return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(), return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
outputTypes.size(), outputTypes.data()); outputTypes.size(), outputTypes.data());
} }
@ -303,3 +314,25 @@ torch_mlir::getMlirTypesFromValues(MlirLocation loc,
} }
return ret; return ret;
} }
std::vector<MlirValue>
torch_mlir::derefineValues(c10::ArrayRef<MlirValue> values,
c10::ArrayRef<MlirType> expectedTypes,
MlirLocation loc, MlirBlock appendToBlock) {
std::vector<MlirValue> ret;
assert(values.size() == expectedTypes.size());
for (int i = 0, e = values.size(); i != e; i++) {
MlirValue value = values[i];
MlirType expectedType = expectedTypes[i];
MlirType type = mlirValueGetType(value);
if (mlirTypeEqual(expectedType, type)) {
// No need to derefine.
ret.push_back(value);
} else {
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.derefine", loc, expectedType, value);
ret.push_back(mlirOperationGetResult(operation, 0));
}
}
return ret;
}

View File

@ -51,18 +51,12 @@ private:
MlirContext context; MlirContext context;
}; };
/// Creates a FunctionType suitable for expressing the signature of `block`. /// Creates a FunctionType suitable for expressing the signature of `schema`.
/// ///
/// `mlir::Block` only has a formalized notion of argument types (bb args), but /// This can differ from the type inferred from the block of a
/// the exact nature of the block's terminator is left opaque (for example, you /// torch::jit::Function due to derefinement.
/// can have a weird terminator that "returns all but the first operand"). MlirType getFunctionTypeFromSchema(MlirContext context,
/// `torch::jit::Block` on the other hand has a formalized notion of a const c10::FunctionSchema &schema);
/// `param_node` and `return_node`, which are effectively dummy operations at
/// the start and end of the block, which establish a formal signature for the
/// block and can be generically reasoned about -- that is what we anchor on
/// here.
MlirType getFunctionTypeFromBlock(MlirContext context,
torch::jit::Block *block);
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`. /// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
MlirAttribute converTensorToMlirElementsAttr(at::Tensor tensor, MlirAttribute converTensorToMlirElementsAttr(at::Tensor tensor,
@ -78,6 +72,11 @@ std::vector<MlirType>
getMlirTypesFromValues(MlirLocation loc, getMlirTypesFromValues(MlirLocation loc,
c10::ArrayRef<torch::jit::Value *> values); c10::ArrayRef<torch::jit::Value *> values);
std::vector<MlirValue> derefineValues(c10::ArrayRef<MlirValue> values,
c10::ArrayRef<MlirType> expectedTypes,
MlirLocation loc,
MlirBlock appendToBlock);
} // namespace torch_mlir } // namespace torch_mlir
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H #endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H

View File

@ -10,7 +10,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# Verify without debug info. # Verify without debug info.
# CHECK-LABEL: func @add3 # CHECK-LABEL: func @__torch__.add3
# CHECK-SAME: (%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> { # CHECK-SAME: (%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
# CHECK: %[[C1:.*]] = constant 1 : i64 # CHECK: %[[C1:.*]] = constant 1 : i64
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]} # CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @add3 # CHECK-LABEL: func @__torch__.add3
# Note that line-level debug information for parts unannotated in the Torch # Note that line-level debug information for parts unannotated in the Torch
# graph are ascribed to the first op that carries source information. Presently # graph are ascribed to the first op that carries source information. Presently
# this includes naked constants, return and the function itself. This heuristic # this includes naked constants, return and the function itself. This heuristic

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: @f # CHECK-LABEL: @__torch__.f
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script
def f(b: bool, i: int): def f(b: bool, i: int):

View File

@ -17,8 +17,8 @@ try:
@torch.jit.script @torch.jit.script
def import_class(x: typing.Any): def import_class(x: typing.Any):
return x return x
except RuntimeError as e: except Exception as e:
# TODO: Once diagnostics are enabled, verify the actual error emitted. # TODO: Once diagnostics are enabled, verify the actual error emitted.
assert str(e) == "unsupported type" assert str(e) == "unsupported type in function schema: 'Any'"
else: else:
assert False, "Expected exception" assert False, "Expected exception"

View File

@ -0,0 +1,42 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import typing
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @__torch__.optional_return(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.optional<i64> {
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
# CHECK: return %[[RET]] : !torch.optional<i64>
@mb.import_function
@torch.jit.script
def optional_return(i: int) -> typing.Optional[int]:
return i
# CHECK-LABEL: func @__torch__.optional_arg(
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<i64>) -> !basicpy.NoneType {
@mb.import_function
@torch.jit.script
def optional_arg(i: typing.Optional[int]) -> None:
return
# CHECK-LABEL: func @__torch__.calls_optional_arg(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.NoneType {
# CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional<i64>) -> !basicpy.NoneType
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
# CHECK: %{{.*}} = call_indirect %[[CALLEE]](%[[DEREFINED]]) : (!torch.optional<i64>) -> !basicpy.NoneType
@mb.import_function
@torch.jit.script
def calls_optional_arg(i: int):
optional_arg(i)
mb.module.operation.print()
print()

View File

@ -9,12 +9,12 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: @f( # CHECK-LABEL: @__torch__.prim_If(
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType, # CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
# CHECK-SAME: %[[I:.*]]: i64) -> i64 { # CHECK-SAME: %[[I:.*]]: i64) -> i64 {
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script
def f(b: bool, i: int): def prim_If(b: bool, i: int):
# CHECK: %[[I1:.*]] = basicpy.bool_cast %[[B]] : !basicpy.BoolType -> i1 # CHECK: %[[I1:.*]] = basicpy.bool_cast %[[B]] : !basicpy.BoolType -> i1
# CHECK: %[[RES:.*]] = scf.if %[[I1]] -> (i64) { # CHECK: %[[RES:.*]] = scf.if %[[I1]] -> (i64) {
# CHECK: %[[ADD:.*]] = torch.kernel_call "aten::add" %[[I]], %[[I]] # CHECK: %[[ADD:.*]] = torch.kernel_call "aten::add" %[[I]], %[[I]]
@ -30,6 +30,25 @@ def f(b: bool, i: int):
return i * i return i * i
# elif is modeled as a nested if, so no need to specially test it here. # elif is modeled as a nested if, so no need to specially test it here.
assert isinstance(f, torch.jit.ScriptFunction) # CHECK-LABEL: func @__torch__.prim_If_derefine(
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
# CHECK-SAME: %[[I:.*]]: i64) -> !torch.optional<i64> {
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[PRED:.*]] = basicpy.bool_cast %[[B]] : !basicpy.BoolType -> i1
# CHECK: %[[RES:.*]] = scf.if %[[PRED]] -> (!torch.optional<i64>) {
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
# CHECK: scf.yield %[[NONE_DEREFINED]] : !torch.optional<i64>
# CHECK: } else {
# CHECK: %[[I_DEREFINED:.*]] = torch.derefine %[[I]] : i64 -> !torch.optional<i64>
# CHECK: scf.yield %[[I_DEREFINED]] : !torch.optional<i64>
# CHECK: }
# CHECK: return %[[RES:.*]] : !torch.optional<i64>
@mb.import_function
@torch.jit.script
def prim_If_derefine(b: bool, i: int):
if b:
return None
return i
mb.module.operation.print() mb.module.operation.print()
print() print()

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @f( # CHECK-LABEL: func @__torch__.f(
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>, # CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType { # CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType {
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType # CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType

View File

@ -5,11 +5,13 @@
import torch import torch
import torch_mlir import torch_mlir
import typing
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s # RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @prim_Loop_forlike( # CHECK-LABEL: func @__torch__.prim_Loop_forlike(
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: i64) -> f64 { # CHECK-SAME: %[[MAX_ITERATIONS:.*]]: i64) -> f64 {
# CHECK: %[[BOOL_TRUE:.*]] = basicpy.bool_constant true # CHECK: %[[BOOL_TRUE:.*]] = basicpy.bool_constant true
# CHECK: %[[F_INIT:.*]] = constant 0.000000e+00 : f64 # CHECK: %[[F_INIT:.*]] = constant 0.000000e+00 : f64
@ -27,18 +29,18 @@ def prim_Loop_forlike(n: int):
f += i f += i
return f return f
# CHECK-LABEL: func @prim_Loop_whilelike( # CHECK-LABEL: func @__torch__.prim_Loop_whilelike(
# CHECK-SAME: %[[VAL_0:.*]]: i64) -> f64 { # CHECK-SAME: %[[VAL_0:.*]]: i64) -> f64 {
# CHECK: %[[F_INIT:.*]] = constant 3.200000e+00 : f64 # CHECK: %[[F_INIT:.*]] = constant 3.200000e+00 : f64
# CHECK: %[[MAX_ITERATIONS:.*]] = constant 9223372036854775807 : i64 # CHECK: %[[MAX_ITERATIONS:.*]] = constant 9223372036854775807 : i64
# CHECK: %[[COND_INIT:.*]] = torch.kernel_call "aten::lt" %[[F_INIT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]} # CHECK: %[[COND_INIT:.*]] = torch.kernel_call "aten::lt" %[[F_INIT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
# CHECK: %[[IV:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) { # CHECK: %[[RET:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) {
# CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64): # CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64):
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]} # CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
# CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]} # CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
# CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64) # CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64 # CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
# CHECK: return %[[VAL_9:.*]] : f64 # CHECK: return %[[RET:.*]] : f64
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script
def prim_Loop_whilelike(n: int): def prim_Loop_whilelike(n: int):
@ -47,5 +49,24 @@ def prim_Loop_whilelike(n: int):
f = f * f f = f * f
return f return f
# CHECK-LABEL: func @__torch__.prim_Loop_derefine(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.optional<i64> {
# CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
# CHECK: %[[RET:.*]] = torch.prim.Loop %[[ARG]], %[[TRUE]], init(%[[NONE_DEREFINED]]) {
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[X_ITER:.*]]: !torch.optional<i64>):
# CHECK: %[[X_NEXT:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
# CHECK: torch.prim.Loop.condition %[[TRUE]] iter(%[[X_NEXT]]) : !basicpy.BoolType, (!torch.optional<i64>)
# CHECK: } : (i64, !basicpy.BoolType, !torch.optional<i64>) -> !torch.optional<i64>
# CHECK: return %[[RET:.*]] : !torch.optional<i64>
@mb.import_function
@torch.jit.script
def prim_Loop_derefine(n: int):
x: typing.Optional[int] = None
for i in range(n):
x = n
return x
mb.module.operation.print() mb.module.operation.print()
print() print()

View File

@ -14,7 +14,7 @@ import typing
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @prim_NumToTensor( # CHECK-LABEL: func @__torch__.prim_NumToTensor(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> { # CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype> # CHECK: %[[RET:.*]] = torch.prim.NumToTensor %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype>
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype> # CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
@ -25,7 +25,7 @@ mb = torch_mlir.ModuleBuilder()
def prim_NumToTensor(i: int): def prim_NumToTensor(i: int):
return _to_tensor(i) return _to_tensor(i)
# CHECK-LABEL: func @prim_Print( # CHECK-LABEL: func @__torch__.prim_Print(
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType { # CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType {
# CHECK: %[[STR:.*]] = basicpy.bytes_constant "x" # CHECK: %[[STR:.*]] = basicpy.bytes_constant "x"
# CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !basicpy.BytesType, !numpy.ndarray<*:!numpy.any_dtype> # CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !basicpy.BytesType, !numpy.ndarray<*:!numpy.any_dtype>
@ -34,7 +34,7 @@ def prim_NumToTensor(i: int):
def prim_Print(x): def prim_Print(x):
print("x", x) print("x", x)
# CHECK-LABEL: func @prim_RaiseException() -> !basicpy.NoneType { # CHECK-LABEL: func @__torch__.prim_RaiseException() -> !basicpy.NoneType {
# CHECK: %[[ERRORSTR:.*]] = basicpy.bytes_constant "Error" # CHECK: %[[ERRORSTR:.*]] = basicpy.bytes_constant "Error"
# CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !basicpy.NoneType # CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !basicpy.NoneType
# CHECK: torch.prim.RaiseException %[[ERRORSTR]] # CHECK: torch.prim.RaiseException %[[ERRORSTR]]
@ -44,7 +44,7 @@ def prim_Print(x):
def prim_RaiseException(): def prim_RaiseException():
raise Exception("Error") raise Exception("Error")
# CHECK-LABEL: func @prim_unchecked_cast( # CHECK-LABEL: func @__torch__.prim_unchecked_cast(
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 { # CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 {
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType # CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
# CHECK: %[[C3:.*]] = constant 3 : i64 # CHECK: %[[C3:.*]] = constant 3 : i64
@ -64,7 +64,7 @@ def prim_unchecked_cast(i: typing.Optional[int]):
return 3 return 3
return i return i
# CHECK-LABEL: func @prim_TupleUnpack( # CHECK-LABEL: func @__torch__.prim_TupleUnpack(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 { # CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64 # CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64
# CHECK: return %[[RET]]#0 : i64 # CHECK: return %[[RET]]#0 : i64
@ -74,7 +74,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
val, _ = tup val, _ = tup
return val return val
# CHECK-LABEL: func @prim_TupleIndex( # CHECK-LABEL: func @__torch__.prim_TupleIndex(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 { # CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !basicpy.TupleType, i64 -> i64 # CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !basicpy.TupleType, i64 -> i64
# CHECK: return %[[RET]] : i64 # CHECK: return %[[RET]] : i64
@ -83,7 +83,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
def prim_TupleIndex(tup: typing.Tuple[int, int]): def prim_TupleIndex(tup: typing.Tuple[int, int]):
return tup[0] return tup[0]
# CHECK-LABEL: func @prim_ListUnpack( # CHECK-LABEL: func @__torch__.prim_ListUnpack(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 { # CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 {
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64 # CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64
# CHECK: return %[[RET]]#1 : i64 # CHECK: return %[[RET]]#1 : i64

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @f( # CHECK-LABEL: func @__torch__.f(
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>, # CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType { # CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType {
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType # CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# CHECK: @returns_bool # CHECK: @__torch__.returns_bool
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script
def returns_bool(): def returns_bool():

View File

@ -9,7 +9,7 @@ import torch_mlir
mb = torch_mlir.ModuleBuilder() mb = torch_mlir.ModuleBuilder()
# CHECK: @returns_none # CHECK: @__torch__.returns_none
@mb.import_function @mb.import_function
@torch.jit.script @torch.jit.script
def returns_none(): def returns_none():

View File

@ -14,6 +14,7 @@
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h" #include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h" #include "npcomp/Dialect/Torch/IR/TorchTypes.h"

View File

@ -13,6 +13,7 @@ include "npcomp/Dialect/Torch/IR/TorchTypes.td"
include "npcomp/Dialect/Torch/IR/OpInterfaces.td" include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class Torch_Op<string mnemonic, list<OpTrait> traits = []> class Torch_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Torch_Dialect, mnemonic, traits> { : Op<Torch_Dialect, mnemonic, traits> {
@ -457,10 +458,59 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", []> {
}]; }];
} }
def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", []> { def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", [
NoSideEffect
]> {
let summary = "TorchScript prim::unchecked_cast op"; let summary = "TorchScript prim::unchecked_cast op";
// TODO: This seems to mostly be used for casting "optional" to the contained let description = [{
// type. Verify that and tighten the verifier. Refine a type to one of its subtypes.
For example, refine a type that was only statically known to be
Optional[T] to a T when we obtain static information that guarantees it.
The key observation here is that Optional[T] does not have a corresponding
runtime type (i.e. `c10::IValue` subclass). It represents a set of possible
concrete types which for `Optional[T]` is either `None` or a concrete
subtype of `T` (which in the simplest case is just `T`). In particular,
at runtime there is no way to distinguish `Optional[int]` from
`Optional[Optional[int]]`, because both are either `None` or `int`.
This differs from C++ std::optional.
The best documentation of this op is inspection of the code in
`torch/csrc/jit/frontend/ir_emitter.cpp`.
}];
// TODO: When we model PyTorch's notion of subtyping, verify the types here.
let arguments = (ins AnyTorchType:$operand);
let results = (outs AnyTorchType:$result);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
//===----------------------------------------------------------------------===//
// Additional ops used to model TorchScript's Graph's / Node's.
//===----------------------------------------------------------------------===//
def Torch_DerefineOp : Torch_Op<"derefine", [
NoSideEffect
]> {
let summary = "De-refine a type";
let description = [{
In terms of IR structure, TorchScript allows types to vary in many
circumstances where MLIR requires pointer-identical types. In particular,
it is valid to pass any subtype in place of a type. For example, if an
`Optional[int]` is required somewhere in the IR, it is legal to pass a
value of just `int` (but not the other way around; see
`torch.prim.unchecked_cast`). In effect, every *use* can have a different
type.
This op bridges that impedance mismatch. This op allows casting a value
from one type to a type that it is a subtype of to model this behavior.
}];
// TODO: When we model PyTorch's notion of subtyping, verify the types here.
let arguments = (ins AnyTorchType:$operand); let arguments = (ins AnyTorchType:$operand);
let results = (outs AnyTorchType:$result); let results = (outs AnyTorchType:$result);

View File

@ -17,6 +17,7 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
MLIRIR MLIRIR
MLIRSupport MLIRSupport
MLIRControlFlowInterfaces MLIRControlFlowInterfaces
MLIRSideEffectInterfaces
NPCOMPBasicpyDialect NPCOMPBasicpyDialect
NPCOMPNumpyDialect NPCOMPNumpyDialect
) )