mirror of https://github.com/llvm/torch-mlir
Properly model "derefinement".
In terms of IR structure, TorchScript allows types to vary in many circumstances where MLIR requires pointer-identical types. In particular, it is valid to pass any subtype in place of a type. For example, if an `Optional[int]` is required somewhere in the IR, it is legal to pass a value of just `int` (but not the other way around; see `torch.prim.unchecked_cast`). In effect, every *use* can have a different type. We introduce a new op `torch.derefine` that models that impedance mismatch. This op allows casting a value from one type to a type that it is a subtype of to model this behavior. Recommended review order: - TorchOps.td for new torch.derefine (and updated docs for `torch.prim.unchecked_cast`) - new test code in if.py, loop.py, function-derefine.py - new code in node_importer.cpp for handling derefinement insertion - function_importer.cpp and utils changes in torch_to_mlir_utils.cpp Properly handling derefinement on function boundaries required relayering the code so that graph_importer.cpp/.h is now function_importer.cpp/.h because only the `torch::jit::Function` (actually the `c10::FunctionSchema` it holds) knows the derefined types that are actually needed at the boundary (see `function-derefine.py` for a test). Annoyingly, this churns all the functions which are now prefixed with `__torch__.` but that is more correct anyway (that is their linkage name in the `torch::jit::CompilationUnit`; the previous `mb.import_function` was actually buggy in the case of functions calling each other as it would reference their unqualified name). With this change, we can import `resnet18` from `torchvision` :) IR: https://gist.github.com/silvasean/6426a5272d8a6c7caae533fce05ab704pull/178/head
parent
1736ff0253
commit
43dba03afd
|
@ -15,7 +15,7 @@ add_library(NPCOMPTorchMLIRExt SHARED
|
||||||
builder/class_annotator.cpp
|
builder/class_annotator.cpp
|
||||||
builder/debug.cpp
|
builder/debug.cpp
|
||||||
builder/func_builder.cpp
|
builder/func_builder.cpp
|
||||||
builder/graph_importer.cpp
|
builder/function_importer.cpp
|
||||||
builder/module_builder.cpp
|
builder/module_builder.cpp
|
||||||
builder/node_importer.cpp
|
builder/node_importer.cpp
|
||||||
builder/op_builder.cpp
|
builder/op_builder.cpp
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
//===- function_importer.cpp ----------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See frontends/pytorch/LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "function_importer.h"
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "mlir_utils.h"
|
||||||
|
#include "torch_to_mlir_utils.h"
|
||||||
|
|
||||||
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
|
#include "mlir-c/Diagnostics.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
MlirOperation
|
||||||
|
torch_mlir::importJitFunctionAsFuncOp(MlirContext context,
|
||||||
|
torch::jit::Function *function) {
|
||||||
|
// Useful for debugging:
|
||||||
|
// graph->dump();
|
||||||
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
MlirType functionType =
|
||||||
|
getFunctionTypeFromSchema(context, function->getSchema());
|
||||||
|
// Use the function's qualified name from the compilation unit.
|
||||||
|
// This is a stable linkage name that matches Python module lookup
|
||||||
|
// conventions (see compilation unit import in IValueImporter for more details
|
||||||
|
// on qualified names).
|
||||||
|
MlirAttribute symNameAttr = mlirStringAttrGet(
|
||||||
|
context, toMlirStringRef(function->qualname().qualifiedName()));
|
||||||
|
MlirOperation func = createMlirOperation(
|
||||||
|
"func", loc, mlirRegionCreate(),
|
||||||
|
toMlirNamedAttribute("type", mlirTypeAttrGet(functionType)),
|
||||||
|
toMlirNamedAttribute("sym_name", symNameAttr));
|
||||||
|
MlirRegion bodyRegion = mlirOperationGetRegion(func, 0);
|
||||||
|
std::vector<MlirType> resultTypes;
|
||||||
|
for (int i = 0, e = mlirFunctionTypeGetNumResults(functionType); i != e;
|
||||||
|
i++) {
|
||||||
|
resultTypes.push_back(mlirFunctionTypeGetResult(functionType, i));
|
||||||
|
}
|
||||||
|
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||||
|
MlirBlock appendToBlock) {
|
||||||
|
createMlirOperationAtEnd(
|
||||||
|
appendToBlock, "std.return", loc,
|
||||||
|
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||||
|
};
|
||||||
|
MlirBlock block =
|
||||||
|
importBlock(context, function->graph()->block(), createTerminator);
|
||||||
|
mlirRegionAppendOwnedBlock(bodyRegion, block);
|
||||||
|
return func;
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
//===- function_importer.h --------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under a pytorch-style license
|
||||||
|
// See frontends/pytorch/LICENSE for license information.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H
|
||||||
|
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "../pybind.h"
|
||||||
|
#include "func_builder.h"
|
||||||
|
#include "node_importer.h"
|
||||||
|
|
||||||
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
/// Main entry-point for importing torch::jit::Function instances.
|
||||||
|
///
|
||||||
|
/// This code doesn't handle importing of torch::jit::Module's. See
|
||||||
|
/// IValueImporter for that.
|
||||||
|
///
|
||||||
|
/// A torch::jit::Function holds a c10::FunctionSchema along with a
|
||||||
|
/// c10::QualifiedName and a torch::jit::Graph.
|
||||||
|
///
|
||||||
|
/// The torch::jit::Graph is a combination of an MLIR context, function, and
|
||||||
|
/// builder. See NodeImporter for importing of the core IR Node/Block
|
||||||
|
/// structure that is analogous to MLIR's Operation/Region/Block core structure.
|
||||||
|
MlirOperation importJitFunctionAsFuncOp(MlirContext context,
|
||||||
|
torch::jit::Function *function);
|
||||||
|
|
||||||
|
} // namespace torch_mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_FUNCTION_IMPORTER_H
|
|
@ -1,38 +0,0 @@
|
||||||
//===- graph_importer.cpp -------------------------------------------------===//
|
|
||||||
//
|
|
||||||
// This file is licensed under a pytorch-style license
|
|
||||||
// See frontends/pytorch/LICENSE for license information.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "graph_importer.h"
|
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "mlir_utils.h"
|
|
||||||
#include "torch_to_mlir_utils.h"
|
|
||||||
|
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
|
||||||
#include "mlir-c/Diagnostics.h"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
using namespace torch_mlir;
|
|
||||||
|
|
||||||
MlirOperation torch_mlir::importGraphAsFuncOp(MlirContext context,
|
|
||||||
torch::jit::Graph *graph,
|
|
||||||
const std::string &name) {
|
|
||||||
// Useful for debugging:
|
|
||||||
// graph->dump();
|
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
|
||||||
MlirAttribute typeAttr =
|
|
||||||
mlirTypeAttrGet(getFunctionTypeFromBlock(context, graph->block()));
|
|
||||||
MlirAttribute symNameAttr = mlirStringAttrGet(context, toMlirStringRef(name));
|
|
||||||
MlirOperation func = createMlirOperation(
|
|
||||||
"func", loc, mlirRegionCreate(), toMlirNamedAttribute("type", typeAttr),
|
|
||||||
toMlirNamedAttribute("sym_name", symNameAttr));
|
|
||||||
MlirRegion bodyRegion = mlirOperationGetRegion(func, 0);
|
|
||||||
MlirBlock block = importBlock(context, graph->block(), "std.return");
|
|
||||||
mlirRegionAppendOwnedBlock(bodyRegion, block);
|
|
||||||
return func;
|
|
||||||
}
|
|
|
@ -1,37 +0,0 @@
|
||||||
//===- graph_importer.h -----------------------------------------*- C++ -*-===//
|
|
||||||
//
|
|
||||||
// This file is licensed under a pytorch-style license
|
|
||||||
// See frontends/pytorch/LICENSE for license information.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
|
|
||||||
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "../pybind.h"
|
|
||||||
#include "func_builder.h"
|
|
||||||
#include "node_importer.h"
|
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
|
||||||
|
|
||||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
|
||||||
|
|
||||||
namespace torch_mlir {
|
|
||||||
|
|
||||||
/// Main entry-point for importing torch::jit::Graph instances.
|
|
||||||
///
|
|
||||||
/// This code doesn't handle importing of torch::jit::Module's. See
|
|
||||||
/// IValueImporter for that.
|
|
||||||
///
|
|
||||||
/// A Graph is a combination of an MLIR context, function, and builder.
|
|
||||||
/// See NodeImporter for importing of the core IR Node/Block
|
|
||||||
/// structure that is analogous to MLIR's Operation/Region/Block core structure.
|
|
||||||
MlirOperation importGraphAsFuncOp(MlirContext context, torch::jit::Graph *graph,
|
|
||||||
const std::string &name);
|
|
||||||
|
|
||||||
} // namespace torch_mlir
|
|
||||||
|
|
||||||
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_GRAPH_IMPORTER_H
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
#include "ivalue_importer.h"
|
#include "ivalue_importer.h"
|
||||||
#include "class_annotator.h"
|
#include "class_annotator.h"
|
||||||
#include "graph_importer.h"
|
#include "function_importer.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
|
||||||
|
@ -393,8 +393,7 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (torch::jit::Function *function : cu->get_functions()) {
|
for (torch::jit::Function *function : cu->get_functions()) {
|
||||||
MlirOperation func = importGraphAsFuncOp(
|
MlirOperation func = importJitFunctionAsFuncOp(context, function);
|
||||||
context, function->graph().get(), function->qualname().qualifiedName());
|
|
||||||
// For IValue importing, the logical linkage structure of the module
|
// For IValue importing, the logical linkage structure of the module
|
||||||
// is determined by the object graph.
|
// is determined by the object graph.
|
||||||
//
|
//
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
|
|
||||||
#include "module_builder.h"
|
#include "module_builder.h"
|
||||||
|
|
||||||
#include "graph_importer.h"
|
#include "function_importer.h"
|
||||||
#include "ivalue_importer.h"
|
#include "ivalue_importer.h"
|
||||||
|
|
||||||
#include "mlir-c/Bindings/Python/Interop.h"
|
#include "mlir-c/Bindings/Python/Interop.h"
|
||||||
|
@ -128,8 +128,7 @@ torch::jit::StrongFunctionPtr
|
||||||
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
||||||
MlirBlock block = getBodyBlock();
|
MlirBlock block = getBodyBlock();
|
||||||
MlirOperation terminator = this->terminator;
|
MlirOperation terminator = this->terminator;
|
||||||
MlirOperation func = importGraphAsFuncOp(
|
MlirOperation func = importJitFunctionAsFuncOp(context, function.function_);
|
||||||
context, function.function_->graph().get(), function.function_->name());
|
|
||||||
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
mlirBlockInsertOwnedOperationBefore(block, terminator, func);
|
||||||
return function;
|
return function;
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ public:
|
||||||
NodeImporter(MlirContext context) : context(context) {}
|
NodeImporter(MlirContext context) : context(context) {}
|
||||||
|
|
||||||
void importNode(Node *node, MlirBlock appendToBlock);
|
void importNode(Node *node, MlirBlock appendToBlock);
|
||||||
MlirBlock importBlock(Block *jitBlock, const std::string &terminatorOpName);
|
MlirBlock importBlock(Block *jitBlock, CreateTerminatorFn createTerminator);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void importPrimNode(Node *node, MlirBlock appendToBlock);
|
void importPrimNode(Node *node, MlirBlock appendToBlock);
|
||||||
|
@ -70,7 +70,7 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
||||||
const std::string &symName = function->qualname().qualifiedName();
|
const std::string &symName = function->qualname().qualifiedName();
|
||||||
op = createMlirOperation(
|
op = createMlirOperation(
|
||||||
"std.constant", loc,
|
"std.constant", loc,
|
||||||
getFunctionTypeFromBlock(context, function->graph()->block()),
|
getFunctionTypeFromSchema(context, function->getSchema()),
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
"value",
|
"value",
|
||||||
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
||||||
|
@ -141,14 +141,28 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kind == c10::prim::Loop) {
|
if (kind == c10::prim::Loop) {
|
||||||
|
std::vector<MlirType> resultTypes =
|
||||||
|
getMlirTypesFromValues(loc, node->outputs());
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
appendToBlock, "torch.prim.Loop", loc,
|
appendToBlock, "torch.prim.Loop", loc, resultTypes,
|
||||||
getMlirTypesFromValues(loc, node->outputs()),
|
lookupMappedValues(node->inputs().slice(0, 2)),
|
||||||
lookupMappedValues(node->inputs()), mlirRegionCreate());
|
derefineValues(lookupMappedValues(node->inputs().slice(2)), resultTypes,
|
||||||
|
loc, appendToBlock),
|
||||||
|
mlirRegionCreate());
|
||||||
mapResults(node, operation);
|
mapResults(node, operation);
|
||||||
|
std::vector<MlirType> terminatorOperandTypes = {npcompBoolTypeGet(context)};
|
||||||
|
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
||||||
|
resultTypes.begin(), resultTypes.end());
|
||||||
|
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||||
|
MlirBlock appendToBlock) {
|
||||||
|
createMlirOperationAtEnd(appendToBlock, "torch.prim.Loop.condition", loc,
|
||||||
|
derefineValues(yieldedValues,
|
||||||
|
terminatorOperandTypes, loc,
|
||||||
|
appendToBlock));
|
||||||
|
};
|
||||||
mlirRegionAppendOwnedBlock(
|
mlirRegionAppendOwnedBlock(
|
||||||
mlirOperationGetRegion(operation, 0),
|
mlirOperationGetRegion(operation, 0),
|
||||||
importBlock(node->blocks()[0], "torch.prim.Loop.condition"));
|
importBlock(node->blocks()[0], createTerminator));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,15 +172,24 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
||||||
MlirOperation pred = createMlirOperationAtEnd(
|
MlirOperation pred = createMlirOperationAtEnd(
|
||||||
appendToBlock, "basicpy.bool_cast", loc, mlirIntegerTypeGet(context, 1),
|
appendToBlock, "basicpy.bool_cast", loc, mlirIntegerTypeGet(context, 1),
|
||||||
lookupMappedValue(node->input()));
|
lookupMappedValue(node->input()));
|
||||||
|
std::vector<MlirType> resultTypes =
|
||||||
|
getMlirTypesFromValues(loc, node->outputs());
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0),
|
appendToBlock, "scf.if", loc, mlirOperationGetResult(pred, 0),
|
||||||
getMlirTypesFromValues(loc, node->outputs()), mlirRegionCreate(),
|
resultTypes, mlirRegionCreate(), mlirRegionCreate());
|
||||||
mlirRegionCreate());
|
|
||||||
mapResults(node, operation);
|
mapResults(node, operation);
|
||||||
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 0),
|
auto createTerminator =
|
||||||
importBlock(node->blocks()[0], "scf.yield"));
|
[&](c10::ArrayRef<MlirValue> yieldedValues, MlirBlock appendToBlock) {
|
||||||
mlirRegionAppendOwnedBlock(mlirOperationGetRegion(operation, 1),
|
createMlirOperationAtEnd(
|
||||||
importBlock(node->blocks()[1], "scf.yield"));
|
appendToBlock, "scf.yield", loc,
|
||||||
|
derefineValues(yieldedValues, resultTypes, loc, appendToBlock));
|
||||||
|
};
|
||||||
|
mlirRegionAppendOwnedBlock(
|
||||||
|
mlirOperationGetRegion(operation, 0),
|
||||||
|
importBlock(node->blocks()[0], createTerminator));
|
||||||
|
mlirRegionAppendOwnedBlock(
|
||||||
|
mlirOperationGetRegion(operation, 1),
|
||||||
|
importBlock(node->blocks()[1], createTerminator));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,10 +203,18 @@ void NodeImporter::importPrimNode(Node *node, MlirBlock appendToBlock) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kind == c10::prim::CallFunction) {
|
if (kind == c10::prim::CallFunction) {
|
||||||
MlirOperation operation =
|
auto functionType = node->input(0)->type()->cast<c10::FunctionType>();
|
||||||
createMlirOperationAtEnd(appendToBlock, "std.call_indirect", loc,
|
torch::jit::Block *calleeEntryBlock =
|
||||||
getMlirTypesFromValues(loc, node->outputs()),
|
functionType->function()->graph()->block();
|
||||||
lookupMappedValues(node->inputs()));
|
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||||
|
return typeMapper.mapFromTorchType(loc, v->type());
|
||||||
|
});
|
||||||
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
|
appendToBlock, "std.call_indirect", loc,
|
||||||
|
getMlirTypesFromValues(loc, node->outputs()),
|
||||||
|
lookupMappedValue(node->input(0)),
|
||||||
|
derefineValues(lookupMappedValues(node->inputs().slice(1)),
|
||||||
|
expectedTypes, loc, appendToBlock));
|
||||||
mapResults(node, operation);
|
mapResults(node, operation);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -288,15 +319,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock NodeImporter::importBlock(Block *jitBlock,
|
MlirBlock NodeImporter::importBlock(Block *jitBlock,
|
||||||
const std::string &terminatorOpName) {
|
CreateTerminatorFn createTerminator) {
|
||||||
MlirBlock block = createBlockFor(jitBlock);
|
MlirBlock block = createBlockFor(jitBlock);
|
||||||
for (Node *node : jitBlock->nodes()) {
|
for (Node *node : jitBlock->nodes()) {
|
||||||
importNode(node, block);
|
importNode(node, block);
|
||||||
}
|
}
|
||||||
Node *returnNode = jitBlock->return_node();
|
Node *returnNode = jitBlock->return_node();
|
||||||
createMlirOperationAtEnd(block, terminatorOpName,
|
createTerminator(lookupMappedValues(returnNode->inputs()), block);
|
||||||
getMlirLocationFromNode(context, returnNode),
|
|
||||||
lookupMappedValues(returnNode->inputs()));
|
|
||||||
return block;
|
return block;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,7 +372,7 @@ NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
||||||
const std::string &terminatorOpName) {
|
CreateTerminatorFn createTerminator) {
|
||||||
NodeImporter importer(context);
|
NodeImporter importer(context);
|
||||||
return importer.importBlock(jitBlock, terminatorOpName);
|
return importer.importBlock(jitBlock, createTerminator);
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,8 +20,11 @@
|
||||||
|
|
||||||
namespace torch_mlir {
|
namespace torch_mlir {
|
||||||
|
|
||||||
|
using CreateTerminatorFn =
|
||||||
|
std::function<void(c10::ArrayRef<MlirValue>, MlirBlock)>;
|
||||||
|
|
||||||
MlirBlock importBlock(MlirContext context, torch::jit::Block *jitBlock,
|
MlirBlock importBlock(MlirContext context, torch::jit::Block *jitBlock,
|
||||||
const std::string &terminatorOpName);
|
CreateTerminatorFn createTerminator);
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "graph_importer.h"
|
#include "function_importer.h"
|
||||||
#include "ivalue_importer.h"
|
#include "ivalue_importer.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
@ -163,17 +163,28 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
||||||
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
|
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torch_mlir::getFunctionTypeFromBlock(MlirContext context,
|
MlirType
|
||||||
torch::jit::Block *block) {
|
torch_mlir::getFunctionTypeFromSchema(MlirContext context,
|
||||||
MlirLocation inputLoc = getMlirLocationFromNode(context, block->param_node());
|
const c10::FunctionSchema &schema) {
|
||||||
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
TypeMapper typeMapper(context);
|
||||||
|
auto mapType = [&](const c10::TypePtr &torchType) {
|
||||||
|
MlirType type = typeMapper.mapFromTorchType(loc, torchType);
|
||||||
|
if (mlirTypeIsNull(type)) {
|
||||||
|
std::stringstream msg;
|
||||||
|
msg << "unsupported type in function schema: '"
|
||||||
|
<< c10::toString(torchType) << "'";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
return type;
|
||||||
|
};
|
||||||
|
|
||||||
std::vector<MlirType> inputTypes =
|
std::vector<MlirType> inputTypes =
|
||||||
getMlirTypesFromValues(inputLoc, block->param_node()->outputs());
|
c10::fmap(schema.arguments(),
|
||||||
|
[&](const c10::Argument &arg) { return mapType(arg.type()); });
|
||||||
MlirLocation outputLoc =
|
|
||||||
getMlirLocationFromNode(context, block->return_node());
|
|
||||||
std::vector<MlirType> outputTypes =
|
std::vector<MlirType> outputTypes =
|
||||||
getMlirTypesFromValues(outputLoc, block->return_node()->inputs());
|
c10::fmap(schema.returns(),
|
||||||
|
[&](const c10::Argument &arg) { return mapType(arg.type()); });
|
||||||
return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
|
return mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
|
||||||
outputTypes.size(), outputTypes.data());
|
outputTypes.size(), outputTypes.data());
|
||||||
}
|
}
|
||||||
|
@ -303,3 +314,25 @@ torch_mlir::getMlirTypesFromValues(MlirLocation loc,
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<MlirValue>
|
||||||
|
torch_mlir::derefineValues(c10::ArrayRef<MlirValue> values,
|
||||||
|
c10::ArrayRef<MlirType> expectedTypes,
|
||||||
|
MlirLocation loc, MlirBlock appendToBlock) {
|
||||||
|
std::vector<MlirValue> ret;
|
||||||
|
assert(values.size() == expectedTypes.size());
|
||||||
|
for (int i = 0, e = values.size(); i != e; i++) {
|
||||||
|
MlirValue value = values[i];
|
||||||
|
MlirType expectedType = expectedTypes[i];
|
||||||
|
MlirType type = mlirValueGetType(value);
|
||||||
|
if (mlirTypeEqual(expectedType, type)) {
|
||||||
|
// No need to derefine.
|
||||||
|
ret.push_back(value);
|
||||||
|
} else {
|
||||||
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
|
appendToBlock, "torch.derefine", loc, expectedType, value);
|
||||||
|
ret.push_back(mlirOperationGetResult(operation, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
|
@ -51,18 +51,12 @@ private:
|
||||||
MlirContext context;
|
MlirContext context;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Creates a FunctionType suitable for expressing the signature of `block`.
|
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
||||||
///
|
///
|
||||||
/// `mlir::Block` only has a formalized notion of argument types (bb args), but
|
/// This can differ from the type inferred from the block of a
|
||||||
/// the exact nature of the block's terminator is left opaque (for example, you
|
/// torch::jit::Function due to derefinement.
|
||||||
/// can have a weird terminator that "returns all but the first operand").
|
MlirType getFunctionTypeFromSchema(MlirContext context,
|
||||||
/// `torch::jit::Block` on the other hand has a formalized notion of a
|
const c10::FunctionSchema &schema);
|
||||||
/// `param_node` and `return_node`, which are effectively dummy operations at
|
|
||||||
/// the start and end of the block, which establish a formal signature for the
|
|
||||||
/// block and can be generically reasoned about -- that is what we anchor on
|
|
||||||
/// here.
|
|
||||||
MlirType getFunctionTypeFromBlock(MlirContext context,
|
|
||||||
torch::jit::Block *block);
|
|
||||||
|
|
||||||
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
|
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
|
||||||
MlirAttribute converTensorToMlirElementsAttr(at::Tensor tensor,
|
MlirAttribute converTensorToMlirElementsAttr(at::Tensor tensor,
|
||||||
|
@ -78,6 +72,11 @@ std::vector<MlirType>
|
||||||
getMlirTypesFromValues(MlirLocation loc,
|
getMlirTypesFromValues(MlirLocation loc,
|
||||||
c10::ArrayRef<torch::jit::Value *> values);
|
c10::ArrayRef<torch::jit::Value *> values);
|
||||||
|
|
||||||
|
std::vector<MlirValue> derefineValues(c10::ArrayRef<MlirValue> values,
|
||||||
|
c10::ArrayRef<MlirType> expectedTypes,
|
||||||
|
MlirLocation loc,
|
||||||
|
MlirBlock appendToBlock);
|
||||||
|
|
||||||
} // namespace torch_mlir
|
} // namespace torch_mlir
|
||||||
|
|
||||||
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H
|
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_TORCH_TO_MLIR_UTILS_H
|
||||||
|
|
|
@ -10,7 +10,7 @@ import torch_mlir
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# Verify without debug info.
|
# Verify without debug info.
|
||||||
# CHECK-LABEL: func @add3
|
# CHECK-LABEL: func @__torch__.add3
|
||||||
# CHECK-SAME: (%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
# CHECK-SAME: (%arg0: !numpy.ndarray<*:!numpy.any_dtype>, %arg1: !numpy.ndarray<*:!numpy.any_dtype>, %arg2: !numpy.ndarray<*:!numpy.any_dtype>) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
# CHECK: %[[C1:.*]] = constant 1 : i64
|
# CHECK: %[[C1:.*]] = constant 1 : i64
|
||||||
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
# CHECK: %[[A0:.*]] = torch.kernel_call "aten::add" %arg0, %arg1, %[[C1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>, i64) -> !numpy.ndarray<*:!numpy.any_dtype> {sigArgTypes = ["Tensor", "Tensor", "Scalar"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# CHECK-LABEL: func @add3
|
# CHECK-LABEL: func @__torch__.add3
|
||||||
# Note that line-level debug information for parts unannotated in the Torch
|
# Note that line-level debug information for parts unannotated in the Torch
|
||||||
# graph are ascribed to the first op that carries source information. Presently
|
# graph are ascribed to the first op that carries source information. Presently
|
||||||
# this includes naked constants, return and the function itself. This heuristic
|
# this includes naked constants, return and the function itself. This heuristic
|
||||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# CHECK-LABEL: @f
|
# CHECK-LABEL: @__torch__.f
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def f(b: bool, i: int):
|
def f(b: bool, i: int):
|
||||||
|
|
|
@ -17,8 +17,8 @@ try:
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def import_class(x: typing.Any):
|
def import_class(x: typing.Any):
|
||||||
return x
|
return x
|
||||||
except RuntimeError as e:
|
except Exception as e:
|
||||||
# TODO: Once diagnostics are enabled, verify the actual error emitted.
|
# TODO: Once diagnostics are enabled, verify the actual error emitted.
|
||||||
assert str(e) == "unsupported type"
|
assert str(e) == "unsupported type in function schema: 'Any'"
|
||||||
else:
|
else:
|
||||||
assert False, "Expected exception"
|
assert False, "Expected exception"
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
# -*- Python -*-
|
||||||
|
# This file is licensed under a pytorch-style license
|
||||||
|
# See frontends/pytorch/LICENSE for license information.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @__torch__.optional_return(
|
||||||
|
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.optional<i64> {
|
||||||
|
# CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
|
||||||
|
# CHECK: return %[[RET]] : !torch.optional<i64>
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def optional_return(i: int) -> typing.Optional[int]:
|
||||||
|
return i
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @__torch__.optional_arg(
|
||||||
|
# CHECK-SAME: %[[ARG:.*]]: !torch.optional<i64>) -> !basicpy.NoneType {
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def optional_arg(i: typing.Optional[int]) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @__torch__.calls_optional_arg(
|
||||||
|
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.NoneType {
|
||||||
|
# CHECK: %[[CALLEE:.*]] = constant @__torch__.optional_arg : (!torch.optional<i64>) -> !basicpy.NoneType
|
||||||
|
# CHECK: %[[DEREFINED:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
|
||||||
|
# CHECK: %{{.*}} = call_indirect %[[CALLEE]](%[[DEREFINED]]) : (!torch.optional<i64>) -> !basicpy.NoneType
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def calls_optional_arg(i: int):
|
||||||
|
optional_arg(i)
|
||||||
|
|
||||||
|
|
||||||
|
mb.module.operation.print()
|
||||||
|
print()
|
|
@ -9,12 +9,12 @@ import torch_mlir
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# CHECK-LABEL: @f(
|
# CHECK-LABEL: @__torch__.prim_If(
|
||||||
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
|
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
|
||||||
# CHECK-SAME: %[[I:.*]]: i64) -> i64 {
|
# CHECK-SAME: %[[I:.*]]: i64) -> i64 {
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def f(b: bool, i: int):
|
def prim_If(b: bool, i: int):
|
||||||
# CHECK: %[[I1:.*]] = basicpy.bool_cast %[[B]] : !basicpy.BoolType -> i1
|
# CHECK: %[[I1:.*]] = basicpy.bool_cast %[[B]] : !basicpy.BoolType -> i1
|
||||||
# CHECK: %[[RES:.*]] = scf.if %[[I1]] -> (i64) {
|
# CHECK: %[[RES:.*]] = scf.if %[[I1]] -> (i64) {
|
||||||
# CHECK: %[[ADD:.*]] = torch.kernel_call "aten::add" %[[I]], %[[I]]
|
# CHECK: %[[ADD:.*]] = torch.kernel_call "aten::add" %[[I]], %[[I]]
|
||||||
|
@ -30,6 +30,25 @@ def f(b: bool, i: int):
|
||||||
return i * i
|
return i * i
|
||||||
# elif is modeled as a nested if, so no need to specially test it here.
|
# elif is modeled as a nested if, so no need to specially test it here.
|
||||||
|
|
||||||
assert isinstance(f, torch.jit.ScriptFunction)
|
# CHECK-LABEL: func @__torch__.prim_If_derefine(
|
||||||
|
# CHECK-SAME: %[[B:.*]]: !basicpy.BoolType,
|
||||||
|
# CHECK-SAME: %[[I:.*]]: i64) -> !torch.optional<i64> {
|
||||||
|
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||||
|
# CHECK: %[[PRED:.*]] = basicpy.bool_cast %[[B]] : !basicpy.BoolType -> i1
|
||||||
|
# CHECK: %[[RES:.*]] = scf.if %[[PRED]] -> (!torch.optional<i64>) {
|
||||||
|
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
|
||||||
|
# CHECK: scf.yield %[[NONE_DEREFINED]] : !torch.optional<i64>
|
||||||
|
# CHECK: } else {
|
||||||
|
# CHECK: %[[I_DEREFINED:.*]] = torch.derefine %[[I]] : i64 -> !torch.optional<i64>
|
||||||
|
# CHECK: scf.yield %[[I_DEREFINED]] : !torch.optional<i64>
|
||||||
|
# CHECK: }
|
||||||
|
# CHECK: return %[[RES:.*]] : !torch.optional<i64>
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def prim_If_derefine(b: bool, i: int):
|
||||||
|
if b:
|
||||||
|
return None
|
||||||
|
return i
|
||||||
|
|
||||||
mb.module.operation.print()
|
mb.module.operation.print()
|
||||||
print()
|
print()
|
||||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# CHECK-LABEL: func @f(
|
# CHECK-LABEL: func @__torch__.f(
|
||||||
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||||
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType {
|
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType {
|
||||||
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType
|
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.ListType
|
||||||
|
|
|
@ -5,11 +5,13 @@
|
||||||
import torch
|
import torch
|
||||||
import torch_mlir
|
import torch_mlir
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# CHECK-LABEL: func @prim_Loop_forlike(
|
# CHECK-LABEL: func @__torch__.prim_Loop_forlike(
|
||||||
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: i64) -> f64 {
|
# CHECK-SAME: %[[MAX_ITERATIONS:.*]]: i64) -> f64 {
|
||||||
# CHECK: %[[BOOL_TRUE:.*]] = basicpy.bool_constant true
|
# CHECK: %[[BOOL_TRUE:.*]] = basicpy.bool_constant true
|
||||||
# CHECK: %[[F_INIT:.*]] = constant 0.000000e+00 : f64
|
# CHECK: %[[F_INIT:.*]] = constant 0.000000e+00 : f64
|
||||||
|
@ -27,18 +29,18 @@ def prim_Loop_forlike(n: int):
|
||||||
f += i
|
f += i
|
||||||
return f
|
return f
|
||||||
|
|
||||||
# CHECK-LABEL: func @prim_Loop_whilelike(
|
# CHECK-LABEL: func @__torch__.prim_Loop_whilelike(
|
||||||
# CHECK-SAME: %[[VAL_0:.*]]: i64) -> f64 {
|
# CHECK-SAME: %[[VAL_0:.*]]: i64) -> f64 {
|
||||||
# CHECK: %[[F_INIT:.*]] = constant 3.200000e+00 : f64
|
# CHECK: %[[F_INIT:.*]] = constant 3.200000e+00 : f64
|
||||||
# CHECK: %[[MAX_ITERATIONS:.*]] = constant 9223372036854775807 : i64
|
# CHECK: %[[MAX_ITERATIONS:.*]] = constant 9223372036854775807 : i64
|
||||||
# CHECK: %[[COND_INIT:.*]] = torch.kernel_call "aten::lt" %[[F_INIT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
# CHECK: %[[COND_INIT:.*]] = torch.kernel_call "aten::lt" %[[F_INIT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
||||||
# CHECK: %[[IV:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) {
|
# CHECK: %[[RET:.*]] = torch.prim.Loop %[[MAX_ITERATIONS]], %[[COND_INIT]], init(%[[F_INIT]]) {
|
||||||
# CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64):
|
# CHECK: ^bb0(%[[F_ITER:.*]]: i64, %[[F_ITER:.*]]: f64):
|
||||||
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
|
# CHECK: %[[F_NEXT:.*]] = torch.kernel_call "aten::mul" %[[F_ITER]], %[[F_ITER]] : (f64, f64) -> f64 {sigArgTypes = ["float", "float"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["float"]}
|
||||||
# CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
# CHECK: %[[COND_ITER:.*]] = torch.kernel_call "aten::lt" %[[F_NEXT]], %[[VAL_0]] : (f64, i64) -> !basicpy.BoolType {sigArgTypes = ["float", "int"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["bool"]}
|
||||||
# CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
|
# CHECK: torch.prim.Loop.condition %[[COND_ITER]] iter(%[[F_NEXT]]) : !basicpy.BoolType, (f64)
|
||||||
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
|
# CHECK: } : (i64, !basicpy.BoolType, f64) -> f64
|
||||||
# CHECK: return %[[VAL_9:.*]] : f64
|
# CHECK: return %[[RET:.*]] : f64
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def prim_Loop_whilelike(n: int):
|
def prim_Loop_whilelike(n: int):
|
||||||
|
@ -47,5 +49,24 @@ def prim_Loop_whilelike(n: int):
|
||||||
f = f * f
|
f = f * f
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
# CHECK-LABEL: func @__torch__.prim_Loop_derefine(
|
||||||
|
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.optional<i64> {
|
||||||
|
# CHECK: %[[TRUE:.*]] = basicpy.bool_constant true
|
||||||
|
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||||
|
# CHECK: %[[NONE_DEREFINED:.*]] = torch.derefine %[[NONE]] : !basicpy.NoneType -> !torch.optional<i64>
|
||||||
|
# CHECK: %[[RET:.*]] = torch.prim.Loop %[[ARG]], %[[TRUE]], init(%[[NONE_DEREFINED]]) {
|
||||||
|
# CHECK: ^bb0(%[[IV:.*]]: i64, %[[X_ITER:.*]]: !torch.optional<i64>):
|
||||||
|
# CHECK: %[[X_NEXT:.*]] = torch.derefine %[[ARG]] : i64 -> !torch.optional<i64>
|
||||||
|
# CHECK: torch.prim.Loop.condition %[[TRUE]] iter(%[[X_NEXT]]) : !basicpy.BoolType, (!torch.optional<i64>)
|
||||||
|
# CHECK: } : (i64, !basicpy.BoolType, !torch.optional<i64>) -> !torch.optional<i64>
|
||||||
|
# CHECK: return %[[RET:.*]] : !torch.optional<i64>
|
||||||
|
@mb.import_function
|
||||||
|
@torch.jit.script
|
||||||
|
def prim_Loop_derefine(n: int):
|
||||||
|
x: typing.Optional[int] = None
|
||||||
|
for i in range(n):
|
||||||
|
x = n
|
||||||
|
return x
|
||||||
|
|
||||||
mb.module.operation.print()
|
mb.module.operation.print()
|
||||||
print()
|
print()
|
||||||
|
|
|
@ -14,7 +14,7 @@ import typing
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
|
|
||||||
# CHECK-LABEL: func @prim_NumToTensor(
|
# CHECK-LABEL: func @__torch__.prim_NumToTensor(
|
||||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
# CHECK-SAME: %[[ARG:.*]]: i64) -> !numpy.ndarray<*:!numpy.any_dtype> {
|
||||||
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype>
|
# CHECK: %[[RET:.*]] = torch.prim.NumToTensor %[[ARG]] : i64 -> !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
# CHECK: return %[[RET]] : !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
@ -25,7 +25,7 @@ mb = torch_mlir.ModuleBuilder()
|
||||||
def prim_NumToTensor(i: int):
|
def prim_NumToTensor(i: int):
|
||||||
return _to_tensor(i)
|
return _to_tensor(i)
|
||||||
|
|
||||||
# CHECK-LABEL: func @prim_Print(
|
# CHECK-LABEL: func @__torch__.prim_Print(
|
||||||
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType {
|
# CHECK-SAME: %[[ARG:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.NoneType {
|
||||||
# CHECK: %[[STR:.*]] = basicpy.bytes_constant "x"
|
# CHECK: %[[STR:.*]] = basicpy.bytes_constant "x"
|
||||||
# CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !basicpy.BytesType, !numpy.ndarray<*:!numpy.any_dtype>
|
# CHECK: torch.prim.Print(%[[STR]], %[[ARG]]) : !basicpy.BytesType, !numpy.ndarray<*:!numpy.any_dtype>
|
||||||
|
@ -34,7 +34,7 @@ def prim_NumToTensor(i: int):
|
||||||
def prim_Print(x):
|
def prim_Print(x):
|
||||||
print("x", x)
|
print("x", x)
|
||||||
|
|
||||||
# CHECK-LABEL: func @prim_RaiseException() -> !basicpy.NoneType {
|
# CHECK-LABEL: func @__torch__.prim_RaiseException() -> !basicpy.NoneType {
|
||||||
# CHECK: %[[ERRORSTR:.*]] = basicpy.bytes_constant "Error"
|
# CHECK: %[[ERRORSTR:.*]] = basicpy.bytes_constant "Error"
|
||||||
# CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !basicpy.NoneType
|
# CHECK: %[[NONE:.*]] = torch.prim.Uninitialized : !basicpy.NoneType
|
||||||
# CHECK: torch.prim.RaiseException %[[ERRORSTR]]
|
# CHECK: torch.prim.RaiseException %[[ERRORSTR]]
|
||||||
|
@ -44,7 +44,7 @@ def prim_Print(x):
|
||||||
def prim_RaiseException():
|
def prim_RaiseException():
|
||||||
raise Exception("Error")
|
raise Exception("Error")
|
||||||
|
|
||||||
# CHECK-LABEL: func @prim_unchecked_cast(
|
# CHECK-LABEL: func @__torch__.prim_unchecked_cast(
|
||||||
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 {
|
# CHECK-SAME: %[[VAL_0:.*]]: !torch.optional<i64>) -> i64 {
|
||||||
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
# CHECK: %[[NONE:.*]] = basicpy.singleton : !basicpy.NoneType
|
||||||
# CHECK: %[[C3:.*]] = constant 3 : i64
|
# CHECK: %[[C3:.*]] = constant 3 : i64
|
||||||
|
@ -64,7 +64,7 @@ def prim_unchecked_cast(i: typing.Optional[int]):
|
||||||
return 3
|
return 3
|
||||||
return i
|
return i
|
||||||
|
|
||||||
# CHECK-LABEL: func @prim_TupleUnpack(
|
# CHECK-LABEL: func @__torch__.prim_TupleUnpack(
|
||||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
|
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
|
||||||
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64
|
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64
|
||||||
# CHECK: return %[[RET]]#0 : i64
|
# CHECK: return %[[RET]]#0 : i64
|
||||||
|
@ -74,7 +74,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
|
||||||
val, _ = tup
|
val, _ = tup
|
||||||
return val
|
return val
|
||||||
|
|
||||||
# CHECK-LABEL: func @prim_TupleIndex(
|
# CHECK-LABEL: func @__torch__.prim_TupleIndex(
|
||||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
|
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
|
||||||
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !basicpy.TupleType, i64 -> i64
|
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !basicpy.TupleType, i64 -> i64
|
||||||
# CHECK: return %[[RET]] : i64
|
# CHECK: return %[[RET]] : i64
|
||||||
|
@ -83,7 +83,7 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
|
||||||
def prim_TupleIndex(tup: typing.Tuple[int, int]):
|
def prim_TupleIndex(tup: typing.Tuple[int, int]):
|
||||||
return tup[0]
|
return tup[0]
|
||||||
|
|
||||||
# CHECK-LABEL: func @prim_ListUnpack(
|
# CHECK-LABEL: func @__torch__.prim_ListUnpack(
|
||||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 {
|
# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 {
|
||||||
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64
|
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64
|
||||||
# CHECK: return %[[RET]]#1 : i64
|
# CHECK: return %[[RET]]#1 : i64
|
||||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# CHECK-LABEL: func @f(
|
# CHECK-LABEL: func @__torch__.f(
|
||||||
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
# CHECK-SAME: %[[T0:.*]]: !numpy.ndarray<*:!numpy.any_dtype>,
|
||||||
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType {
|
# CHECK-SAME: %[[T1:.*]]: !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType {
|
||||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType
|
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!numpy.ndarray<*:!numpy.any_dtype>, !numpy.ndarray<*:!numpy.any_dtype>) -> !basicpy.TupleType
|
||||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# CHECK: @returns_bool
|
# CHECK: @__torch__.returns_bool
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def returns_bool():
|
def returns_bool():
|
||||||
|
|
|
@ -9,7 +9,7 @@ import torch_mlir
|
||||||
|
|
||||||
mb = torch_mlir.ModuleBuilder()
|
mb = torch_mlir.ModuleBuilder()
|
||||||
|
|
||||||
# CHECK: @returns_none
|
# CHECK: @__torch__.returns_none
|
||||||
@mb.import_function
|
@mb.import_function
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def returns_none():
|
def returns_none():
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
#include "npcomp/Dialect/Torch/IR/OpInterfaces.h"
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ include "npcomp/Dialect/Torch/IR/TorchTypes.td"
|
||||||
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
include "npcomp/Dialect/Torch/IR/OpInterfaces.td"
|
||||||
include "mlir/IR/SymbolInterfaces.td"
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||||
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
|
||||||
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
class Torch_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
: Op<Torch_Dialect, mnemonic, traits> {
|
: Op<Torch_Dialect, mnemonic, traits> {
|
||||||
|
@ -457,10 +458,59 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", []> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", []> {
|
def Torch_Primunchecked_castOp : Torch_Op<"prim.unchecked_cast", [
|
||||||
|
NoSideEffect
|
||||||
|
]> {
|
||||||
let summary = "TorchScript prim::unchecked_cast op";
|
let summary = "TorchScript prim::unchecked_cast op";
|
||||||
// TODO: This seems to mostly be used for casting "optional" to the contained
|
let description = [{
|
||||||
// type. Verify that and tighten the verifier.
|
Refine a type to one of its subtypes.
|
||||||
|
|
||||||
|
For example, refine a type that was only statically known to be
|
||||||
|
Optional[T] to a T when we obtain static information that guarantees it.
|
||||||
|
|
||||||
|
The key observation here is that Optional[T] does not have a corresponding
|
||||||
|
runtime type (i.e. `c10::IValue` subclass). It represents a set of possible
|
||||||
|
concrete types which for `Optional[T]` is either `None` or a concrete
|
||||||
|
subtype of `T` (which in the simplest case is just `T`). In particular,
|
||||||
|
at runtime there is no way to distinguish `Optional[int]` from
|
||||||
|
`Optional[Optional[int]]`, because both are either `None` or `int`.
|
||||||
|
This differs from C++ std::optional.
|
||||||
|
|
||||||
|
The best documentation of this op is inspection of the code in
|
||||||
|
`torch/csrc/jit/frontend/ir_emitter.cpp`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
// TODO: When we model PyTorch's notion of subtyping, verify the types here.
|
||||||
|
let arguments = (ins AnyTorchType:$operand);
|
||||||
|
let results = (outs AnyTorchType:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict `:` type($operand) `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Additional ops used to model TorchScript's Graph's / Node's.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def Torch_DerefineOp : Torch_Op<"derefine", [
|
||||||
|
NoSideEffect
|
||||||
|
]> {
|
||||||
|
let summary = "De-refine a type";
|
||||||
|
let description = [{
|
||||||
|
In terms of IR structure, TorchScript allows types to vary in many
|
||||||
|
circumstances where MLIR requires pointer-identical types. In particular,
|
||||||
|
it is valid to pass any subtype in place of a type. For example, if an
|
||||||
|
`Optional[int]` is required somewhere in the IR, it is legal to pass a
|
||||||
|
value of just `int` (but not the other way around; see
|
||||||
|
`torch.prim.unchecked_cast`). In effect, every *use* can have a different
|
||||||
|
type.
|
||||||
|
|
||||||
|
This op bridges that impedance mismatch. This op allows casting a value
|
||||||
|
from one type to a type that it is a subtype of to model this behavior.
|
||||||
|
}];
|
||||||
|
|
||||||
|
// TODO: When we model PyTorch's notion of subtyping, verify the types here.
|
||||||
let arguments = (ins AnyTorchType:$operand);
|
let arguments = (ins AnyTorchType:$operand);
|
||||||
let results = (outs AnyTorchType:$result);
|
let results = (outs AnyTorchType:$result);
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ add_npcomp_dialect_library(NPCOMPTorchDialect
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRSupport
|
MLIRSupport
|
||||||
MLIRControlFlowInterfaces
|
MLIRControlFlowInterfaces
|
||||||
|
MLIRSideEffectInterfaces
|
||||||
NPCOMPBasicpyDialect
|
NPCOMPBasicpyDialect
|
||||||
NPCOMPNumpyDialect
|
NPCOMPNumpyDialect
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue