mirror of https://github.com/llvm/torch-mlir
Add `!torch.tuple<T1, T2>` type.
This further eliminates the need for the `basicpy` dependency. This required adding `torch.prim.TupleConstruct` to replace `basicpy.build_tuple`.pull/223/head
parent
ea1dd1cd90
commit
92ee0fa98f
|
@ -1 +1 @@
|
|||
Subproject commit cbd0054b9eb17ec48f0702e3828209646c8f5ebd
|
||||
Subproject commit 853a614864754cd4b000f03a7ab8fbba103d6177
|
|
@ -279,13 +279,16 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
}
|
||||
if (ivalue.isTuple()) {
|
||||
auto list = ivalue.toTuple()->elements();
|
||||
std::vector<MlirValue> elems;
|
||||
std::vector<MlirValue> operands;
|
||||
std::vector<MlirType> types;
|
||||
for (const c10::IValue &elem : list) {
|
||||
elems.push_back(importIValue(elem));
|
||||
MlirValue operand = importIValue(elem);
|
||||
operands.push_back(operand);
|
||||
types.push_back(mlirValueGetType(operand));
|
||||
}
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
|
||||
npcompBasicpyTupleTypeGet(context), elems);
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.TupleConstruct", loc,
|
||||
npcompTorchTupleTypeGet(context, types.size(), types.data()), operands);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isTensor()) {
|
||||
|
|
|
@ -84,9 +84,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
switch (kind) {
|
||||
case c10::prim::ListUnpack:
|
||||
case c10::prim::ListConstruct:
|
||||
case c10::prim::TupleConstruct: {
|
||||
createAndMapTrivialNode(node,
|
||||
"torch.prim." + std::string(kind.toUnqualString()));
|
||||
return;
|
||||
}
|
||||
case c10::prim::GetAttr:
|
||||
case c10::prim::SetAttr: {
|
||||
createAndMapNodeWithAttribute(
|
||||
|
@ -96,14 +98,6 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
}
|
||||
}
|
||||
|
||||
// Ops trivially lowered through `basicpy` dialect.
|
||||
switch (kind) {
|
||||
case c10::prim::TupleConstruct: {
|
||||
createAndMapTrivialNode(node, "basicpy.build_tuple");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (kind == c10::prim::Constant) {
|
||||
auto output = node->output();
|
||||
MlirOperation op;
|
||||
|
|
|
@ -185,8 +185,13 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::TupleType: {
|
||||
// TODO: Don't lose the element type information.
|
||||
return npcompBasicpyTupleTypeGet(context);
|
||||
std::vector<MlirType> containedTypes;
|
||||
for (const c10::TypePtr &type :
|
||||
torchType->cast<c10::TupleType>()->containedTypes()) {
|
||||
containedTypes.push_back(mapFromTorchType(loc, type));
|
||||
}
|
||||
return npcompTorchTupleTypeGet(context, containedTypes.size(),
|
||||
containedTypes.data());
|
||||
}
|
||||
case TypeKind::StringType: {
|
||||
return npcompBasicpyBytesTypeGet(context);
|
||||
|
|
|
@ -20,9 +20,9 @@ class TestModule(torch.nn.Module):
|
|||
# CHECK: }
|
||||
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
|
||||
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
|
||||
# CHECK: %[[TUPLE:.*]] = basicpy.build_tuple %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.TupleType
|
||||
# CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[N1]], %[[N2]] : i64, i64
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.slot "t", %[[TUPLE]] : !basicpy.TupleType
|
||||
# CHECK: torch.slot "t", %[[TUPLE]] : !torch.tuple<i64, i64>
|
||||
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
|
||||
|
||||
|
||||
|
|
|
@ -65,8 +65,8 @@ def prim_unchecked_cast(i: typing.Optional[int]):
|
|||
return i
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_TupleUnpack(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
|
||||
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !basicpy.TupleType -> i64, i64
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<i64, i64>) -> i64 {
|
||||
# CHECK: %[[RET:.*]]:2 = torch.prim.TupleUnpack %[[ARG]] : !torch.tuple<i64, i64> -> i64, i64
|
||||
# CHECK: return %[[RET]]#0 : i64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
@ -75,12 +75,12 @@ def prim_TupleUnpack(tup: typing.Tuple[int, int]):
|
|||
return val
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_TupleIndex(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.TupleType) -> i64 {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !basicpy.TupleType, i64 -> i64
|
||||
# CHECK: return %[[RET]] : i64
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.tuple<!torch.tensor, !torch.tensor>) -> !torch.tensor {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.TupleIndex %[[ARG]], %[[IDX:.*]] : !torch.tuple<!torch.tensor, !torch.tensor>, i64 -> !torch.tensor
|
||||
# CHECK: return %[[RET]] : !torch.tensor
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_TupleIndex(tup: typing.Tuple[int, int]):
|
||||
def prim_TupleIndex(tup: typing.Tuple[torch.Tensor, torch.Tensor]):
|
||||
return tup[0]
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_ListUnpack(
|
||||
|
@ -121,28 +121,28 @@ def prim_device(x):
|
|||
return x.device
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_min(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.tuple<i64, i64, i64> {
|
||||
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list<i64>
|
||||
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list<i64> -> i64
|
||||
# CHECK: %[[MIN2:.*]] = torch.prim.min.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
|
||||
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list<i64>
|
||||
# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !torch.list<i64> -> i64
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MIN1]], %[[MIN2]], %[[MIN3]] : (i64, i64, i64) -> !basicpy.TupleType
|
||||
# CHECK: return %[[RET]] : !basicpy.TupleType
|
||||
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[MIN1]], %[[MIN2]], %[[MIN3]] : i64, i64, i64
|
||||
# CHECK: return %[[RET]] : !torch.tuple<i64, i64, i64>
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_min(x: int):
|
||||
return min(x), min(x,x), min(x, x, x)
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_max(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !torch.tuple<i64, i64, i64> {
|
||||
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list<i64>
|
||||
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list<i64> -> i64
|
||||
# CHECK: %[[MAX2:.*]] = torch.prim.max.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
|
||||
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list<i64>
|
||||
# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !torch.list<i64> -> i64
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MAX1]], %[[MAX2]], %[[MAX3]] : (i64, i64, i64) -> !basicpy.TupleType
|
||||
# CHECK: return %[[RET]] : !basicpy.TupleType
|
||||
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[MAX1]], %[[MAX2]], %[[MAX3]] : i64, i64, i64
|
||||
# CHECK: return %[[RET]] : !torch.tuple<i64, i64, i64>
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def prim_max(x: int):
|
||||
|
|
|
@ -11,9 +11,9 @@ mb = torch_mlir.ModuleBuilder()
|
|||
|
||||
# CHECK-LABEL: func @__torch__.f(
|
||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !basicpy.TupleType {
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !basicpy.TupleType
|
||||
# CHECK: return %[[RET]] : !basicpy.TupleType
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.tuple<!torch.tensor, !torch.tensor> {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.TupleConstruct %[[T0]], %[[T1]] : !torch.tensor, !torch.tensor
|
||||
# CHECK: return %[[RET]] : !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
|
|
@ -37,6 +37,18 @@ bool npcompTypeIsATorchOptional(MlirType t);
|
|||
/// Gets the !torch.optional<T> type with subtype T.
|
||||
MlirType npcompTorchOptionalTypeGet(MlirType containedType);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tuple<T1, T2, T3> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.tuple type
|
||||
bool npcompTypeIsATorchTuple(MlirType t);
|
||||
|
||||
/// Gets the !torch.tuple type with contained types `containedTypes`.
|
||||
MlirType npcompTorchTupleTypeGet(MlirContext context,
|
||||
intptr_t numContainedTypes,
|
||||
MlirType const *containedTypes);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.list<T> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -322,6 +322,29 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
||||
NoSideEffect,
|
||||
TypesMatchWith<"contained types correspond to operand types",
|
||||
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))">
|
||||
]> {
|
||||
let summary = "TorchScript prim::TupleConstruct op";
|
||||
let description = [{
|
||||
Note: This op does not allow trivial type refinement, because the
|
||||
operand types and the result types must be in correspondence.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyTorchType>:$elements
|
||||
);
|
||||
let results = (outs
|
||||
Torch_TupleType:$result
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$elements attr-dict `:` type($elements)
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
||||
NoSideEffect,
|
||||
AllowsTypeRefinement,
|
||||
|
|
|
@ -232,6 +232,16 @@ def Torch_ListType : Torch_TypeWithContainedType<"List", "list"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_TupleType : Torch_Type<"Tuple", "tuple"> {
|
||||
let summary = "!torch.tuple<T1, T2, T3>";
|
||||
let description = [{
|
||||
Tuple type with 0-N ordered contained types.
|
||||
}];
|
||||
let parameters = (ins
|
||||
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
|
||||
);
|
||||
}
|
||||
|
||||
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
|
||||
let summary = "Torch device";
|
||||
}
|
||||
|
@ -329,7 +339,7 @@ def AnyTorchType : AnyTypeOf<[
|
|||
AnyTorchBoolType,
|
||||
AnyTorchScalarType,
|
||||
AnyTorchTensorType,
|
||||
Basicpy_TupleType,
|
||||
Torch_TupleType,
|
||||
Basicpy_BytesType,
|
||||
Torch_NnModuleType,
|
||||
Torch_NoneType,
|
||||
|
|
|
@ -41,6 +41,24 @@ MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
|
|||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tuple<T1, T2, T3> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchTuple(MlirType t) {
|
||||
return unwrap(t).isa<Torch::TupleType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchTupleTypeGet(MlirContext context,
|
||||
intptr_t numContainedTypes,
|
||||
MlirType const *containedTypes) {
|
||||
return wrap(Torch::TupleType::get(
|
||||
unwrap(context),
|
||||
llvm::to_vector<6>(
|
||||
llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes),
|
||||
[](MlirType t) { return unwrap(t); }))));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.list<T> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -10,11 +10,40 @@
|
|||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TupleType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type Torch::TupleType::parse(MLIRContext *context, DialectAsmParser &parser) {
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
if (!parser.parseOptionalGreater())
|
||||
return Torch::TupleType::get(context, {});
|
||||
|
||||
SmallVector<Type> containedTypes;
|
||||
do {
|
||||
Type containedType;
|
||||
if (parser.parseType(containedType))
|
||||
return Type();
|
||||
containedTypes.push_back(containedType);
|
||||
} while (!parser.parseOptionalComma());
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
return Torch::TupleType::get(context, containedTypes);
|
||||
}
|
||||
|
||||
void Torch::TupleType::print(::mlir::DialectAsmPrinter &printer) const {
|
||||
printer << "tuple<";
|
||||
llvm::interleaveComma(getContainedTypes(), printer);
|
||||
printer << ">";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BaseTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -41,6 +41,13 @@ func private @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk>
|
|||
// CHECK: @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
|
||||
func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32>
|
||||
|
||||
// CHECK: @tuple.empty() -> !torch.tuple<>
|
||||
func private @tuple.empty() -> !torch.tuple<>
|
||||
// CHECK: @tuple.one_element() -> !torch.tuple<!torch.tensor>
|
||||
func private @tuple.one_element() -> !torch.tuple<!torch.tensor>
|
||||
// CHECK: @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
func private @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
|
||||
// CHECK-LABEL: func @torch.tensor() {
|
||||
func @torch.tensor() {
|
||||
// CHECK: torch.tensor(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
|
||||
|
|
Loading…
Reference in New Issue