Generate MLIR with shape information via LTC frontend (#742)

pull/1125/head
Antonio Kim 2022-05-26 12:53:15 -07:00 committed by Henry Tu
parent a605fe279c
commit 615ff1d31c
13 changed files with 286 additions and 123 deletions

View File

@ -158,7 +158,7 @@ class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
size_t i = 0;
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, arguments, kwarguments);
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
return {schema.aten_name}_out;

View File

@ -42,7 +42,7 @@ def main(device):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(5, 5)
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
out = self.fc1(x)

View File

@ -36,9 +36,6 @@ TorchMlirComputation::TorchMlirComputation(
const std::shared_ptr<torch::jit::Graph>& graph)
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
graph_(graph), num_results_(graph_->outputs().size()) {
// TODO(henrytu): Save parameter shape information.
for (torch::jit::Value* input : graph_->inputs()) {
parameter_names_.push_back(input->debugName());
}
@ -144,22 +141,18 @@ void TorchMlirLoweringContext::AddParameter(
ComputationPtr TorchMlirLoweringContext::Build() {
PRINT_FUNCTION();
// Insert return values into graph.
for (torch::jit::Value* output : root_tuple_) {
graph_->block()->registerOutput(output);
}
// Create jit::Function from jit::Graph.
c10::QualifiedName name("graph");
auto cu = std::make_shared<torch::jit::CompilationUnit>();
// IMPORTANT: We pass in a COPY of the graph into create_function, since it
// may get mutated in the process.
auto jit_fn = cu->create_function(std::move(name), std::move(graph_->copy()));
// Generate MLIR.
MlirOperation func_op =
torch_mlir::importJitFunctionAsFuncOp(mlir_context_, jit_fn);
MlirOperation func_op = torch_mlir::importJitFunctionAsFuncOp(
/*context=*/mlir_context_,
/*function=*/generate_jit_fn().get(),
/*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; },
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});
// TODO(henrytu): Inject tensor shapes into func_op
return std::make_shared<TorchMlirComputation>(func_op, mlir_context_, graph_);
}
@ -224,6 +217,14 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
TORCH_CHECK(
false, "Unhandled scalar type: ", c10::toString(scalar.type()));
}
} else {
// Save parameter shape information.
param->setType(torch::jit::TensorType::create(
/*scalar_type=*/data->shape().scalar_type(),
/*device=*/c10::nullopt,
/*sizes=*/c10::VaryingShape<int64_t>(data->shape().sizes()),
/*strides=*/c10::VaryingShape<int64_t>(),
/*requires_grad=*/c10::nullopt));
}
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
@ -245,6 +246,46 @@ size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) {
return root_tuple_.size() - 1;
}
// Sync vector of c10::Argument with type specified from parallel list of
// jit::Value. There must be a 1:1 map between elements of args and values.
std::vector<c10::Argument> sync_argument_types(
const std::vector<c10::Argument>& args,
c10::ArrayRef<torch::jit::Value*> values) {
TORCH_CHECK(
args.size() == values.size(),
"Expected 1:1 mapping between list of c10::Argument and jit::Value! Got ",
args.size(), ":", values.size(), " instead!");
std::vector<c10::Argument> updated_args;
for (unsigned i = 0; i < args.size(); i++) {
updated_args.push_back(args[i].cloneWithType(values[i]->type()));
}
return updated_args;
}
std::unique_ptr<torch::jit::Function>
TorchMlirLoweringContext::generate_jit_fn() const {
// IMPORTANT: We pass in a COPY of the graph into create_function, since it
// may get mutated in the process.
auto fn = std::make_unique<torch::jit::GraphFunction>(
c10::QualifiedName("graph"), graph_->copy(), nullptr);
c10::FunctionSchema schema = fn->getSchema();
// When constructing the default schema of a jit::GraphFunction, input and
// output shapes are stripped (via call to unshapedType(...)); however,
// since we want to have shape information in our MLIR, we'll add it back.
std::vector<c10::Argument> arguments =
sync_argument_types(schema.arguments(), graph_->inputs());
std::vector<c10::Argument> returns =
sync_argument_types(schema.returns(), graph_->outputs());
fn->setSchema(schema.cloneWithArguments(arguments).cloneWithReturns(returns));
return fn;
}
void TorchMlirLoweringContext::RegisterMlirDialects() {
// https://reviews.llvm.org/D88162
mlirRegisterAllDialects(mlir_context_);

View File

@ -122,6 +122,10 @@ private:
size_t AddResult(torch::jit::Value* op);
// Creates a jit::Function from the current jit::Graph. Input and output
// type information is patched to include shape.
std::unique_ptr<torch::jit::Function> generate_jit_fn() const;
void RegisterMlirDialects();
std::shared_ptr<torch::jit::Graph> graph_;

View File

@ -44,6 +44,76 @@
namespace torch {
namespace lazy {
TorchMlirOpVector LowerTorchMlirBuiltin(
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
const std::vector<c10::TypePtr> tensor_types,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments) {
auto builtin =
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get());
CHECK(sv);
TorchMlirOpVector results;
if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
// Op returns multiple values.
const auto tuple_call_result = sv->asTuple({}, *function);
for (const auto& tuple_component : tuple_call_result) {
auto tuple_component_sv =
dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
results.push_back(tuple_component_sv->getValue());
}
} else {
// Op returns single value.
results.push_back(sv->getValue());
}
// Insert known tensor type information.
unsigned tensor_type_idx = 0;
for (jit::Value* value : results) {
if (value->type()->kind() == c10::TypeKind::TensorType) {
TORCH_CHECK(
tensor_type_idx < tensor_types.size(),
"Tensor corresponding to JIT SSA value %", value->debugName(),
" corresponds to result #", tensor_type_idx, ", but we only have ",
tensor_types.size(), " known types!");
value->setType(tensor_types[tensor_type_idx++]);
}
}
// Ensure that we use up all the known tensor type information available.
TORCH_CHECK(
tensor_type_idx == tensor_types.size(), tensor_type_idx,
" known types were injected into jit::Value, but ", tensor_types.size(),
" were provided from lazy::Node!");
return results;
}
TorchMlirOpVector LowerTorchMlirBuiltin(
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
const c10::ArrayRef<Shape> result_shapes,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments) {
std::vector<c10::TypePtr> tensor_types;
// Generate types with fixed tensor shape information.
for (const Shape& shape : result_shapes) {
tensor_types.push_back(torch::jit::TensorType::create(
/*scalar_type=*/shape.scalar_type(),
/*device=*/c10::nullopt,
/*sizes=*/c10::VaryingShape<int64_t>(shape.sizes()),
/*strides=*/c10::VaryingShape<int64_t>(),
/*requires_grad=*/c10::nullopt));
}
return LowerTorchMlirBuiltin(
function, sym, tensor_types, arguments, kwarguments);
}
class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
public:
TorchMlirNodeLowering(
@ -189,12 +259,20 @@ public:
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(
function_, node->op().op, arguments, kwarguments);
function_, node->op().op, node->shapes(), arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const std::vector<torch::jit::NamedValue>& arguments,
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(function_, sym, arguments, kwarguments);
return LowerTorchMlirBuiltin(
function_, sym, result_shapes, arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const std::vector<c10::TypePtr> types,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(function_, sym, types, arguments, kwarguments);
}
TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
@ -222,7 +300,7 @@ public:
dest_arguments.emplace_back(node->stride());
dest_arguments.emplace_back(node->storage_offset());
TorchMlirOpVector as_strided_out =
LowerBuiltin(at::aten::as_strided, dest_arguments);
LowerBuiltin(at::aten::as_strided, node->shapes(), dest_arguments);
CHECK_EQ(as_strided_out.size(), 1);
torch::jit::Value* as_strided = as_strided_out.front();
GenerateCopy(as_strided, loctx()->GetOutputOp(input_op));
@ -266,7 +344,7 @@ public:
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->dtype());
return LowerBuiltin(at::aten::to, arguments);
return LowerBuiltin(at::aten::to, node->shapes(), arguments);
}
TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) {
@ -383,13 +461,16 @@ public:
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->output_size());
return LowerBuiltin(at::aten::reshape, arguments);
return LowerBuiltin(at::aten::reshape, node->shapes(), arguments);
}
torch::jit::Value* GenerateClone(torch::jit::Value* val) {
std::vector<torch::jit::NamedValue> clone_arguments;
clone_arguments.emplace_back(val);
TorchMlirOpVector cloned = LowerBuiltin(at::aten::clone, clone_arguments);
// Type of cloned value should be identical to the original one.
TorchMlirOpVector cloned =
LowerBuiltin(at::aten::clone, {val->type()}, clone_arguments);
CHECK_EQ(cloned.size(), 1);
return cloned.front();
}
@ -398,7 +479,9 @@ public:
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(destination);
arguments.emplace_back(source);
LowerBuiltin(at::aten::copy_, arguments);
LowerBuiltin(
at::aten::copy_, c10::ArrayRef<Shape>({/*shape goes here*/}),
arguments);
}
torch::jit::Value* GenerateSlice(
@ -410,7 +493,9 @@ public:
arguments.emplace_back(start);
arguments.emplace_back(end);
arguments.emplace_back(step);
TorchMlirOpVector selected = LowerBuiltin(at::aten::slice, arguments);
TorchMlirOpVector selected = LowerBuiltin(
at::aten::slice, c10::ArrayRef<Shape>({/*shape goes here*/}),
arguments);
CHECK_EQ(selected.size(), 1);
return selected.front();
}
@ -424,29 +509,5 @@ TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) {
"TorchMlirNodeLowering",
static_cast<torch::lazy::TorchMlirLoweringContext*>(loctx));
}
TorchMlirOpVector LowerTorchMlirBuiltin(
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments) {
auto builtin =
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get());
CHECK(sv);
if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
const auto tuple_call_result = sv->asTuple({}, *function);
TorchMlirOpVector tuple_result;
for (const auto& tuple_component : tuple_call_result) {
auto tuple_component_sv =
dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
tuple_result.push_back(tuple_component_sv->getValue());
}
return tuple_result;
}
return {sv->getValue()};
}
} // namespace lazy
} // namespace torch

View File

@ -23,6 +23,7 @@ typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;
TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin(
TorchMlirFunction function, c10::Symbol sym,
const c10::ArrayRef<Shape> result_shapes,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {});

View File

@ -22,12 +22,13 @@ using namespace torch_mlir;
MlirOperation torch_mlir::importJitFunctionAsFuncOp(
MlirContext context, torch::jit::Function *function,
std::function<MlirAttribute(int)> getArgAttribute) {
std::function<MlirAttribute(int)> getArgAttribute,
const ImportOptions &importOptions) {
// Useful for debugging:
// graph->dump();
MlirLocation loc = mlirLocationUnknownGet(context);
MlirType functionType =
getFunctionTypeFromSchema(context, function->getSchema());
getFunctionTypeFromSchema(context, function->getSchema(), importOptions);
// Use the function's qualified name from the compilation unit.
// This is a stable linkage name that matches Python module lookup
// conventions (see compilation unit import in IValueImporter for more details
@ -69,7 +70,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
};
MlirBlock block = importBlock(
context, torch::jit::toGraphFunction(*function).graph()->block(),
createTerminator, inputTypes);
createTerminator, inputTypes, importOptions);
mlirRegionAppendOwnedBlock(bodyRegion, block);
return func;
}

View File

@ -12,6 +12,7 @@
#include <memory>
#include "import_options.h"
#include "node_importer.h"
#include "mlir-c/IR.h"
@ -41,7 +42,8 @@ namespace torch_mlir {
TORCH_API MlirOperation importJitFunctionAsFuncOp(
MlirContext context, torch::jit::Function *function,
std::function<MlirAttribute(int)> getArgAttribute =
[](int) -> MlirAttribute { return {nullptr}; });
[](int) -> MlirAttribute { return {nullptr}; },
const ImportOptions &importOptions = {});
} // namespace torch_mlir

View File

@ -0,0 +1,29 @@
//===- import_options.h -----------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
#define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H
namespace torch_mlir {
// Common import options across importers. We define this as a struct to avoid
// an unstructured proliferation of different kinds of ways to control different
// parts of the import process.
struct ImportOptions {
// If this is set to true, then all tensors in the program can be assumed to
// have value semantics. This can happen, for example, when coming from
// LazyTensorCore since conversion to value semantics has already happened at
// a higher level there before we see the program. For
// calling-convention-impacting decisions, this flag should be interpreted as
// a requirement to use a value-semantic tensor type (!torch.vtensor) in
// signatures.
bool assumeTensorsHaveValueSemantics = false;
};
} // namespace torch_mlir
#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H

View File

@ -33,15 +33,18 @@ class NodeImporter {
public:
NodeImporter(MlirContext context) : context(context) {}
void importNode(Node *node, MlirBlock appendToBlock);
void importNode(Node *node, MlirBlock appendToBlock,
const ImportOptions &importOptions = {});
MlirBlock importBlock(
Block *jitBlock, CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt);
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
const ImportOptions &importOptions = {});
private:
MlirBlock
createBlockFor(Block *jitBlock,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes);
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
const ImportOptions &importOptions = {});
void mapValue(Value *jitValue, MlirValue value);
void mapResults(Node *node, MlirOperation operation);
MlirValue lookupMappedValue(Value *jitValue);
@ -76,27 +79,27 @@ rearrangeDictConstructInputs(std::vector<MlirValue> &inputs) {
return rearranged;
}
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
const ImportOptions &importOptions) {
MlirLocation loc = getMlirLocationFromNode(context, node);
auto kind = node->kind();
auto createAndMapTrivialNode = [&](Node *node, const std::string &opName,
InputsTransformFn t) {
std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs());
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, opName, loc,
getMlirTypesFromValues(loc, node->outputs()),
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, opName, loc,
getMlirTypesFromValues(loc, node->outputs(), importOptions),
t ? t(mappedInputs) : mappedInputs);
mapResults(node, operation);
};
auto createAndMapNodeWithAttribute = [&](Node *node,
const std::string &opName,
const std::string &attrName,
auto createAndMapNodeWithAttribute =
[&](Node *node, const std::string &opName, const std::string &attrName,
MlirAttribute attr) {
MlirOperation operation =
createMlirOperationAtEnd(appendToBlock, opName, loc,
getMlirTypesFromValues(loc, node->outputs()),
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, opName, loc,
getMlirTypesFromValues(loc, node->outputs(), importOptions),
lookupMappedValues(node->inputs()),
toMlirNamedAttribute(attrName.c_str(), attr));
mapResults(node, operation);
@ -105,9 +108,9 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
// Trivial ops with schema.
auto maybeSchema = node->maybeSchema();
if (maybeSchema) {
MlirOperation operation =
createOperationFromSchema(appendToBlock, loc, node->schema(),
getMlirTypesFromValues(loc, node->outputs()),
MlirOperation operation = createOperationFromSchema(
appendToBlock, loc, node->schema(),
getMlirTypesFromValues(loc, node->outputs(), importOptions),
lookupMappedValues(node->inputs()));
mapResults(node, operation);
return;
@ -178,13 +181,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
} else if (output->type()->cast<c10::IntType>()) {
op = createMlirOperation(
"torch.constant.int", loc,
getMlirTypeFromTorchType(loc, output->type()),
getMlirTypeFromTorchType(loc, output->type(), importOptions),
toMlirNamedAttribute("value",
importAttribute(loc, node, c10::attr::value)));
} else if (output->type()->cast<c10::FloatType>()) {
op = createMlirOperation(
"torch.constant.float", loc,
getMlirTypeFromTorchType(loc, output->type()),
getMlirTypeFromTorchType(loc, output->type(), importOptions),
toMlirNamedAttribute("value",
importAttribute(loc, node, c10::attr::value)));
} else if (output->type()->cast<c10::StringType>()) {
@ -202,7 +205,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
} else if (output->type()->cast<c10::DeviceObjType>()) {
op = createMlirOperation(
"torch.constant.device", loc,
getMlirTypeFromTorchType(loc, output->type()),
getMlirTypeFromTorchType(loc, output->type(), importOptions),
toMlirNamedAttribute(
"value", mlirStringAttrGet(context, toMlirStringRef(node->s(
c10::attr::value)))));
@ -211,16 +214,15 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
const std::string &symName = function->qualname().qualifiedName();
op = createMlirOperation(
"func.constant", loc,
getFunctionTypeFromSchema(context, function->getSchema()),
getFunctionTypeFromSchema(context, function->getSchema(),
importOptions),
toMlirNamedAttribute(
"value",
mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName))));
} else if (output->type()->cast<c10::ListType>()) {
ClassAnnotator dummyAnnotator;
MlirValue listValue = importIValue(node->ival(c10::attr::value),
appendToBlock,
context,
dummyAnnotator);
MlirValue listValue = importIValue(
node->ival(c10::attr::value), appendToBlock, context, dummyAnnotator);
mapResults(node, mlirOpResultGetOwner(listValue));
return; // Early return, since `importIValue` already added op to block.
} else {
@ -237,7 +239,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
if (kind == c10::prim::Loop) {
std::vector<MlirType> resultTypes =
getMlirTypesFromValues(loc, node->outputs());
getMlirTypesFromValues(loc, node->outputs(), importOptions);
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.Loop", loc, resultTypes,
lookupMappedValues(node->inputs().slice(0, 2)),
@ -260,13 +262,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
};
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], createTerminator));
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions));
return;
}
if (kind == c10::prim::If) {
std::vector<MlirType> resultTypes =
getMlirTypesFromValues(loc, node->outputs());
getMlirTypesFromValues(loc, node->outputs(), importOptions);
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.If", loc, lookupMappedValue(node->input()),
resultTypes, mlirRegionCreate(), mlirRegionCreate());
@ -281,10 +283,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
};
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 0),
importBlock(node->blocks()[0], createTerminator));
importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions));
mlirRegionAppendOwnedBlock(
mlirOperationGetRegion(operation, 1),
importBlock(node->blocks()[1], createTerminator));
importBlock(node->blocks()[1], createTerminator, c10::nullopt, importOptions));
return;
}
@ -293,14 +295,14 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
auto methodName = node->s(c10::attr::name);
torch::jit::Function *function = classType->findMethod(methodName);
MlirType calleeType =
getFunctionTypeFromSchema(context, function->getSchema());
getFunctionTypeFromSchema(context, function->getSchema(), importOptions);
std::vector<MlirType> expectedTypes;
for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) {
expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i));
}
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "torch.prim.CallMethod", loc,
getMlirTypesFromValues(loc, node->outputs()),
getMlirTypesFromValues(loc, node->outputs(), importOptions),
adjustStaticInformationForValues(
appendToBlock, loc, lookupMappedValues(node->inputs()),
expectedTypes, /*userAllowsRefinement=*/false),
@ -315,11 +317,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
torch::jit::Block *calleeEntryBlock =
torch::jit::toGraphFunction(*functionType->function()).graph()->block();
auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) {
return getMlirTypeFromTorchType(loc, v->type());
return getMlirTypeFromTorchType(loc, v->type(), importOptions);
});
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, "func.call_indirect", loc,
getMlirTypesFromValues(loc, node->outputs()),
getMlirTypesFromValues(loc, node->outputs(), importOptions),
lookupMappedValue(node->input(0)),
adjustStaticInformationForValues(
appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)),
@ -339,10 +341,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
MlirBlock NodeImporter::importBlock(
Block *jitBlock, CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
MlirBlock block = createBlockFor(jitBlock, blockArgTypes);
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
const ImportOptions &importOptions) {
MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions);
for (Node *node : jitBlock->nodes()) {
importNode(node, block);
importNode(node, block, importOptions);
}
Node *returnNode = jitBlock->return_node();
createTerminator(lookupMappedValues(returnNode->inputs()), block);
@ -350,11 +353,12 @@ MlirBlock NodeImporter::importBlock(
}
MlirBlock NodeImporter::createBlockFor(
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
Block *jitBlock, c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
const ImportOptions &importOptions) {
Node *paramNode = jitBlock->param_node();
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
std::vector<MlirType> paramNodeTypes =
getMlirTypesFromValues(loc, paramNode->outputs());
getMlirTypesFromValues(loc, paramNode->outputs(), importOptions);
if (!blockArgTypes)
blockArgTypes = paramNodeTypes;
else
@ -405,7 +409,8 @@ NodeImporter::lookupMappedValues(c10::ArrayRef<Value *> values) {
MlirBlock
torch_mlir::importBlock(MlirContext context, Block *jitBlock,
CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes) {
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes,
const ImportOptions &importOptions) {
NodeImporter importer(context);
return importer.importBlock(jitBlock, createTerminator, blockArgTypes);
return importer.importBlock(jitBlock, createTerminator, blockArgTypes, importOptions);
}

View File

@ -10,6 +10,8 @@
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_NODE_IMPORTER_H
#define TORCHMLIRJITIRIMPORTER_CSRC_NODE_IMPORTER_H
#include "import_options.h"
#include <memory>
#include "mlir-c/IR.h"
@ -37,7 +39,8 @@ using CreateTerminatorFn =
MlirBlock importBlock(
MlirContext context, torch::jit::Block *jitBlock,
CreateTerminatorFn createTerminator,
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt);
c10::optional<c10::ArrayRef<MlirType>> blockArgTypes = c10::nullopt,
const ImportOptions &importOptions = {});
} // namespace torch_mlir

View File

@ -117,14 +117,20 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
throw mlir_diagnostic_emitted();
}
MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
const c10::TypePtr &torchType) {
MlirType
torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
const c10::TypePtr &torchType,
const ImportOptions &importOptions) {
MlirContext context = mlirLocationGetContext(loc);
using c10::TypeKind;
auto kind = torchType->kind();
switch (kind) {
case TypeKind::TensorType: {
auto tensorType = torchType->cast<c10::TensorType>();
auto getMlirTensorType = importOptions.assumeTensorsHaveValueSemantics
? torchMlirTorchValueTensorTypeGet
: torchMlirTorchNonValueTensorTypeGet;
// Element type.
MlirType elementType = {nullptr};
if (tensorType->scalarType()) {
@ -137,7 +143,7 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
auto &sizes = tensorType->symbolic_sizes();
if (!sizes.rank()) {
// Unranked.
return torchMlirTorchNonValueTensorTypeGet(context,
return getMlirTensorType(context,
/*numSizes=*/0,
/*optionalSizes=*/nullptr,
/*optionalDtype=*/
@ -158,7 +164,7 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
// case. So use a dummy data pointer.
int64_t dummy;
int64_t *dimsData = dims.size() == 0 ? &dummy : dims.data();
return torchMlirTorchNonValueTensorTypeGet(context, dims.size(),
return getMlirTensorType(context, dims.size(),
/*optionalSizes=*/dimsData,
/*optionalDtype=*/
elementType);
@ -180,13 +186,15 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
}
case TypeKind::OptionalType: {
return torchMlirTorchOptionalTypeGet(getMlirTypeFromTorchType(
loc, torchType->cast<c10::OptionalType>()->getElementType()));
loc, torchType->cast<c10::OptionalType>()->getElementType(),
importOptions));
}
case TypeKind::TupleType: {
std::vector<MlirType> containedTypes;
for (const c10::TypePtr &type :
torchType->cast<c10::TupleType>()->containedTypes()) {
containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
containedTypes.push_back(
getMlirTypeFromTorchType(loc, type, importOptions));
}
return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
containedTypes.data());
@ -202,13 +210,14 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
}
case TypeKind::ListType: {
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
loc, torchType->cast<c10::ListType>()->getElementType()));
loc, torchType->cast<c10::ListType>()->getElementType(),
importOptions));
}
case TypeKind::DictType: {
auto dictType = torchType->cast<c10::DictType>();
return torchMlirTorchDictTypeGet(
getMlirTypeFromTorchType(loc, dictType->getKeyType()),
getMlirTypeFromTorchType(loc, dictType->getValueType()));
getMlirTypeFromTorchType(loc, dictType->getKeyType(), importOptions),
getMlirTypeFromTorchType(loc, dictType->getValueType(), importOptions));
}
case TypeKind::NoneType: {
return torchMlirTorchNoneTypeGet(context);
@ -243,10 +252,11 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
MlirType
torch_mlir::getFunctionTypeFromSchema(MlirContext context,
const c10::FunctionSchema &schema) {
const c10::FunctionSchema &schema,
const ImportOptions &importOptions) {
MlirLocation loc = mlirLocationUnknownGet(context);
auto mapType = [&](const c10::TypePtr &torchType) {
MlirType type = getMlirTypeFromTorchType(loc, torchType);
MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions);
if (mlirTypeIsNull(type)) {
std::stringstream msg;
msg << "unsupported type in function schema: '"
@ -383,10 +393,11 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context,
std::vector<MlirType>
torch_mlir::getMlirTypesFromValues(MlirLocation loc,
c10::ArrayRef<torch::jit::Value *> values) {
c10::ArrayRef<torch::jit::Value *> values,
const ImportOptions &importOptions) {
std::vector<MlirType> ret;
for (auto value : values) {
MlirType t = getMlirTypeFromTorchType(loc, value->type());
MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions);
if (mlirTypeIsNull(t))
throw mlir_diagnostic_emitted("unsupported type");
ret.push_back(t);

View File

@ -10,6 +10,8 @@
#ifndef TORCHMLIRJITIRIMPORTER_CSRC_TORCH_TO_MLIR_UTILS_H
#define TORCHMLIRJITIRIMPORTER_CSRC_TORCH_TO_MLIR_UTILS_H
#include "import_options.h"
#include <memory>
#include "mlir-c/IR.h"
@ -42,14 +44,16 @@ MlirType getMlirTypeForTorchScalarType(MlirLocation loc,
/// Maps a torch type to a corresponding MlirType. Returns a null type
/// on failure and emits a diagnostic.
MlirType getMlirTypeFromTorchType(MlirLocation loc,
const c10::TypePtr &torchType);
const c10::TypePtr &torchType,
const ImportOptions &importOptions = {});
/// Creates a FunctionType suitable for expressing the signature of `schema`.
///
/// This can differ from the type inferred from the block of a
/// torch::jit::Function due to derefinement and refinement of tensor types.
MlirType getFunctionTypeFromSchema(MlirContext context,
const c10::FunctionSchema &schema);
const c10::FunctionSchema &schema,
const ImportOptions &importOptions = {});
/// Creates an appropriate MlirAttribute that holds the same values as `tensor`.
MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor,
@ -63,7 +67,8 @@ MlirLocation getMlirLocationFromNode(MlirContext context,
std::vector<MlirType>
getMlirTypesFromValues(MlirLocation loc,
c10::ArrayRef<torch::jit::Value *> values);
c10::ArrayRef<torch::jit::Value *> values,
const ImportOptions &importOptions = {});
std::vector<MlirValue> adjustStaticInformationForValues(
MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef<MlirValue> values,