mirror of https://github.com/llvm/torch-mlir
Properly model "derefinement".
In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704pull/178/head
parent
1736ff0253
commit
43dba03afd
|
@ -15,7 +15,7 @@ add_library(NPCOMPTorchMLIRExt SHARED
|
|||
builder/class_annotator.cpp
|
||||
builder/debug.cpp
|
||||
builder/func_builder.cpp
|
||||
builder/graph_importer.cpp
|
||||
builder/function_importer.cpp
|
||||
builder/module_builder.cpp
|
||||
builder/node_importer.cpp
|
||||
builder/op_builder.cpp
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
//===- function_importer.cpp ----------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "function_importer.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlir_utils.h"
|
||||
#include "torch_to_mlir_utils.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace torch_mlir;
|
||||
|
||||
MlirOperation
|
||||
torch_mlir::importJitFunctionAsFuncOp(MlirContext context,
|
||||
torch::jit::Function *function) {
|
||||
// Useful for debugging:
|
||||
// graph->dump();
|
||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
MlirType functionType =
|
||||
getFunctionTypeFromSchema(context, function->getSchema());
|
||||
// Use the function's qualified name from the compilation unit.
|
||||
// This is a stable linkage name that matches Python module lookup
|
||||
// conventions (see compilation unit import in IValueImporter for more details
|
||||
// on qualified names).
|
||||
MlirAttribute symNameAttr = mlirStringAttrGet(
|
||||
context, toMlirStringRef(function->qualname().qualifiedName()));
|
||||
MlirOperation func = createMlirOperation(
|
||||
"func", loc, mlirRegionCreate(),
|
||||
toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)),
|
||||
toMlirNamedAttribute("sym_name", symNameAttr));
|
||||
MlirRegion bodyRegion = mlirOperationGetRegion(func, 0);
|
||||
std::vector<MlirType> resultTypes;
|
||||
for (int i = 0, e = mlirFunctionTypeGetNumResults(functionType); i != e;
|
||||
i++) {
|
||||
resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i));
|
||||
}
|
||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||
MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(
|
||||
appendToBlock, "std.return", loc,
|
||||
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||
};
|
||||
MlirBlock block =
|
||||
importBlock(context, function->graph()->block(), createTerminator);
|
||||
mlirRegionAppendOwnedBlock(bodyRegion, block);
|
||||
return func;
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
//===- function_importer.h --------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H
|
||||
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "func_builder.h"
|
||||
#include "node_importer.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Main entry-point for importing torch::jit::Function instances.
|
||||
///
|
||||
/// This code doesn't handle importing of torch::jit::Module's. See
|
||||
/// IValueImporter for that.
|
||||
///
|
||||
/// A torch::jit::Function holds a c10::FunctionSchema along with a
|
||||
/// c10::QualifiedName and a torch::jit::Graph.
|
||||
///
|
||||
/// The torch::jit::Graph is a combination of an MLIR context, function, and
|
||||
/// builder. See NodeImporter for importing of the core IR Node/Block
|
||||
/// structure that is analogous to MLIR's Operation/Region/Block core structure.
|
||||
MlirOperation importJitFunctionAsFuncOp(MlirContext context,
|
||||
torch::jit::Function *function);
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H
|
|
@ -1,38 +0,0 @@
|
|||
//===- graph_importer.cpp -------------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "graph_importer.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlir_utils.h"
|
||||
#include "torch_to_mlir_utils.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace torch_mlir;
|
||||
|
||||
MlirOperation torch_mlir::importGraphAsFuncOp(MlirContext context,
|
||||
torch::jit::Graph *graph,
|
||||
const std::string &name) {
|
||||
// Useful for debugging:
|
||||
// graph->dump();
|
||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
MlirAttribute typeAttr =
|
||||
mlirTypeAttrGet(getFunctionTypeFromBlock(context, graph->block()));
|
||||
MlirAttribute symNameAttr = mlirStringAttrGet(context, toMlirStringRef(name));
|
||||
MlirOperation func = createMlirOperation(
|
||||
"func", loc, mlirRegionCreate(), toMlirNamedAttribute("type", typeAttr),
|
||||
toMlirNamedAttribute("sym_name", symNameAttr));
|
||||
MlirRegion bodyRegion = mlirOperationGetRegion(func, 0);
|
||||
MlirBlock block = importBlock(context, graph->block(), "std.return");
|
||||
mlirRegionAppendOwnedBlock(bodyRegion, block);
|
||||
return func;
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
//===- graph_importer.h -----------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
|
||||
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "../pybind.h"
|
||||
#include "func_builder.h"
|
||||
#include "node_importer.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Main entry-point for importing torch::jit::Graph instances.
|
||||
///
|
||||
/// This code doesn't handle importing of torch::jit::Module's. See
|
||||
/// IValueImporter for that.
|
||||
///
|
||||
/// A Graph is a combination of an MLIR context, function, and builder.
|
||||
/// See NodeImporter for importing of the core IR Node/Block
|
||||
/// structure that is analogous to MLIR's Operation/Region/Block core structure.
|
||||
MlirOperation importGraphAsFuncOp(MlirContext context, torch::jit::Graph *graph,
|
||||
const std::string &name);
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
#include "ivalue_importer.h"
|
||||
#include "class_annotator.h"
|
||||
#include "graph_importer.h"
|
||||
#include "function_importer.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
|
@ -393,8 +393,7 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|||
}
|
||||
|
||||
for (torch::jit::Function *function : cu->get_functions()) {
|
||||
MlirOperation func = importGraphAsFuncOp(
|
||||
context, function->graph().get(), function->qualname().qualifiedName());
|
||||
MlirOperation func = importJitFunctionAsFuncOp(context, function);
|
||||
// For IValue importing, the logical linkage structure of the module
|
||||
// is determined by the object graph.
|
||||
//
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
#include "module_builder.h"
|
||||
|
||||
#include "graph_importer.h"
|
||||
#include "function_importer.h"
|
||||
#include "ivalue_importer.h"
|
||||
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
|
@ -128,8 +128,7 @@ torch::jit::StrongFunctionPtr
|
|||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||
MlirBlock block = getBodyBlock();
|
||||
MlirOperation terminator = this->terminator;
|
||||
MlirOperation func = importGraphAsFuncOp(
|
||||
context, function.function_->graph().get(), function.function_->name());
|
||||
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_);
|
||||
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
||||
return function;
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ public:
|
|||
NodeImporter(MlirContext context) : context(context) {}
|
||||
|
||||
void importNode(Node *node, MlirBlock appendToBlock);
|
||||
MlirBlock importBlock(Block *jitBlock, const std::string &terminatorOpName);
|
||||
MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator);
|
||||
|
||||
private:
|
||||
void importPrimNode(Node *node, MlirBlock appendToBlock);
|
||||
|
@ -70,7 +70,7 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
const std::string &symName = function->qualname().qualifiedName();
|
||||
op = createMlirOperation(
|
||||
"std.constant", loc,
|
||||
getFunctionTypeFromBlock(context, function->graph()->block()),
|
||||
getFunctionTypeFromSchema(context, function->getSchema()),
|
||||
toMlirNamedAttribute(
|
||||
"value",
|
||||
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
||||
|
@ -141,14 +141,28 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
}
|
||||
|
||||
if (kind == c10::prim::Loop) {
|
||||
std::vector<MlirType> resultTypes =
|
||||
getMlirTypesFromValues(loc, node->outputs());
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.Loop", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()), mlirRegionCreate());
|
||||
appendToBlock, "torch.prim.Loop", loc, resultTypes,
|
||||
lookupMappedValues(node->inputs().slice(0, 2)),
|
||||
derefineValues(lookupMappedValues(node->inputs().slice(2)), resultTypes,
|
||||
loc, appendToBlock),
|
||||
mlirRegionCreate());
|
||||
mapResults(node, operation);
|
||||
std::vector<MlirType> terminatorOperandTypes = {npcompBoolTypeGet(context)};
|
||||
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
||||
resultTypes.begin(), resultTypes.end());
|
||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||
MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(appendToBlock, "torch.prim.Loop.condition", loc,
|
||||
derefineValues(yieldedValues,
|
||||
terminatorOperandTypes, loc,
|
||||
appendToBlock));
|
||||
};
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], "torch.prim.Loop.condition"));
|
||||
importBlock(node->blocks()[0], createTerminator));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -158,15 +172,24 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
MlirOperation pred = createMlirOperationAtEnd(
|
||||
appendToBlock, "basicpy.bool_cast", loc, mlirIntegerTypeGet(context, 1),
|
||||
lookupMappedValue(node->input()));
|
||||
std::vector<MlirType> resultTypes =
|
||||
getMlirTypesFromValues(loc, node->outputs());
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0),
|
||||
getMlirTypesFromValues(loc, node->outputs()), mlirRegionCreate(),
|
||||
mlirRegionCreate());
|
||||
resultTypes, mlirRegionCreate(), mlirRegionCreate());
|
||||
mapResults(node, operation);
|
||||
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], "scf.yield"));
|
||||
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1),
|
||||
importBlock(node->blocks()[1], "scf.yield"));
|
||||
auto createTerminator =
|
||||
[&](c10::ArrayRef<MlirValue> yieldedValues, MlirBlock appendToBlock) {
|
||||
createMlirOperationAtEnd(
|
||||
appendToBlock, "scf.yield", loc,
|
||||
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||
};
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], createTerminator));
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 1),
|
||||
importBlock(node->blocks()[1], createTerminator));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -180,10 +203,18 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
|||
}
|
||||
|
||||
if (kind == c10::prim::CallFunction) {
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(appendToBlock, "std.call_indirect", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValues(node->inputs()));
|
||||
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
|
||||
torch::jit::Block *calleeEntryBlock =
|
||||
functionType->function()->graph()->block();
|
||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||
return typeMapper.mapFromTorchType(loc, v->type());
|
||||
});
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "std.call_indirect", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
lookupMappedValue(node->input(0)),
|
||||
derefineValues(lookupMappedValues(node->inputs().slice(1)),
|
||||
expectedTypes, loc, appendToBlock));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
}
|
||||
|
@ -288,15 +319,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
}
|
||||
|
||||
MlirBlock NodeImporter::importBlock(Block *jitBlock,
|
||||
const std::string &terminatorOpName) {
|
||||
CreateTerminatorFn createTerminator) {
|
||||
MlirBlock block = createBlockFor(jitBlock);
|
||||
for (Node *node : jitBlock->nodes()) {
|
||||
importNode(node, block);
|
||||
}
|
||||
Node *returnNode = jitBlock->return_node();
|
||||
createMlirOperationAtEnd(block, terminatorOpName,
|
||||
getMlirLocationFromNode(context, returnNode),
|
||||
lookupMappedValues(returnNode->inputs()));
|
||||
createTerminator(lookupMappedValues(returnNode->inputs()), block);
|
||||
return block;
|
||||
}
|
||||
|
||||
|
@ -343,7 +372,7 @@ NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
|
|||
}
|
||||
|
||||
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
||||
const std::string &terminatorOpName) {
|
||||
CreateTerminatorFn createTerminator) {
|
||||
NodeImporter importer(context);
|
||||
return importer.importBlock(jitBlock, terminatorOpName);
|
||||
return importer.importBlock(jitBlock, createTerminator);
|
||||
}
|
||||
|
|
|
@ -20,8 +20,11 @@
|
|||
|
||||
namespace torch_mlir {
|
||||
|
||||
using CreateTerminatorFn =
|
||||
std::function<void(c10::ArrayRef<MlirValue>, MlirBlock)>;
|
||||
|
||||
MlirBlock importBlock(MlirContext context, torch::jit::Block *jitBlock,
|
||||
const std::string &terminatorOpName);
|
||||
CreateTerminatorFn createTerminator);
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "graph_importer.h"
|
||||
#include "function_importer.h"
|
||||
#include "ivalue_importer.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
@ -163,17 +163,28 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
|||
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
|
||||
}
|
||||
|
||||
MlirType torch_mlir::getFunctionTypeFromBlock(MlirContext context,
|
||||
torch::jit::Block *block) {
|
||||
MlirLocation inputLoc = getMlirLocationFromNode(context, block->param_node());
|
||||
MlirType
|
||||
torch_mlir::getFunctionTypeFromSchema(MlirContext context,
|
||||
const c10::FunctionSchema &schema) {
|
||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
TypeMapper typeMapper(context);
|
||||
auto mapType = [&](const c10::TypePtr &torchType) {
|
||||
MlirType type = typeMapper.mapFromTorchType(loc, torchType);
|
||||
if (mlirTypeIsNull(type)) {
|
||||
std::stringstream msg;
|
||||
msg << "unsupported type in function schema: '"
|
||||
<< c10::toString(torchType) << "'";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return type;
|
||||
};
|
||||
|
||||
std::vector<MlirType> inputTypes =
|
||||
getMlirTypesFromValues(inputLoc, block->param_node()->outputs());
|
||||
|
||||
MlirLocation outputLoc =
|
||||
getMlirLocationFromNode(context, block->return_node());
|
||||
c10::fmap(schema.arguments(),
|
||||
[&](const c10::Argument &arg) { return mapType(arg.type()); });
|
||||
std::vector<MlirType> outputTypes =
|
||||
getMlirTypesFromValues(outputLoc, block->return_node()->inputs());
|
||||
|
||||
c10::fmap(schema.returns(),
|
||||
[&](const c10::Argument &arg) { return mapType(arg.type()); });
|
||||
return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
|
||||
outputTypes.size(), outputTypes.data());
|
||||
}
|
||||
|
@ -303,3 +314,25 @@ torch_mlir::getMlirTypesFromValues(MlirLocation loc,
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<MlirValue>
|
||||
torch_mlir::derefineValues(c10::ArrayRef<MlirValue> values,
|
||||
c10::ArrayRef<MlirType> expectedTypes,
|
||||
MlirLocation loc, MlirBlock appendToBlock) {
|
||||
std::vector<MlirValue> ret;
|
||||
assert(values.size() == expectedTypes.size());
|
||||
for (int i = 0, e = values.size(); i != e; i++) {
|
||||
MlirValue value = values[i];
|
||||
MlirType expectedType = expectedTypes[i];
|
||||
MlirType type = mlirValueGetType(value);
|
||||
if (mlirTypeEqual(expectedType, type)) {
|
||||
// No need to derefine.
|
||||
ret.push_back(value);
|
||||
} else {
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.derefine", loc, expectedType, value);
|
||||
ret.push_back(mlirOperationGetResult(operation, 0));
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -51,18 +51,12 @@ private:
|
|||
MlirContext context;
|
||||
};
|
||||
|
||||
/// Creates a FunctionType suitable for expressing the signature of `block`.
|
||||
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
||||
///
|
||||
/// `mlir::Block` only has a formalized notion of argument types (bb args), but
|
||||
/// the exact nature of the block's terminator is left opaque (for example, you
|
||||
/// can have a weird terminator that "returns all but the first operand").
|
||||
/// `torch::jit::Block` on the other hand has a formalized notion of a
|
||||
/// `param_node` and `return_node`, which are effectively dummy operations at
|
||||
/// the start and end of the block, which establish a formal signature for the
|
||||
/// block and can be generically reasoned about -- that is what we anchor on
|
||||
/// here.
|
||||
MlirType getFunctionTypeFromBlock(MlirContext context,
|
||||
torch::jit::Block *block);
|
||||
/// This can differ from the type inferred from the block of a
|
||||
/// torch::jit::Function due to derefinement.
|
||||
MlirType getFunctionTypeFromSchema(MlirContext context,
|
||||
const c10::FunctionSchema &schema);
|
||||
|
||||
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
|
||||
MlirAttribute converTensorToMlirElementsAttr(at::Tensor tensor,
|
||||
|
@ -78,6 +72,11 @@ std::vector<MlirType>
|
|||
getMlirTypesFromValues(MlirLocation loc,
|
||||
c10::ArrayRef<torch::jit::Value *> values);
|
||||
|
||||
std::vector<MlirValue> derefineValues(c10::ArrayRef<MlirValue> values,
|
||||
c10::ArrayRef<MlirType> expectedTypes,
|
||||
MlirLocation loc,
|
||||
MlirBlock appendToBlock);
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch_mlir
|
|||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# Verify without debug info.
|
||||
# CHECK-LABEL: func @add3
|
||||
# CHECK-LABEL: func @__torch__.add3
|
||||
# CHECK-SAME: (%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
# CHECK: %[[C1:.*]] = constant 1 : i64
|
||||
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
|||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @add3
|
||||
# CHECK-LABEL: func @__torch__.add3
|
||||
# Note that line-level debug information for parts unannotated in the Torch
|
||||
# graph are ascribed to the first op that carries source information. Presently
|
||||
# this includes naked constants, return and the function itself. This heuristic
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
|||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: @f
|
||||
# CHECK-LABEL: @__torch__.f
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def f(b: bool, i: int):
|
||||
|
|
|
@ -17,8 +17,8 @@ try:
|
|||
@torch.jit.script
|
||||
def import_class(x: typing.Any):
|
||||
return x
|
||||
except RuntimeError as e:
|
||||
except Exception as e:
|
||||
# TODO: Once diagnostics are enabled, verify the actual error emitted.
|
||||
assert str(e) == "unsupported type"
|
||||
assert str(e) == "unsupported type in function schema: 'Any'"
|
||||
else:
|
||||
assert False, "Expected exception"
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import typing
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @__torch__.optional_return(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.optional<i64> {
|
||||
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
|
||||
# CHECK: return %[[RET]] : !torch.optional<i64>
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def optional_return(i: int) -> typing.Optional[int]:
|
||||
return i
|
||||
|
||||
# CHECK-LABEL: func @__torch__.optional_arg(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<i64>) -> !basicpy.NoneType {
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def optional_arg(i: typing.Optional[int]) -> None:
|
||||
return
|
||||
|
||||
# CHECK-LABEL: func @__torch__.calls_optional_arg(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.NoneType {
|
||||
# CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional<i64>) -> !basicpy.NoneType
|
||||
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
|
||||
# CHECK: %{{.*}} = call_indirect %[[CALLEE]](%[[DEREFINED]]) : (!torch.optional<i64>) -> !basicpy.NoneType
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def calls_optional_arg(i: int):
|
||||
optional_arg(i)
|
||||
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
|
@ -9,12 +9,12 @@ import torch_mlir
|
|||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: @f(
|
||||
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
|
||||
# CHECK-SAME: %[[I:.*]]: i64) -> i64 {
|
||||
# CHECK-LABEL: @__torch__.prim_If(
|
||||
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
|
||||
# CHECK-SAME: %[[I:.*]]: i64) -> i64 {
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def f(b: bool, i: int):
|
||||
def prim_If(b: bool, i: int):
|
||||
# CHECK: %[[I1:.*]] = basicpy.bool_cast %[[B]] : !basicpy.BoolType -> i1
|
||||
# CHECK: %[[RES:.*]] = scf.if %[[I1]] -> (i64) {
|
||||
# CHECK: %[[ADD:.*]] = torch.kernel_call "aten::add" %[[I]], %[[I]]
|
||||
|
@ -30,6 +30,25 @@ def f(b: bool, i: int):
|
|||
return i * i
|
||||
# elif is modeled as a nested if, so no need to specially test it here.
|
||||
|
||||
assert isinstance(f, torch.jit.ScriptFunction)
|
||||
# CHECK-LABEL: func @__torch__.prim_If_derefine(
|
||||
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
|
||||
# CHECK-SAME: %[[I:.*]]: i64) -> !torch.optional<i64> {
|
||||
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: %[[PRED:.*]] = basicpy.bool_cast %[[B]] : !basicpy.BoolType -> i1
|
||||
# CHECK: %[[RES:.*]] = scf.if %[[PRED]] -> (!torch.optional<i64>) {
|
||||
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
|
||||
# CHECK: scf.yield %[[NONE_DEREFINED]] : !torch.optional<i64>
|
||||
# CHECK: } else {
|
||||
# CHECK: %[[I_DEREFINED:.*]] = torch.derefine %[[I]] : i64 -> !torch.optional<i64>
|
||||
# CHECK: scf.yield %[[I_DEREFINED]] : !torch.optional<i64>
|
||||
# CHECK: }
|
||||
# CHECK: return %[[RES:.*]] : !torch.optional<i64>
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_If_derefine(b: bool, i: int):
|
||||
if b:
|
||||
return None
|
||||
return i
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
|||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @f(
|
||||
# CHECK-LABEL: func @__torch__.f(
|
||||
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType {
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType
|
||||
|
|
|
@ -5,11 +5,13 @@
|
|||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import typing
|
||||
|
||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @prim_Loop_forlike(
|
||||
# CHECK-LABEL: func @__torch__.prim_Loop_forlike(
|
||||
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: i64) -> f64 {
|
||||
# CHECK: %[[BOOL_TRUE:.*]] = basicpy.bool_constant true
|
||||
# CHECK: %[[F_INIT:.*]] = constant 0.000000e+00 : f64
|
||||
|
@ -27,18 +29,18 @@ def prim_Loop_forlike(n: int):
|
|||
f += i
|
||||
return f
|
||||
|
||||
# CHECK-LABEL: func @prim_Loop_whilelike(
|
||||
# CHECK-LABEL: func @__torch__.prim_Loop_whilelike(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: i64) -> f64 {
|
||||
# CHECK: %[[F_INIT:.*]] = constant 3.200000e+00 : f64
|
||||
# CHECK: %[[MAX_ITERATIONS:.*]] = constant 9223372036854775807 : i64
|
||||
# CHECK: %[[COND_INIT:.*]] = torch.kernel_call "aten::lt" %[[F_INIT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
||||
# CHECK: %[[IV:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) {
|
||||
# CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64):
|
||||
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
|
||||
# CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
||||
# CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
|
||||
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
|
||||
# CHECK: return %[[VAL_9:.*]] : f64
|
||||
# CHECK: return %[[RET:.*]] : f64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_Loop_whilelike(n: int):
|
||||
|
@ -47,5 +49,24 @@ def prim_Loop_whilelike(n: int):
|
|||
f = f * f
|
||||
return f
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_Loop_derefine(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.optional<i64> {
|
||||
# CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
|
||||
# CHECK: %[[RET:.*]] = torch.prim.Loop %[[ARG]], %[[TRUE]], init(%[[NONE_DEREFINED]]) {
|
||||
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[X_ITER:.*]]: !torch.optional<i64>):
|
||||
# CHECK: %[[X_NEXT:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
|
||||
# CHECK: torch.prim.Loop.condition %[[TRUE]] iter(%[[X_NEXT]]) : !basicpy.BoolType, (!torch.optional<i64>)
|
||||
# CHECK: } : (i64, !basicpy.BoolType, !torch.optional<i64>) -> !torch.optional<i64>
|
||||
# CHECK: return %[[RET:.*]] : !torch.optional<i64>
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_Loop_derefine(n: int):
|
||||
x: typing.Optional[int] = None
|
||||
for i in range(n):
|
||||
x = n
|
||||
return x
|
||||
|
||||
mb.module.operation.print()
|
||||
print()
|
||||
|
|
|
@ -14,7 +14,7 @@ import typing
|
|||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
|
||||
# CHECK-LABEL: func @prim_NumToTensor(
|
||||
# CHECK-LABEL: func @__torch__.prim_NumToTensor(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||
|
@ -25,7 +25,7 @@ mb = torch_mlir.ModuleBuilder()
|
|||
def prim_NumToTensor(i: int):
|
||||
return _to_tensor(i)
|
||||
|
||||
# CHECK-LABEL: func @prim_Print(
|
||||
# CHECK-LABEL: func @__torch__.prim_Print(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType {
|
||||
# CHECK: %[[STR:.*]] = basicpy.bytes_constant "x"
|
||||
# CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !basicpy.BytesType, !numpy.ndarray<*:!numpy.any_dtype>
|
||||
|
@ -34,7 +34,7 @@ def prim_NumToTensor(i: int):
|
|||
def prim_Print(x):
|
||||
print("x", x)
|
||||
|
||||
# CHECK-LABEL: func @prim_RaiseException() -> !basicpy.NoneType {
|
||||
# CHECK-LABEL: func @__torch__.prim_RaiseException() -> !basicpy.NoneType {
|
||||
# CHECK: %[[ERRORSTR:.*]] = basicpy.bytes_constant "Error"
|
||||
# CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !basicpy.NoneType
|
||||
# CHECK: torch.prim.RaiseException %[[ERRORSTR]]
|
||||
|
@ -44,7 +44,7 @@ def prim_Print(x):
|
|||
def prim_RaiseException():
|
||||
raise Exception("Error")
|
||||
|
||||
# CHECK-LABEL: func @prim_unchecked_cast(
|
||||
# CHECK-LABEL: func @__torch__.prim_unchecked_cast(
|
||||
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 {
|
||||
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||
# CHECK: %[[C3:.*]] = constant 3 : i64
|
||||
|
@ -64,7 +64,7 @@ def prim_unchecked_cast(i: typing.Optional[int]):
|
|||
return 3
|
||||
return i
|
||||
|
||||
# CHECK-LABEL: func @prim_TupleUnpack(
|
||||
# CHECK-LABEL: func @__torch__.prim_TupleUnpack(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
|
||||
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64
|
||||
# CHECK: return %[[RET]]#0 : i64
|
||||
|
@ -74,7 +74,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
|
|||
val, _ = tup
|
||||
return val
|
||||
|
||||
# CHECK-LABEL: func @prim_TupleIndex(
|
||||
# CHECK-LABEL: func @__torch__.prim_TupleIndex(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !basicpy.TupleType, i64 -> i64
|
||||
# CHECK: return %[[RET]] : i64
|
||||
|
@ -83,7 +83,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
|
|||
def prim_TupleIndex(tup: typing.Tuple[int, int]):
|
||||
return tup[0]
|
||||
|
||||
# CHECK-LABEL: func @prim_ListUnpack(
|
||||
# CHECK-LABEL: func @__torch__.prim_ListUnpack(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 {
|
||||
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64
|
||||
# CHECK: return %[[RET]]#1 : i64
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
|||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK-LABEL: func @f(
|
||||
# CHECK-LABEL: func @__torch__.f(
|
||||
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType {
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
|||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK: @returns_bool
|
||||
# CHECK: @__torch__.returns_bool
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def returns_bool():
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
|||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
|
||||
# CHECK: @returns_none
|
||||
# CHECK: @__torch__.returns_none
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def returns_none():
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
|||
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<Torch_Dialect, mnemonic, traits> {
|
||||
|
@ -457,10 +458,59 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", []> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", []> {
|
||||
def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", [
|
||||
NoSideEffect
|
||||
]> {
|
||||
let summary = "TorchScript prim::unchecked_cast op";
|
||||
// TODO: This seems to mostly be used for casting "optional" to the contained
|
||||
// type. Verify that and tighten the verifier.
|
||||
let description = [{
|
||||
Refine a type to one of its subtypes.
|
||||
|
||||
For example, refine a type that was only statically known to be
|
||||
Optional[T] to a T when we obtain static information that guarantees it.
|
||||
|
||||
The key observation here is that Optional[T] does not have a corresponding
|
||||
runtime type (i.e. `c10::IValue` subclass). It represents a set of possible
|
||||
concrete types which for `Optional[T]` is either `None` or a concrete
|
||||
subtype of `T` (which in the simplest case is just `T`). In particular,
|
||||
at runtime there is no way to distinguish `Optional[int]` from
|
||||
`Optional[Optional[int]]`, because both are either `None` or `int`.
|
||||
This differs from C++ std::optional.
|
||||
|
||||
The best documentation of this op is inspection of the code in
|
||||
`torch/csrc/jit/frontend/ir_emitter.cpp`.
|
||||
}];
|
||||
|
||||
// TODO: When we model PyTorch's notion of subtyping, verify the types here.
|
||||
let arguments = (ins AnyTorchType:$operand);
|
||||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Additional ops used to model TorchScript's Graph's / Node's.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Torch_DerefineOp : Torch_Op<"derefine", [
|
||||
NoSideEffect
|
||||
]> {
|
||||
let summary = "De-refine a type";
|
||||
let description = [{
|
||||
In terms of IR structure, TorchScript allows types to vary in many
|
||||
circumstances where MLIR requires pointer-identical types. In particular,
|
||||
it is valid to pass any subtype in place of a type. For example, if an
|
||||
`Optional[int]` is required somewhere in the IR, it is legal to pass a
|
||||
value of just `int` (but not the other way around; see
|
||||
`torch.prim.unchecked_cast`). In effect, every *use* can have a different
|
||||
type.
|
||||
|
||||
This op bridges that impedance mismatch. This op allows casting a value
|
||||
from one type to a type that it is a subtype of to model this behavior.
|
||||
}];
|
||||
|
||||
// TODO: When we model PyTorch's notion of subtyping, verify the types here.
|
||||
let arguments = (ins AnyTorchType:$operand);
|
||||
let results = (outs AnyTorchType:$result);
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
|
|||
MLIRIR
|
||||
MLIRSupport
|
||||
MLIRControlFlowInterfaces
|
||||
MLIRSideEffectInterfaces
|
||||
NPCOMPBasicpyDialect
|
||||
NPCOMPNumpyDialect
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue