mirror of https://github.com/llvm/torch-mlir
Generate MLIR with shape information via LTC frontend (#742)
parent
a605fe279c
commit
615ff1d31c
|
@ -158,7 +158,7 @@ class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
|
|||
size_t i = 0;
|
||||
{emplace_arguments_str}
|
||||
{emplace_kwarguments}
|
||||
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, arguments, kwarguments);
|
||||
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
|
||||
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
|
||||
|
||||
return {schema.aten_name}_out;
|
||||
|
|
|
@ -42,7 +42,7 @@ def main(device):
|
|||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(5, 5)
|
||||
self.fc1 = torch.nn.Linear(5, 10)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc1(x)
|
||||
|
|
|
@ -36,9 +36,6 @@ TorchMlirComputation::TorchMlirComputation(
|
|||
const std::shared_ptr<torch::jit::Graph>& graph)
|
||||
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
|
||||
graph_(graph), num_results_(graph_->outputs().size()) {
|
||||
|
||||
// TODO(henrytu): Save parameter shape information.
|
||||
|
||||
for (torch::jit::Value* input : graph_->inputs()) {
|
||||
parameter_names_.push_back(input->debugName());
|
||||
}
|
||||
|
@ -144,22 +141,18 @@ void TorchMlirLoweringContext::AddParameter(
|
|||
ComputationPtr TorchMlirLoweringContext::Build() {
|
||||
PRINT_FUNCTION();
|
||||
|
||||
// Insert return values into graph.
|
||||
for (torch::jit::Value* output : root_tuple_) {
|
||||
graph_->block()->registerOutput(output);
|
||||
}
|
||||
|
||||
// Create jit::Function from jit::Graph.
|
||||
c10::QualifiedName name("graph");
|
||||
auto cu = std::make_shared<torch::jit::CompilationUnit>();
|
||||
// IMPORTANT: We pass in a COPY of the graph into create_function, since it
|
||||
// may get mutated in the process.
|
||||
auto jit_fn = cu->create_function(std::move(name), std::move(graph_->copy()));
|
||||
|
||||
// Generate MLIR.
|
||||
MlirOperation func_op =
|
||||
torch_mlir::importJitFunctionAsFuncOp(mlir_context_, jit_fn);
|
||||
MlirOperation func_op = torch_mlir::importJitFunctionAsFuncOp(
|
||||
/*context=*/mlir_context_,
|
||||
/*function=*/generate_jit_fn().get(),
|
||||
/*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; },
|
||||
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});
|
||||
|
||||
// TODO(henrytu): Inject tensor shapes into func_op
|
||||
return std::make_shared<TorchMlirComputation>(func_op, mlir_context_, graph_);
|
||||
}
|
||||
|
||||
|
@ -224,6 +217,14 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
|
|||
TORCH_CHECK(
|
||||
false, "Unhandled scalar type: ", c10::toString(scalar.type()));
|
||||
}
|
||||
} else {
|
||||
// Save parameter shape information.
|
||||
param->setType(torch::jit::TensorType::create(
|
||||
/*scalar_type=*/data->shape().scalar_type(),
|
||||
/*device=*/c10::nullopt,
|
||||
/*sizes=*/c10::VaryingShape<int64_t>(data->shape().sizes()),
|
||||
/*strides=*/c10::VaryingShape<int64_t>(),
|
||||
/*requires_grad=*/c10::nullopt));
|
||||
}
|
||||
|
||||
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
|
||||
|
@ -245,6 +246,46 @@ size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) {
|
|||
return root_tuple_.size() - 1;
|
||||
}
|
||||
|
||||
// Sync vector of c10::Argument with type specified from parallel list of
|
||||
// jit::Value. There must be a 1:1 map between elements of args and values.
|
||||
std::vector<c10::Argument> sync_argument_types(
|
||||
const std::vector<c10::Argument>& args,
|
||||
c10::ArrayRef<torch::jit::Value*> values) {
|
||||
TORCH_CHECK(
|
||||
args.size() == values.size(),
|
||||
"Expected 1:1 mapping between list of c10::Argument and jit::Value! Got ",
|
||||
args.size(), ":", values.size(), " instead!");
|
||||
|
||||
std::vector<c10::Argument> updated_args;
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
updated_args.push_back(args[i].cloneWithType(values[i]->type()));
|
||||
}
|
||||
|
||||
return updated_args;
|
||||
}
|
||||
|
||||
std::unique_ptr<torch::jit::Function>
|
||||
TorchMlirLoweringContext::generate_jit_fn() const {
|
||||
// IMPORTANT: We pass in a COPY of the graph into create_function, since it
|
||||
// may get mutated in the process.
|
||||
auto fn = std::make_unique<torch::jit::GraphFunction>(
|
||||
c10::QualifiedName("graph"), graph_->copy(), nullptr);
|
||||
|
||||
c10::FunctionSchema schema = fn->getSchema();
|
||||
|
||||
// When constructing the default schema of a jit::GraphFunction, input and
|
||||
// output shapes are stripped (via call to unshapedType(...)); however,
|
||||
// since we want to have shape information in our MLIR, we'll add it back.
|
||||
std::vector<c10::Argument> arguments =
|
||||
sync_argument_types(schema.arguments(), graph_->inputs());
|
||||
std::vector<c10::Argument> returns =
|
||||
sync_argument_types(schema.returns(), graph_->outputs());
|
||||
|
||||
fn->setSchema(schema.cloneWithArguments(arguments).cloneWithReturns(returns));
|
||||
|
||||
return fn;
|
||||
}
|
||||
|
||||
void TorchMlirLoweringContext::RegisterMlirDialects() {
|
||||
// https://reviews.llvm.org/D88162
|
||||
mlirRegisterAllDialects(mlir_context_);
|
||||
|
|
|
@ -122,6 +122,10 @@ private:
|
|||
|
||||
size_t AddResult(torch::jit::Value* op);
|
||||
|
||||
// Creates a jit::Function from the current jit::Graph. Input and output
|
||||
// type information is patched to include shape.
|
||||
std::unique_ptr<torch::jit::Function> generate_jit_fn() const;
|
||||
|
||||
void RegisterMlirDialects();
|
||||
|
||||
std::shared_ptr<torch::jit::Graph> graph_;
|
||||
|
|
|
@ -44,6 +44,76 @@
|
|||
namespace torch {
|
||||
namespace lazy {
|
||||
|
||||
TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
|
||||
const std::vector<c10::TypePtr> tensor_types,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
||||
auto builtin =
|
||||
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
|
||||
auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
|
||||
auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
|
||||
auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get());
|
||||
CHECK(sv);
|
||||
|
||||
TorchMlirOpVector results;
|
||||
if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
|
||||
// Op returns multiple values.
|
||||
const auto tuple_call_result = sv->asTuple({}, *function);
|
||||
for (const auto& tuple_component : tuple_call_result) {
|
||||
auto tuple_component_sv =
|
||||
dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
|
||||
results.push_back(tuple_component_sv->getValue());
|
||||
}
|
||||
} else {
|
||||
// Op returns single value.
|
||||
results.push_back(sv->getValue());
|
||||
}
|
||||
|
||||
// Insert known tensor type information.
|
||||
unsigned tensor_type_idx = 0;
|
||||
for (jit::Value* value : results) {
|
||||
if (value->type()->kind() == c10::TypeKind::TensorType) {
|
||||
TORCH_CHECK(
|
||||
tensor_type_idx < tensor_types.size(),
|
||||
"Tensor corresponding to JIT SSA value %", value->debugName(),
|
||||
" corresponds to result #", tensor_type_idx, ", but we only have ",
|
||||
tensor_types.size(), " known types!");
|
||||
|
||||
value->setType(tensor_types[tensor_type_idx++]);
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that we use up all the known tensor type information available.
|
||||
TORCH_CHECK(
|
||||
tensor_type_idx == tensor_types.size(), tensor_type_idx,
|
||||
" known types were injected into jit::Value, but ", tensor_types.size(),
|
||||
" were provided from lazy::Node!");
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
|
||||
const c10::ArrayRef<Shape> result_shapes,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
||||
std::vector<c10::TypePtr> tensor_types;
|
||||
|
||||
// Generate types with fixed tensor shape information.
|
||||
for (const Shape& shape : result_shapes) {
|
||||
tensor_types.push_back(torch::jit::TensorType::create(
|
||||
/*scalar_type=*/shape.scalar_type(),
|
||||
/*device=*/c10::nullopt,
|
||||
/*sizes=*/c10::VaryingShape<int64_t>(shape.sizes()),
|
||||
/*strides=*/c10::VaryingShape<int64_t>(),
|
||||
/*requires_grad=*/c10::nullopt));
|
||||
}
|
||||
|
||||
return LowerTorchMlirBuiltin(
|
||||
function, sym, tensor_types, arguments, kwarguments);
|
||||
}
|
||||
|
||||
class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
|
||||
public:
|
||||
TorchMlirNodeLowering(
|
||||
|
@ -189,12 +259,20 @@ public:
|
|||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(
|
||||
function_, node->op().op, arguments, kwarguments);
|
||||
function_, node->op().op, node->shapes(), arguments, kwarguments);
|
||||
}
|
||||
TorchMlirOpVector LowerBuiltin(
|
||||
c10::Symbol sym, const std::vector<torch::jit::NamedValue>& arguments,
|
||||
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(function_, sym, arguments, kwarguments);
|
||||
return LowerTorchMlirBuiltin(
|
||||
function_, sym, result_shapes, arguments, kwarguments);
|
||||
}
|
||||
TorchMlirOpVector LowerBuiltin(
|
||||
c10::Symbol sym, const std::vector<c10::TypePtr> types,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
||||
return LowerTorchMlirBuiltin(function_, sym, types, arguments, kwarguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
|
||||
|
@ -222,7 +300,7 @@ public:
|
|||
dest_arguments.emplace_back(node->stride());
|
||||
dest_arguments.emplace_back(node->storage_offset());
|
||||
TorchMlirOpVector as_strided_out =
|
||||
LowerBuiltin(at::aten::as_strided, dest_arguments);
|
||||
LowerBuiltin(at::aten::as_strided, node->shapes(), dest_arguments);
|
||||
CHECK_EQ(as_strided_out.size(), 1);
|
||||
torch::jit::Value* as_strided = as_strided_out.front();
|
||||
GenerateCopy(as_strided, loctx()->GetOutputOp(input_op));
|
||||
|
@ -266,7 +344,7 @@ public:
|
|||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.emplace_back(node->dtype());
|
||||
return LowerBuiltin(at::aten::to, arguments);
|
||||
return LowerBuiltin(at::aten::to, node->shapes(), arguments);
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) {
|
||||
|
@ -383,13 +461,16 @@ public:
|
|||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
|
||||
arguments.push_back(node->output_size());
|
||||
return LowerBuiltin(at::aten::reshape, arguments);
|
||||
return LowerBuiltin(at::aten::reshape, node->shapes(), arguments);
|
||||
}
|
||||
|
||||
torch::jit::Value* GenerateClone(torch::jit::Value* val) {
|
||||
std::vector<torch::jit::NamedValue> clone_arguments;
|
||||
clone_arguments.emplace_back(val);
|
||||
TorchMlirOpVector cloned = LowerBuiltin(at::aten::clone, clone_arguments);
|
||||
|
||||
// Type of cloned value should be identical to the original one.
|
||||
TorchMlirOpVector cloned =
|
||||
LowerBuiltin(at::aten::clone, {val->type()}, clone_arguments);
|
||||
CHECK_EQ(cloned.size(), 1);
|
||||
return cloned.front();
|
||||
}
|
||||
|
@ -398,7 +479,9 @@ public:
|
|||
std::vector<torch::jit::NamedValue> arguments;
|
||||
arguments.emplace_back(destination);
|
||||
arguments.emplace_back(source);
|
||||
LowerBuiltin(at::aten::copy_, arguments);
|
||||
LowerBuiltin(
|
||||
at::aten::copy_, c10::ArrayRef<Shape>({/*shape goes here*/}),
|
||||
arguments);
|
||||
}
|
||||
|
||||
torch::jit::Value* GenerateSlice(
|
||||
|
@ -410,7 +493,9 @@ public:
|
|||
arguments.emplace_back(start);
|
||||
arguments.emplace_back(end);
|
||||
arguments.emplace_back(step);
|
||||
TorchMlirOpVector selected = LowerBuiltin(at::aten::slice, arguments);
|
||||
TorchMlirOpVector selected = LowerBuiltin(
|
||||
at::aten::slice, c10::ArrayRef<Shape>({/*shape goes here*/}),
|
||||
arguments);
|
||||
CHECK_EQ(selected.size(), 1);
|
||||
return selected.front();
|
||||
}
|
||||
|
@ -424,29 +509,5 @@ TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) {
|
|||
"TorchMlirNodeLowering",
|
||||
static_cast<torch::lazy::TorchMlirLoweringContext*>(loctx));
|
||||
}
|
||||
|
||||
TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
||||
auto builtin =
|
||||
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
|
||||
auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
|
||||
auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
|
||||
auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get());
|
||||
CHECK(sv);
|
||||
if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
|
||||
const auto tuple_call_result = sv->asTuple({}, *function);
|
||||
TorchMlirOpVector tuple_result;
|
||||
for (const auto& tuple_component : tuple_call_result) {
|
||||
auto tuple_component_sv =
|
||||
dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
|
||||
tuple_result.push_back(tuple_component_sv->getValue());
|
||||
}
|
||||
return tuple_result;
|
||||
}
|
||||
return {sv->getValue()};
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
|
|
|
@ -23,6 +23,7 @@ typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;
|
|||
|
||||
TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||
TorchMlirFunction function, c10::Symbol sym,
|
||||
const c10::ArrayRef<Shape> result_shapes,
|
||||
const std::vector<torch::jit::NamedValue>& arguments,
|
||||
const std::vector<torch::jit::NamedValue>& kwarguments = {});
|
||||
|
||||
|
|
|
@ -22,12 +22,13 @@ using namespace torch_mlir;
|
|||
|
||||
MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
||||
MlirContext context, torch::jit::Function *function,
|
||||
std::function<MlirAttribute(int)> getArgAttribute) {
|
||||
std::function<MlirAttribute(int)> getArgAttribute,
|
||||
const ImportOptions &importOptions) {
|
||||
// Useful for debugging:
|
||||
// graph->dump();
|
||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
MlirType functionType =
|
||||
getFunctionTypeFromSchema(context, function->getSchema());
|
||||
getFunctionTypeFromSchema(context, function->getSchema(), importOptions);
|
||||
// Use the function's qualified name from the compilation unit.
|
||||
// This is a stable linkage name that matches Python module lookup
|
||||
// conventions (see compilation unit import in IValueImporter for more details
|
||||
|
@ -69,7 +70,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
|
|||
};
|
||||
MlirBlock block = importBlock(
|
||||
context, torch::jit::toGraphFunction(*function).graph()->block(),
|
||||
createTerminator, inputTypes);
|
||||
createTerminator, inputTypes, importOptions);
|
||||
mlirRegionAppendOwnedBlock(bodyRegion, block);
|
||||
return func;
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "import_options.h"
|
||||
#include "node_importer.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
@ -41,7 +42,8 @@ namespace torch_mlir {
|
|||
TORCH_API MlirOperation importJitFunctionAsFuncOp(
|
||||
MlirContext context, torch::jit::Function *function,
|
||||
std::function<MlirAttribute(int)> getArgAttribute =
|
||||
[](int) -> MlirAttribute { return {nullptr}; });
|
||||
[](int) -> MlirAttribute { return {nullptr}; },
|
||||
const ImportOptions &importOptions = {});
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
//===- import_options.h -----------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
|
||||
#define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
|
||||
|
||||
namespace torch_mlir {
|
||||
// Common import options across importers. We define this as a struct to avoid
|
||||
// an unstructured proliferation of different kinds of ways to control different
|
||||
// parts of the import process.
|
||||
struct ImportOptions {
|
||||
// If this is set to true, then all tensors in the program can be assumed to
|
||||
// have value semantics. This can happen, for example, when coming from
|
||||
// LazyTensorCore since conversion to value semantics has already happened at
|
||||
// a higher level there before we see the program. For
|
||||
// calling-convention-impacting decisions, this flag should be interpreted as
|
||||
// a requirement to use a value-semantic tensor type (!torch.vtensor) in
|
||||
// signatures.
|
||||
bool assumeTensorsHaveValueSemantics = false;
|
||||
};
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
|
|
@ -33,15 +33,18 @@ class NodeImporter {
|
|||
public:
|
||||
NodeImporter(MlirContext context) : context(context) {}
|
||||
|
||||
void importNode(Node *node, MlirBlock appendToBlock);
|
||||
void importNode(Node *node, MlirBlock appendToBlock,
|
||||
const ImportOptions &importOptions = {});
|
||||
MlirBlock importBlock(
|
||||
Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt);
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
||||
const ImportOptions &importOptions = {});
|
||||
|
||||
private:
|
||||
MlirBlock
|
||||
createBlockFor(Block *jitBlock,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes);
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||
const ImportOptions &importOptions = {});
|
||||
void mapValue(Value *jitValue, MlirValue value);
|
||||
void mapResults(Node *node, MlirOperation operation);
|
||||
MlirValue lookupMappedValue(Value *jitValue);
|
||||
|
@ -76,27 +79,27 @@ rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
|
|||
return rearranged;
|
||||
}
|
||||
|
||||
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
|
||||
const ImportOptions &importOptions) {
|
||||
MlirLocation loc = getMlirLocationFromNode(context, node);
|
||||
auto kind = node->kind();
|
||||
|
||||
auto createAndMapTrivialNode = [&](Node *node, const std::string &opName,
|
||||
InputsTransformFn t) {
|
||||
std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs());
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(appendToBlock, opName, loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, opName, loc,
|
||||
getMlirTypesFromValues(loc, node->outputs(), importOptions),
|
||||
t ? t(mappedInputs) : mappedInputs);
|
||||
mapResults(node, operation);
|
||||
};
|
||||
|
||||
auto createAndMapNodeWithAttribute = [&](Node *node,
|
||||
const std::string &opName,
|
||||
const std::string &attrName,
|
||||
auto createAndMapNodeWithAttribute =
|
||||
[&](Node *node, const std::string &opName, const std::string &attrName,
|
||||
MlirAttribute attr) {
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(appendToBlock, opName, loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, opName, loc,
|
||||
getMlirTypesFromValues(loc, node->outputs(), importOptions),
|
||||
lookupMappedValues(node->inputs()),
|
||||
toMlirNamedAttribute(attrName.c_str(), attr));
|
||||
mapResults(node, operation);
|
||||
|
@ -105,9 +108,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
// Trivial ops with schema.
|
||||
auto maybeSchema = node->maybeSchema();
|
||||
if (maybeSchema) {
|
||||
MlirOperation operation =
|
||||
createOperationFromSchema(appendToBlock, loc, node->schema(),
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
MlirOperation operation = createOperationFromSchema(
|
||||
appendToBlock, loc, node->schema(),
|
||||
getMlirTypesFromValues(loc, node->outputs(), importOptions),
|
||||
lookupMappedValues(node->inputs()));
|
||||
mapResults(node, operation);
|
||||
return;
|
||||
|
@ -178,13 +181,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
} else if (output->type()->cast<c10::IntType>()) {
|
||||
op = createMlirOperation(
|
||||
"torch.constant.int", loc,
|
||||
getMlirTypeFromTorchType(loc, output->type()),
|
||||
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
||||
toMlirNamedAttribute("value",
|
||||
importAttribute(loc, node, c10::attr::value)));
|
||||
} else if (output->type()->cast<c10::FloatType>()) {
|
||||
op = createMlirOperation(
|
||||
"torch.constant.float", loc,
|
||||
getMlirTypeFromTorchType(loc, output->type()),
|
||||
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
||||
toMlirNamedAttribute("value",
|
||||
importAttribute(loc, node, c10::attr::value)));
|
||||
} else if (output->type()->cast<c10::StringType>()) {
|
||||
|
@ -202,7 +205,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
} else if (output->type()->cast<c10::DeviceObjType>()) {
|
||||
op = createMlirOperation(
|
||||
"torch.constant.device", loc,
|
||||
getMlirTypeFromTorchType(loc, output->type()),
|
||||
getMlirTypeFromTorchType(loc, output->type(), importOptions),
|
||||
toMlirNamedAttribute(
|
||||
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
|
||||
c10::attr::value)))));
|
||||
|
@ -211,16 +214,15 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
const std::string &symName = function->qualname().qualifiedName();
|
||||
op = createMlirOperation(
|
||||
"func.constant", loc,
|
||||
getFunctionTypeFromSchema(context, function->getSchema()),
|
||||
getFunctionTypeFromSchema(context, function->getSchema(),
|
||||
importOptions),
|
||||
toMlirNamedAttribute(
|
||||
"value",
|
||||
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
|
||||
} else if (output->type()->cast<c10::ListType>()) {
|
||||
ClassAnnotator dummyAnnotator;
|
||||
MlirValue listValue = importIValue(node->ival(c10::attr::value),
|
||||
appendToBlock,
|
||||
context,
|
||||
dummyAnnotator);
|
||||
MlirValue listValue = importIValue(
|
||||
node->ival(c10::attr::value), appendToBlock, context, dummyAnnotator);
|
||||
mapResults(node, mlirOpResultGetOwner(listValue));
|
||||
return; // Early return, since `importIValue` already added op to block.
|
||||
} else {
|
||||
|
@ -237,7 +239,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
|
||||
if (kind == c10::prim::Loop) {
|
||||
std::vector<MlirType> resultTypes =
|
||||
getMlirTypesFromValues(loc, node->outputs());
|
||||
getMlirTypesFromValues(loc, node->outputs(), importOptions);
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.Loop", loc, resultTypes,
|
||||
lookupMappedValues(node->inputs().slice(0, 2)),
|
||||
|
@ -260,13 +262,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
};
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], createTerminator));
|
||||
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
||||
return;
|
||||
}
|
||||
|
||||
if (kind == c10::prim::If) {
|
||||
std::vector<MlirType> resultTypes =
|
||||
getMlirTypesFromValues(loc, node->outputs());
|
||||
getMlirTypesFromValues(loc, node->outputs(), importOptions);
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.If", loc, lookupMappedValue(node->input()),
|
||||
resultTypes, mlirRegionCreate(), mlirRegionCreate());
|
||||
|
@ -281,10 +283,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
};
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 0),
|
||||
importBlock(node->blocks()[0], createTerminator));
|
||||
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions));
|
||||
mlirRegionAppendOwnedBlock(
|
||||
mlirOperationGetRegion(operation, 1),
|
||||
importBlock(node->blocks()[1], createTerminator));
|
||||
importBlock(node->blocks()[1], createTerminator, c10::nullopt, importOptions));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -293,14 +295,14 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
auto methodName = node->s(c10::attr::name);
|
||||
torch::jit::Function *function = classType->findMethod(methodName);
|
||||
MlirType calleeType =
|
||||
getFunctionTypeFromSchema(context, function->getSchema());
|
||||
getFunctionTypeFromSchema(context, function->getSchema(), importOptions);
|
||||
std::vector<MlirType> expectedTypes;
|
||||
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
|
||||
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
|
||||
}
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "torch.prim.CallMethod", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
getMlirTypesFromValues(loc, node->outputs(), importOptions),
|
||||
adjustStaticInformationForValues(
|
||||
appendToBlock, loc, lookupMappedValues(node->inputs()),
|
||||
expectedTypes, /*userAllowsRefinement=*/false),
|
||||
|
@ -315,11 +317,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
torch::jit::Block *calleeEntryBlock =
|
||||
torch::jit::toGraphFunction(*functionType->function()).graph()->block();
|
||||
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
|
||||
return getMlirTypeFromTorchType(loc, v->type());
|
||||
return getMlirTypeFromTorchType(loc, v->type(), importOptions);
|
||||
});
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
appendToBlock, "func.call_indirect", loc,
|
||||
getMlirTypesFromValues(loc, node->outputs()),
|
||||
getMlirTypesFromValues(loc, node->outputs(), importOptions),
|
||||
lookupMappedValue(node->input(0)),
|
||||
adjustStaticInformationForValues(
|
||||
appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)),
|
||||
|
@ -339,10 +341,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
|
||||
MlirBlock NodeImporter::importBlock(
|
||||
Block *jitBlock, CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
|
||||
MlirBlock block = createBlockFor(jitBlock, blockArgTypes);
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||
const ImportOptions &importOptions) {
|
||||
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
|
||||
for (Node *node : jitBlock->nodes()) {
|
||||
importNode(node, block);
|
||||
importNode(node, block, importOptions);
|
||||
}
|
||||
Node *returnNode = jitBlock->return_node();
|
||||
createTerminator(lookupMappedValues(returnNode->inputs()), block);
|
||||
|
@ -350,11 +353,12 @@ MlirBlock NodeImporter::importBlock(
|
|||
}
|
||||
|
||||
MlirBlock NodeImporter::createBlockFor(
|
||||
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
|
||||
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||
const ImportOptions &importOptions) {
|
||||
Node *paramNode = jitBlock->param_node();
|
||||
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
||||
std::vector<MlirType> paramNodeTypes =
|
||||
getMlirTypesFromValues(loc, paramNode->outputs());
|
||||
getMlirTypesFromValues(loc, paramNode->outputs(), importOptions);
|
||||
if (!blockArgTypes)
|
||||
blockArgTypes = paramNodeTypes;
|
||||
else
|
||||
|
@ -405,7 +409,8 @@ NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
|
|||
MlirBlock
|
||||
torch_mlir::importBlock(MlirContext context, Block *jitBlock,
|
||||
CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
|
||||
const ImportOptions &importOptions) {
|
||||
NodeImporter importer(context);
|
||||
return importer.importBlock(jitBlock, createTerminator, blockArgTypes);
|
||||
return importer.importBlock(jitBlock, createTerminator, blockArgTypes, importOptions);
|
||||
}
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_NODE_IMPORTER_H
|
||||
#define TORCHMLIRJITIRIMPORTER_CSRC_NODE_IMPORTER_H
|
||||
|
||||
#include "import_options.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
@ -37,7 +39,8 @@ using CreateTerminatorFn =
|
|||
MlirBlock importBlock(
|
||||
MlirContext context, torch::jit::Block *jitBlock,
|
||||
CreateTerminatorFn createTerminator,
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt);
|
||||
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
|
||||
const ImportOptions &importOptions = {});
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
|
|
|
@ -117,14 +117,20 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
|
|||
throw mlir_diagnostic_emitted();
|
||||
}
|
||||
|
||||
MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||
const c10::TypePtr &torchType) {
|
||||
MlirType
|
||||
torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
||||
const c10::TypePtr &torchType,
|
||||
const ImportOptions &importOptions) {
|
||||
MlirContext context = mlirLocationGetContext(loc);
|
||||
using c10::TypeKind;
|
||||
auto kind = torchType->kind();
|
||||
switch (kind) {
|
||||
case TypeKind::TensorType: {
|
||||
auto tensorType = torchType->cast<c10::TensorType>();
|
||||
auto getMlirTensorType = importOptions.assumeTensorsHaveValueSemantics
|
||||
? torchMlirTorchValueTensorTypeGet
|
||||
: torchMlirTorchNonValueTensorTypeGet;
|
||||
|
||||
// Element type.
|
||||
MlirType elementType = {nullptr};
|
||||
if (tensorType->scalarType()) {
|
||||
|
@ -137,7 +143,7 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
|||
auto &sizes = tensorType->symbolic_sizes();
|
||||
if (!sizes.rank()) {
|
||||
// Unranked.
|
||||
return torchMlirTorchNonValueTensorTypeGet(context,
|
||||
return getMlirTensorType(context,
|
||||
/*numSizes=*/0,
|
||||
/*optionalSizes=*/nullptr,
|
||||
/*optionalDtype=*/
|
||||
|
@ -158,7 +164,7 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
|||
// case. So use a dummy data pointer.
|
||||
int64_t dummy;
|
||||
int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data();
|
||||
return torchMlirTorchNonValueTensorTypeGet(context, dims.size(),
|
||||
return getMlirTensorType(context, dims.size(),
|
||||
/*optionalSizes=*/dimsData,
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
|
@ -180,13 +186,15 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
|||
}
|
||||
case TypeKind::OptionalType: {
|
||||
return torchMlirTorchOptionalTypeGet(getMlirTypeFromTorchType(
|
||||
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||
loc, torchType->cast<c10::OptionalType>()->getElementType(),
|
||||
importOptions));
|
||||
}
|
||||
case TypeKind::TupleType: {
|
||||
std::vector<MlirType> containedTypes;
|
||||
for (const c10::TypePtr &type :
|
||||
torchType->cast<c10::TupleType>()->containedTypes()) {
|
||||
containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
|
||||
containedTypes.push_back(
|
||||
getMlirTypeFromTorchType(loc, type, importOptions));
|
||||
}
|
||||
return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
|
||||
containedTypes.data());
|
||||
|
@ -202,13 +210,14 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
|||
}
|
||||
case TypeKind::ListType: {
|
||||
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
|
||||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||
loc, torchType->cast<c10::ListType>()->getElementType(),
|
||||
importOptions));
|
||||
}
|
||||
case TypeKind::DictType: {
|
||||
auto dictType = torchType->cast<c10::DictType>();
|
||||
return torchMlirTorchDictTypeGet(
|
||||
getMlirTypeFromTorchType(loc, dictType->getKeyType()),
|
||||
getMlirTypeFromTorchType(loc, dictType->getValueType()));
|
||||
getMlirTypeFromTorchType(loc, dictType->getKeyType(), importOptions),
|
||||
getMlirTypeFromTorchType(loc, dictType->getValueType(), importOptions));
|
||||
}
|
||||
case TypeKind::NoneType: {
|
||||
return torchMlirTorchNoneTypeGet(context);
|
||||
|
@ -243,10 +252,11 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
|||
|
||||
MlirType
|
||||
torch_mlir::getFunctionTypeFromSchema(MlirContext context,
|
||||
const c10::FunctionSchema &schema) {
|
||||
const c10::FunctionSchema &schema,
|
||||
const ImportOptions &importOptions) {
|
||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
auto mapType = [&](const c10::TypePtr &torchType) {
|
||||
MlirType type = getMlirTypeFromTorchType(loc, torchType);
|
||||
MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions);
|
||||
if (mlirTypeIsNull(type)) {
|
||||
std::stringstream msg;
|
||||
msg << "unsupported type in function schema: '"
|
||||
|
@ -383,10 +393,11 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
|
|||
|
||||
std::vector<MlirType>
|
||||
torch_mlir::getMlirTypesFromValues(MlirLocation loc,
|
||||
c10::ArrayRef<torch::jit::Value *> values) {
|
||||
c10::ArrayRef<torch::jit::Value *> values,
|
||||
const ImportOptions &importOptions) {
|
||||
std::vector<MlirType> ret;
|
||||
for (auto value : values) {
|
||||
MlirType t = getMlirTypeFromTorchType(loc, value->type());
|
||||
MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions);
|
||||
if (mlirTypeIsNull(t))
|
||||
throw mlir_diagnostic_emitted("unsupported type");
|
||||
ret.push_back(t);
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_TORCH_TO_MLIR_UTILS_H
|
||||
#define TORCHMLIRJITIRIMPORTER_CSRC_TORCH_TO_MLIR_UTILS_H
|
||||
|
||||
#include "import_options.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
@ -42,14 +44,16 @@ MlirType getMlirTypeForTorchScalarType(MlirLocation loc,
|
|||
/// Maps a torch type to a corresponding MlirType. Returns a null type
|
||||
/// on failure and emits a diagnostic.
|
||||
MlirType getMlirTypeFromTorchType(MlirLocation loc,
|
||||
const c10::TypePtr &torchType);
|
||||
const c10::TypePtr &torchType,
|
||||
const ImportOptions &importOptions = {});
|
||||
|
||||
/// Creates a FunctionType suitable for expressing the signature of `schema`.
|
||||
///
|
||||
/// This can differ from the type inferred from the block of a
|
||||
/// torch::jit::Function due to derefinement and refinement of tensor types.
|
||||
MlirType getFunctionTypeFromSchema(MlirContext context,
|
||||
const c10::FunctionSchema &schema);
|
||||
const c10::FunctionSchema &schema,
|
||||
const ImportOptions &importOptions = {});
|
||||
|
||||
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
|
||||
MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor,
|
||||
|
@ -63,7 +67,8 @@ MlirLocation getMlirLocationFromNode(MlirContext context,
|
|||
|
||||
std::vector<MlirType>
|
||||
getMlirTypesFromValues(MlirLocation loc,
|
||||
c10::ArrayRef<torch::jit::Value *> values);
|
||||
c10::ArrayRef<torch::jit::Value *> values,
|
||||
const ImportOptions &importOptions = {});
|
||||
|
||||
std::vector<MlirValue> adjustStaticInformationForValues(
|
||||
MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values,
|
||||
|
|
Loading…
Reference in New Issue