Add `!torch.tuple<T1, T2>` type.

This further eliminates the need for the `basicpy` dependency.

This required adding `torch.prim.TupleConstruct` to replace
`basicpy.build_tuple`.
pull/223/head
Sean Silva 2021-06-14 18:06:38 -07:00
parent ea1dd1cd90
commit 92ee0fa98f
13 changed files with 135 additions and 34 deletions

@ -1 +1 @@
Subproject commit cbd0054b9eb17ec48f0702e3828209646c8f5ebd
Subproject commit 853a614864754cd4b000f03a7ab8fbba103d6177

View File

@ -279,13 +279,16 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
}
if (ivalue.isTuple()) {
auto list = ivalue.toTuple()->elements();
std::vector<MlirValue> elems;
std::vector<MlirValue> operands;
std::vector<MlirType> types;
for (const c10::IValue &elem : list) {
elems.push_back(importIValue(elem));
MlirValue operand = importIValue(elem);
operands.push_back(operand);
types.push_back(mlirValueGetType(operand));
}
MlirOperation operation =
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
npcompBasicpyTupleTypeGet(context), elems);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.TupleConstruct", loc,
npcompTorchTupleTypeGet(context, types.size(), types.data()), operands);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTensor()) {

View File

@ -84,9 +84,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
switch (kind) {
case c10::prim::ListUnpack:
case c10::prim::ListConstruct:
case c10::prim::TupleConstruct: {
createAndMapTrivialNode(node,
"torch.prim." + std::string(kind.toUnqualString()));
return;
}
case c10::prim::GetAttr:
case c10::prim::SetAttr: {
createAndMapNodeWithAttribute(
@ -96,14 +98,6 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
}
}
// Ops trivially lowered through `basicpy` dialect.
switch (kind) {
case c10::prim::TupleConstruct: {
createAndMapTrivialNode(node, "basicpy.build_tuple");
return;
}
}
if (kind == c10::prim::Constant) {
auto output = node->output();
MlirOperation op;

View File

@ -185,8 +185,13 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
loc, torchType->cast<c10::ListType>()->getElementType()));
}
case TypeKind::TupleType: {
// TODO: Don't lose the element type information.
return npcompBasicpyTupleTypeGet(context);
std::vector<MlirType> containedTypes;
for (const c10::TypePtr &type :
torchType->cast<c10::TupleType>()->containedTypes()) {
containedTypes.push_back(mapFromTorchType(loc, type));
}
return npcompTorchTupleTypeGet(context, containedTypes.size(),
containedTypes.data());
}
case TypeKind::StringType: {
return npcompBasicpyBytesTypeGet(context);

View File

@ -20,9 +20,9 @@ class TestModule(torch.nn.Module):
# CHECK: }
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
# CHECK: %[[TUPLE:.*]] = basicpy.build_tuple %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.TupleType
# CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[N1]], %[[N2]] : i64, i64
# CHECK: torch.nn_module {
# CHECK: torch.slot "t", %[[TUPLE]] : !basicpy.TupleType
# CHECK: torch.slot "t", %[[TUPLE]] : !torch.tuple<i64, i64>
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">

View File

@ -65,8 +65,8 @@ def prim_unchecked_cast(i: typing.Optional[int]):
return i
# CHECK-LABEL: func @__torch__.prim_TupleUnpack(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<i64, i64>) -> i64 {
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !torch.tuple<i64, i64> -> i64, i64
# CHECK: return %[[RET]]#0 : i64
@mb.import_function
@torch.jit.script
@ -75,12 +75,12 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
return val
# CHECK-LABEL: func @__torch__.prim_TupleIndex(
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !basicpy.TupleType, i64 -> i64
# CHECK: return %[[RET]] : i64
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<!torch.tensor, !torch.tensor>) -> !torch.tensor {
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !torch.tuple<!torch.tensor, !torch.tensor>, i64 -> !torch.tensor
# CHECK: return %[[RET]] : !torch.tensor
@mb.import_function
@torch.jit.script
def prim_TupleIndex(tup: typing.Tuple[int, int]):
def prim_TupleIndex(tup: typing.Tuple[torch.Tensor, torch.Tensor]):
return tup[0]
# CHECK-LABEL: func @__torch__.prim_ListUnpack(
@ -121,28 +121,28 @@ def prim_device(x):
return x.device
# CHECK-LABEL: func @__torch__.prim_min(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.tuple<i64, i64, i64> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list<i64>
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list<i64> -> i64
# CHECK: %[[MIN2:.*]] = torch.prim.min.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list<i64>
# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !torch.list<i64> -> i64
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MIN1]], %[[MIN2]], %[[MIN3]] : (i64, i64, i64) -> !basicpy.TupleType
# CHECK: return %[[RET]] : !basicpy.TupleType
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[MIN1]], %[[MIN2]], %[[MIN3]] : i64, i64, i64
# CHECK: return %[[RET]] : !torch.tuple<i64, i64, i64>
@mb.import_function
@torch.jit.script
def prim_min(x: int):
return min(x), min(x,x), min(x, x, x)
# CHECK-LABEL: func @__torch__.prim_max(
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.tuple<i64, i64, i64> {
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list<i64>
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list<i64> -> i64
# CHECK: %[[MAX2:.*]] = torch.prim.max.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list<i64>
# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !torch.list<i64> -> i64
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MAX1]], %[[MAX2]], %[[MAX3]] : (i64, i64, i64) -> !basicpy.TupleType
# CHECK: return %[[RET]] : !basicpy.TupleType
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[MAX1]], %[[MAX2]], %[[MAX3]] : i64, i64, i64
# CHECK: return %[[RET]] : !torch.tuple<i64, i64, i64>
@mb.import_function
@torch.jit.script
def prim_max(x: int):

View File

@ -11,9 +11,9 @@ mb = torch_mlir.ModuleBuilder()
# CHECK-LABEL: func @__torch__.f(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !basicpy.TupleType {
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !basicpy.TupleType
# CHECK: return %[[RET]] : !basicpy.TupleType
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor> {
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] : !torch.tensor, !torch.tensor
# CHECK: return %[[RET]] : !torch.tuple<!torch.tensor, !torch.tensor>
@mb.import_function
@torch.jit.script

View File

@ -37,6 +37,18 @@ bool npcompTypeIsATorchOptional(MlirType t);
/// Gets the !torch.optional<T> type with subtype T.
MlirType npcompTorchOptionalTypeGet(MlirType containedType);
//===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.tuple type
bool npcompTypeIsATorchTuple(MlirType t);
/// Gets the !torch.tuple type with contained types `containedTypes`.
MlirType npcompTorchTupleTypeGet(MlirContext context,
intptr_t numContainedTypes,
MlirType const *containedTypes);
//===----------------------------------------------------------------------===//
// torch.list<T> type.
//===----------------------------------------------------------------------===//

View File

@ -322,6 +322,29 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
}];
}
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
NoSideEffect,
TypesMatchWith<"contained types correspond to operand types",
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))">
]> {
let summary = "TorchScript prim::TupleConstruct op";
let description = [{
Note: This op does not allow trivial type refinement, because the
operand types and the result types must be in correspondence.
}];
let arguments = (ins
Variadic<AnyTorchType>:$elements
);
let results = (outs
Torch_TupleType:$result
);
let assemblyFormat = [{
$elements attr-dict `:` type($elements)
}];
}
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
NoSideEffect,
AllowsTypeRefinement,

View File

@ -232,6 +232,16 @@ def Torch_ListType : Torch_TypeWithContainedType<"List", "list"> {
}];
}
def Torch_TupleType : Torch_Type<"Tuple", "tuple"> {
let summary = "!torch.tuple<T1, T2, T3>";
let description = [{
Tuple type with 0-N ordered contained types.
}];
let parameters = (ins
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
);
}
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
let summary = "Torch device";
}
@ -329,7 +339,7 @@ def AnyTorchType : AnyTypeOf<[
AnyTorchBoolType,
AnyTorchScalarType,
AnyTorchTensorType,
Basicpy_TupleType,
Torch_TupleType,
Basicpy_BytesType,
Torch_NnModuleType,
Torch_NoneType,

View File

@ -41,6 +41,24 @@ MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
return wrap(Torch::OptionalType::get(unwrap(containedType)));
}
//===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchTuple(MlirType t) {
return unwrap(t).isa<Torch::TupleType>();
}
MlirType npcompTorchTupleTypeGet(MlirContext context,
intptr_t numContainedTypes,
MlirType const *containedTypes) {
return wrap(Torch::TupleType::get(
unwrap(context),
llvm::to_vector<6>(
llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); }))));
}
//===----------------------------------------------------------------------===//
// torch.list<T> type.
//===----------------------------------------------------------------------===//

View File

@ -10,11 +10,40 @@
#include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//
Type Torch::TupleType::parse(MLIRContext *context, DialectAsmParser &parser) {
if (parser.parseLess())
return Type();
if (!parser.parseOptionalGreater())
return Torch::TupleType::get(context, {});
SmallVector<Type> containedTypes;
do {
Type containedType;
if (parser.parseType(containedType))
return Type();
containedTypes.push_back(containedType);
} while (!parser.parseOptionalComma());
if (parser.parseGreater())
return Type();
return Torch::TupleType::get(context, containedTypes);
}
void Torch::TupleType::print(::mlir::DialectAsmPrinter &printer) const {
printer << "tuple<";
llvm::interleaveComma(getContainedTypes(), printer);
printer << ">";
}
//===----------------------------------------------------------------------===//
// BaseTensorType
//===----------------------------------------------------------------------===//

View File

@ -41,6 +41,13 @@ func private @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk>
// CHECK: @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: @tuple.empty() -> !torch.tuple<>
func private @tuple.empty() -> !torch.tuple<>
// CHECK: @tuple.one_element() -> !torch.tuple<!torch.tensor>
func private @tuple.one_element() -> !torch.tuple<!torch.tensor>
// CHECK: @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
func private @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK-LABEL: func @torch.tensor() {
func @torch.tensor() {
// CHECK: torch.tensor(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>