mirror of https://github.com/llvm/torch-mlir
Add TorchList type and prim::ListConstruct #218
parent
370e3270ab
commit
e0ff5248fb
|
@ -468,7 +468,8 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
|||
for (IValue element : list) {
|
||||
elements.push_back(mapIValueToMlirValue(loc, element));
|
||||
}
|
||||
return funcBuilder->buildList(loc, elements);
|
||||
return funcBuilder->buildList(loc,
|
||||
typeMapper.mapFromTorchType(loc, list.elementType()), elements);
|
||||
}
|
||||
if (ival.isNone()) {
|
||||
return funcBuilder->getNoneConstant(loc);
|
||||
|
@ -511,7 +512,9 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
|||
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||
}
|
||||
if (ival.isList()) {
|
||||
return npcompListTypeGet(funcBuilder->getContext());
|
||||
return npcompListTypeGet(
|
||||
typeMapper.mapFromTorchType(
|
||||
loc, ival.toList().elementType()));
|
||||
}
|
||||
if (ival.isNone()) {
|
||||
return npcompNoneTypeGet(funcBuilder->getContext());
|
||||
|
|
|
@ -130,10 +130,10 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
|||
return insertConstantOp(OpBuilder(context).createStdConstant(loc, value));
|
||||
}
|
||||
|
||||
MlirValue FuncBuilder::buildList(MlirLocation loc,
|
||||
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
|
||||
std::vector<MlirValue> &elements) {
|
||||
MlirType resultType = npcompListTypeGet(context);
|
||||
OperationStateHolder state{"basicpy.build_list", loc};
|
||||
MlirType resultType = npcompListTypeGet(elementType);
|
||||
OperationStateHolder state{"torch.prim.ListConstruct", loc};
|
||||
mlirOperationStateAddResults(state, 1, &resultType);
|
||||
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
||||
MlirOperation op = state.createOperation();
|
||||
|
|
|
@ -117,7 +117,8 @@ public:
|
|||
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
|
||||
|
||||
/// Builds a list with the given elements
|
||||
MlirValue buildList(MlirLocation loc, std::vector<MlirValue> &elements);
|
||||
MlirValue buildList(MlirLocation loc, MlirType elementType,
|
||||
std::vector<MlirValue> &elements);
|
||||
|
||||
private:
|
||||
FuncBuilder(MlirContext context, MlirOperation funcOp,
|
||||
|
|
|
@ -270,8 +270,10 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
elems.push_back(importIValue(elem));
|
||||
}
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(importBlock, "basicpy.build_list", loc,
|
||||
npcompListTypeGet(context), elems);
|
||||
createMlirOperationAtEnd(importBlock, "torch.prim.ListConstruct", loc,
|
||||
npcompListTypeGet(
|
||||
typeMapper.mapFromTorchType(
|
||||
loc, list.elementType())), elems);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isTuple()) {
|
||||
|
|
|
@ -82,6 +82,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
// Builtin interpreter ops with no operator/schema.
|
||||
switch (kind) {
|
||||
case c10::prim::ListUnpack:
|
||||
case c10::prim::ListConstruct:
|
||||
createAndMapTrivialNode(node,
|
||||
"torch.prim." + std::string(kind.toUnqualString()));
|
||||
return;
|
||||
|
@ -96,10 +97,6 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
|
||||
// Ops trivially lowered through `basicpy` dialect.
|
||||
switch (kind) {
|
||||
case c10::prim::ListConstruct: {
|
||||
createAndMapTrivialNode(node, "basicpy.build_list");
|
||||
return;
|
||||
}
|
||||
case c10::prim::TupleConstruct: {
|
||||
createAndMapTrivialNode(node, "basicpy.build_tuple");
|
||||
return;
|
||||
|
|
|
@ -181,8 +181,9 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
return npcompBoolTypeGet(context);
|
||||
}
|
||||
case TypeKind::ListType: {
|
||||
// TODO: Don't lose the element type information.
|
||||
return npcompListTypeGet(context);
|
||||
return npcompListTypeGet(
|
||||
mapFromTorchType(
|
||||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::TupleType: {
|
||||
// TODO: Don't lose the element type information.
|
||||
|
|
|
@ -224,7 +224,7 @@ TORCH_TYPE_TO_ODS_TYPE = {
|
|||
"bool": "AnyTorchBoolType",
|
||||
"bool[]": "AnyTorchBoolListType",
|
||||
"float": "AnyFloat",
|
||||
"t[]": "Basicpy_ListType",
|
||||
"t[]": "AnyTorchListType",
|
||||
"t": "AnyTorchType",
|
||||
"t1": "AnyTorchType",
|
||||
"t2": "AnyTorchType",
|
||||
|
@ -455,6 +455,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::len.t : (t[]) -> (int)",
|
||||
has_folder=True,
|
||||
has_canonicalizer=True)
|
||||
emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True)
|
||||
emit("aten::_set_item.t : (t[], int, t) -> (t[])")
|
||||
|
||||
|
||||
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
||||
|
|
|
@ -40,7 +40,7 @@ with mb.capture_function("conv_cat", [inputs, target]) as f:
|
|||
|
||||
# CHECK: "aten.convolution"
|
||||
# CHECK: "aten.convolution"
|
||||
# CHECK: basicpy.build_list
|
||||
# CHECK: torch.prim.ListConstruct
|
||||
# CHECK: "aten._cat"
|
||||
# CHECK: "aten._log_softmax"
|
||||
# CHECK: "aten.nll_loss2d_forward"
|
||||
|
|
|
@ -45,11 +45,11 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
|
|||
# CHECK: %[[VAL_10:.*]] = constant 1 : i64
|
||||
# CHECK: %[[VAL_11:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>) : !torch.tensor<[4,16,3,3],f32>
|
||||
# CHECK: %[[VAL_12:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4xf32>) : !torch.tensor<[4],f32>
|
||||
# CHECK: %[[VAL_13:.*]] = basicpy.build_list %[[VAL_1]], %[[VAL_2]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_14:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_16:.*]] = basicpy.build_list %[[VAL_8]], %[[VAL_9]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[VAL_17:.*]] = torch.operator "aten.convolution"(%[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_7]], %[[VAL_16]], %[[VAL_10]]) : (!torch.tensor<[3,16,10,10],f32>, !torch.tensor<[4,16,3,3],f32>, !torch.tensor<[4],f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType, !basicpy.ListType, i64) -> !torch.tensor<[3,4,8,8],f32>
|
||||
# CHECK: %[[VAL_13:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_2]] : (i64, i64) -> !torch.list<i64>
|
||||
# CHECK: %[[VAL_14:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !torch.list<i64>
|
||||
# CHECK: %[[VAL_15:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !torch.list<i64>
|
||||
# CHECK: %[[VAL_16:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_9]] : (i64, i64) -> !torch.list<i64>
|
||||
# CHECK: %[[VAL_17:.*]] = torch.operator "aten.convolution"(%[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_7]], %[[VAL_16]], %[[VAL_10]]) : (!torch.tensor<[3,16,10,10],f32>, !torch.tensor<[4,16,3,3],f32>, !torch.tensor<[4],f32>, !torch.list<i64>, !torch.list<i64>, !torch.list<i64>, !basicpy.BoolType, !torch.list<i64>, i64) -> !torch.tensor<[3,4,8,8],f32>
|
||||
# CHECK: return %[[VAL_17]] : !torch.tensor<[3,4,8,8],f32>
|
||||
# CHECK: }
|
||||
|
||||
|
|
|
@ -16,14 +16,13 @@ class TestModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.l = [1, 2]
|
||||
# CHECK: torch.class_type @[[CLASSTYPE:.*]] {
|
||||
# TODO: Don't lose element type.
|
||||
# CHECK: torch.attr "l" : !basicpy.ListType
|
||||
# CHECK: torch.attr "l" : !torch.list<i64>
|
||||
# CHECK: }
|
||||
# CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64
|
||||
# CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64
|
||||
# CHECK: %[[LIST:.*]] = basicpy.build_list %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[N1]], %[[N2]] : (i64, i64) -> !torch.list<i64>
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.slot "l", %[[LIST]] : !basicpy.ListType
|
||||
# CHECK: torch.slot "l", %[[LIST]] : !torch.list<i64>
|
||||
# CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]">
|
||||
|
||||
|
||||
|
|
|
@ -21,8 +21,8 @@ mb = torch_mlir.ModuleBuilder()
|
|||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# CHECK: %[[L2:.*]] = basicpy.build_list
|
||||
# CHECK: %[[L1:.*]] = basicpy.build_list
|
||||
# CHECK: %[[L2:.*]] = torch.prim.ListConstruct
|
||||
# CHECK: %[[L1:.*]] = torch.prim.ListConstruct
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.slot "l2", %[[L2]]
|
||||
# CHECK: torch.slot "l1", %[[L1]]
|
||||
|
|
|
@ -11,9 +11,9 @@ mb = torch_mlir.ModuleBuilder()
|
|||
|
||||
# CHECK-LABEL: func @__torch__.f(
|
||||
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !basicpy.ListType {
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !basicpy.ListType
|
||||
# CHECK: return %[[RET]] : !basicpy.ListType
|
||||
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.list<!torch.tensor> {
|
||||
# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !torch.list<!torch.tensor>
|
||||
# CHECK: return %[[RET]] : !torch.list<!torch.tensor>
|
||||
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
|
|
@ -84,8 +84,8 @@ def prim_TupleIndex(tup: typing.Tuple[int, int]):
|
|||
return tup[0]
|
||||
|
||||
# CHECK-LABEL: func @__torch__.prim_ListUnpack(
|
||||
# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 {
|
||||
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64
|
||||
# CHECK-SAME: %[[ARG:.*]]: !torch.list<i64>) -> i64 {
|
||||
# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !torch.list<i64> -> i64, i64
|
||||
# CHECK: return %[[RET]]#1 : i64
|
||||
@mb.import_function
|
||||
@torch.jit.script
|
||||
|
@ -122,11 +122,11 @@ def prim_device(x):
|
|||
|
||||
# CHECK-LABEL: func @__torch__.prim_min(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
|
||||
# CHECK: %[[SINGLETON:.*]] = basicpy.build_list %[[ARG]] : (i64) -> !basicpy.ListType
|
||||
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !basicpy.ListType -> i64
|
||||
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list<i64>
|
||||
# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list<i64> -> i64
|
||||
# CHECK: %[[MIN2:.*]] = torch.prim.min.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
|
||||
# CHECK: %[[ARG_3_TIMES:.*]] = basicpy.build_list %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !basicpy.ListType -> i64
|
||||
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list<i64>
|
||||
# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !torch.list<i64> -> i64
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MIN1]], %[[MIN2]], %[[MIN3]] : (i64, i64, i64) -> !basicpy.TupleType
|
||||
# CHECK: return %[[RET]] : !basicpy.TupleType
|
||||
@mb.import_function
|
||||
|
@ -136,11 +136,11 @@ def prim_min(x: int):
|
|||
|
||||
# CHECK-LABEL: func @__torch__.prim_max(
|
||||
# CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType {
|
||||
# CHECK: %[[SINGLETON:.*]] = basicpy.build_list %[[ARG]] : (i64) -> !basicpy.ListType
|
||||
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !basicpy.ListType -> i64
|
||||
# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list<i64>
|
||||
# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list<i64> -> i64
|
||||
# CHECK: %[[MAX2:.*]] = torch.prim.max.int %[[ARG]], %[[ARG]] : i64, i64 -> i64
|
||||
# CHECK: %[[ARG_3_TIMES:.*]] = basicpy.build_list %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !basicpy.ListType
|
||||
# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !basicpy.ListType -> i64
|
||||
# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list<i64>
|
||||
# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !torch.list<i64> -> i64
|
||||
# CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MAX1]], %[[MAX2]], %[[MAX3]] : (i64, i64, i64) -> !basicpy.TupleType
|
||||
# CHECK: return %[[RET]] : !basicpy.TupleType
|
||||
@mb.import_function
|
||||
|
|
|
@ -62,10 +62,10 @@ MlirType npcompDictTypeGet(MlirContext context);
|
|||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "list" type. */
|
||||
int npcompTypeIsAList(MlirType t);
|
||||
int npcompTypeIsABasicpyList(MlirType t);
|
||||
|
||||
/** Gets the generic Python "list" type. */
|
||||
MlirType npcompListTypeGet(MlirContext context);
|
||||
MlirType npcompBasicpyListTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* NDArray type. */
|
||||
|
@ -137,6 +137,16 @@ int npcompTypeIsAOptional(MlirType t);
|
|||
/** Gets the !torch.optional<T> type with subtype T. */
|
||||
MlirType npcompOptionalTypeGet(MlirType containedType);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.list type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.list<T> type */
|
||||
int npcompTypeIsAList(MlirType t);
|
||||
|
||||
/** Gets the !torch.list<T> type with contained T. */
|
||||
MlirType npcompListTypeGet(MlirType containedType);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.Device type. */
|
||||
/*============================================================================*/
|
||||
|
|
|
@ -382,7 +382,7 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [
|
|||
]> {
|
||||
let summary = "Generated op for `aten::len.t : (t[]) -> (int)`";
|
||||
let arguments = (ins
|
||||
Basicpy_ListType:$a
|
||||
AnyTorchListType:$a
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchIntType:$result
|
||||
|
@ -392,3 +392,33 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten__Getitem__TOp : Torch_Op<"aten.__getitem__.t", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::__getitem__.t : (t[], int) -> (t)`";
|
||||
let arguments = (ins
|
||||
AnyTorchListType:$list,
|
||||
AnyTorchIntType:$idx
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchType:$result
|
||||
);
|
||||
let assemblyFormat = "$list `,` $idx attr-dict `:` type($list) `,` type($idx) `->` type($result)";
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten_SetItemTOp : Torch_Op<"aten._set_item.t", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_set_item.t : (t[], int, t) -> (t[])`";
|
||||
let arguments = (ins
|
||||
AnyTorchListType:$l,
|
||||
AnyTorchIntType:$idx,
|
||||
AnyTorchType:$el
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchListType:$result
|
||||
);
|
||||
let assemblyFormat = "$l `,` $idx `,` $el attr-dict `:` type($l) `,` type($idx) `,` type($el) `->` type($result)";
|
||||
}
|
||||
|
||||
|
|
|
@ -322,6 +322,26 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack",
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
||||
NoSideEffect,
|
||||
AllowsTypeRefinement,
|
||||
]> {
|
||||
let summary = "TorchScript prim::ListConstruct op";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyTorchType>:$elements
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchListType:$result
|
||||
);
|
||||
|
||||
let verifier = "return ::verify(*this);";
|
||||
|
||||
let assemblyFormat = [{
|
||||
$elements attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> {
|
||||
let summary = "TorchScript prim::GetAttr op";
|
||||
|
||||
|
|
|
@ -21,6 +21,31 @@ class Torch_Type<string name, string typeMnemonic,
|
|||
let mnemonic = typeMnemonic;
|
||||
}
|
||||
|
||||
class Torch_TypeWithContainedType<string name, string typeMnemonic> : Torch_Type<name, typeMnemonic> {
|
||||
let parameters = (ins "::mlir::Type":$containedType);
|
||||
|
||||
let printer = [{
|
||||
$_printer << getMnemonic() << "<" << getImpl()->containedType << ">";
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
Type containedType;
|
||||
if ($_parser.parseType(containedType))
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
return get($_ctxt, containedType);
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
|
||||
return Base::get(containedType.getContext(), containedType);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
|
||||
let summary = "torch.nn.Module";
|
||||
let description = [{
|
||||
|
@ -195,32 +220,16 @@ def AnyTorchTensorType : Type<
|
|||
// For now, we only need it as a stand-in type to allow importing
|
||||
// the `_is_full_backward_hook` optional bool type that Torch puts on
|
||||
// all classes.
|
||||
def Torch_OptionalType : Torch_Type<"Optional", "optional"> {
|
||||
def Torch_OptionalType : Torch_TypeWithContainedType<"Optional", "optional"> {
|
||||
let summary = "!torch.optional<T>";
|
||||
let description = [{
|
||||
}];
|
||||
let parameters = (ins "::mlir::Type":$containedType);
|
||||
}
|
||||
|
||||
let printer = [{
|
||||
$_printer << "optional<" << getImpl()->containedType << ">";
|
||||
def Torch_ListType : Torch_TypeWithContainedType<"List", "list"> {
|
||||
let summary = "!torch.list<T>";
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
Type containedType;
|
||||
if ($_parser.parseType(containedType))
|
||||
return Type();
|
||||
if ($_parser.parseGreater())
|
||||
return Type();
|
||||
return get($_ctxt, containedType);
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{
|
||||
return Base::get(containedType.getContext(), containedType);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def Torch_DeviceType : Torch_Type<"Device", "Device"> {
|
||||
|
@ -279,6 +288,13 @@ def AnyTorchScalarType : AnyTypeOf<[
|
|||
AnySignlessInteger,
|
||||
], "Any primitive type suitable to be passed as a Torch Scalar">;
|
||||
|
||||
def IsListTypePred : CPred<"$_self.isa<::mlir::NPCOMP::Torch::ListType>()">;
|
||||
|
||||
class ListOf<list<Type> allowedTypes, string descr> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>, IsListTypePred,
|
||||
"$_self.cast<::mlir::NPCOMP::Torch::ListType>().getContainedType()",
|
||||
descr, "::mlir::NPCOMP::Torch::ListType">;
|
||||
|
||||
def AnyTorchNumberType : AnyTypeOf<[
|
||||
AnySignedInteger,
|
||||
AnyFloat,
|
||||
|
@ -293,33 +309,29 @@ def AnyTorchBoolType : AnyTypeOf<[
|
|||
Basicpy_BoolType,
|
||||
], "Any permissible bool type">;
|
||||
|
||||
def AnyTorchBoolListType : AnyTypeOf<[
|
||||
Basicpy_ListType,
|
||||
// TODO: Support typed list when available.
|
||||
], "Any bool list type (bool[])">;
|
||||
def AnyTorchBoolListType : ListOf<[AnyTorchBoolType], "Any bool list type (bool[])">;
|
||||
|
||||
def AnyTorchIntType : AnyTypeOf<[
|
||||
AnySignedInteger,
|
||||
AnySignlessInteger,
|
||||
], "Any primitive integer type suitable to be passed as a Torch 'int'">;
|
||||
|
||||
def AnyTorchIntListType : AnyTypeOf<[
|
||||
Basicpy_ListType,
|
||||
// TODO: Support typed list when available.
|
||||
], "Any int list type (int[])">;
|
||||
def AnyTorchIntListType : ListOf<[AnyTorchIntType], "Any int list type (int[])">;
|
||||
|
||||
def AnyTorchType : AnyTypeOf<[
|
||||
AnyTorchBoolType,
|
||||
AnyTorchScalarType,
|
||||
AnyTorchTensorType,
|
||||
Basicpy_ListType,
|
||||
Basicpy_TupleType,
|
||||
Basicpy_NoneType,
|
||||
Basicpy_BytesType,
|
||||
Torch_NnModuleType,
|
||||
Torch_OptionalType,
|
||||
Torch_ListType,
|
||||
Torch_DeviceType,
|
||||
Torch_LinearParamsType,
|
||||
], "Any type that is legal to pass to a Torch kernel">;
|
||||
|
||||
def AnyTorchListType : ListOf<[AnyType], "Any Torch list Type">;
|
||||
|
||||
#endif // TORCH_TYPES
|
||||
|
|
|
@ -69,10 +69,10 @@ MlirType npcompDictTypeGet(MlirContext context) {
|
|||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "list" type. */
|
||||
int npcompTypeIsAList(MlirType t) { return unwrap(t).isa<Basicpy::ListType>(); }
|
||||
int npcompTypeIsABasicpyList(MlirType t) { return unwrap(t).isa<Basicpy::ListType>(); }
|
||||
|
||||
/** Gets the generic Python "dict" type. */
|
||||
MlirType npcompListTypeGet(MlirContext context) {
|
||||
MlirType npcompBasicpyListTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::ListType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
|
@ -175,6 +175,19 @@ MlirType npcompOptionalTypeGet(MlirType containedType) {
|
|||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.list type. */
|
||||
/*============================================================================*/
|
||||
/** Checks whether the given type is a !torch.list<T> type */
|
||||
int npcompTypeIsAList(MlirType t) {
|
||||
return unwrap(t).isa<Torch::ListType>();
|
||||
}
|
||||
|
||||
/** Gets the !torch.List<T> type with contained T. */
|
||||
MlirType npcompListTypeGet(MlirType containedType) {
|
||||
return wrap(Torch::ListType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.Device type. */
|
||||
/*============================================================================*/
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
|
@ -131,6 +132,21 @@ LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrimListConstructOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(PrimListConstructOp op) {
|
||||
auto resultType = op.getResult().getType();
|
||||
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType();
|
||||
auto matchResultElementType = [&](Type type) {
|
||||
return type.getTypeID() == resultElementType.getTypeID();
|
||||
};
|
||||
if (llvm::all_of(op->getOperandTypes(), matchResultElementType))
|
||||
return success();
|
||||
else return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ClassTypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -246,9 +262,9 @@ OpFoldResult AtenDimOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
OpFoldResult AtenLenTOp::fold(ArrayRef<Attribute> operands) {
|
||||
// `len([1,1,1])` -> `3`
|
||||
if (auto buildList = getOperand().getDefiningOp<Basicpy::BuildListOp>()) {
|
||||
if (auto listConstruct = getOperand().getDefiningOp<Torch::PrimListConstructOp>()) {
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
||||
buildList.getNumOperands());
|
||||
listConstruct.getNumOperands());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -288,8 +304,8 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
listElements.push_back(rewriter.create<::mlir::ConstantOp>(
|
||||
op->getLoc(), rewriter.getI64IntegerAttr(size)));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<Basicpy::BuildListOp>(
|
||||
op, Basicpy::ListType::get(rewriter.getContext()), listElements);
|
||||
rewriter.replaceOpWithNewOp<Torch::PrimListConstructOp>(
|
||||
op, Torch::ListType::get(rewriter.getI64Type()), listElements);
|
||||
return success();
|
||||
});
|
||||
// One-off pattern to erase if dead.
|
||||
|
@ -425,5 +441,30 @@ LogicalResult FromBuiltinTensorOp::inferReturnTypes(
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Aten__Getitem__TOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Aten__Getitem__TOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) {
|
||||
auto torchList = op.getOperand(0);
|
||||
if(!torchList.hasOneUse())
|
||||
return failure();
|
||||
|
||||
auto listConstruct = torchList.getDefiningOp<Torch::PrimListConstructOp>();
|
||||
if (!listConstruct)
|
||||
return failure();
|
||||
|
||||
APInt indexAP;
|
||||
if (!matchPattern(op.getOperand(1), m_ConstantInt(&indexAP)))
|
||||
return failure();
|
||||
|
||||
auto index = indexAP.getSExtValue();
|
||||
rewriter.replaceOp(op, {listConstruct.getOperand(index)});
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||
|
|
|
@ -2,20 +2,20 @@
|
|||
|
||||
// CHECK that multiple nested initialization ops are properly handled.
|
||||
|
||||
// CHECK-LABEL: torch.global_slot @l : !basicpy.ListType {
|
||||
// CHECK: %[[L0:.*]] = basicpy.build_list : () -> !basicpy.ListType
|
||||
// CHECK: %[[L1:.*]] = basicpy.build_list %[[L0]], %[[L0]] : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType
|
||||
// CHECK: %[[L2:.*]] = basicpy.build_list %[[L1]], %[[L1]] : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType
|
||||
// CHECK: torch.global_slot.init %[[L2]] : !basicpy.ListType
|
||||
// CHECK-LABEL: torch.global_slot @l : !torch.list<!torch.list<!torch.list<!torch.tensor>>> {
|
||||
// CHECK: %[[L0:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.tensor>
|
||||
// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[L0]], %[[L0]] : (!torch.list<!torch.tensor>, !torch.list<!torch.tensor>) -> !torch.list<!torch.list<!torch.tensor>>
|
||||
// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[L1]], %[[L1]] : (!torch.list<!torch.list<!torch.tensor>>, !torch.list<!torch.list<!torch.tensor>>) -> !torch.list<!torch.list<!torch.list<!torch.tensor>>>
|
||||
// CHECK: torch.global_slot.init %[[L2]] : !torch.list<!torch.list<!torch.list<!torch.tensor>>>
|
||||
// CHECK: }
|
||||
|
||||
torch.class_type @c {
|
||||
torch.attr "l" : !basicpy.ListType
|
||||
torch.attr "l" : !torch.list<!torch.list<!torch.list<!torch.tensor>>>
|
||||
}
|
||||
|
||||
%l0 = basicpy.build_list : () -> !basicpy.ListType
|
||||
%l1 = basicpy.build_list %l0, %l0 : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType
|
||||
%l2 = basicpy.build_list %l1, %l1 : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType
|
||||
%l0 = torch.prim.ListConstruct : () -> !torch.list<!torch.tensor>
|
||||
%l1 = torch.prim.ListConstruct %l0, %l0 : (!torch.list<!torch.tensor>, !torch.list<!torch.tensor>) -> !torch.list<!torch.list<!torch.tensor>>
|
||||
%l2 = torch.prim.ListConstruct %l1, %l1 : (!torch.list<!torch.list<!torch.tensor>>, !torch.list<!torch.list<!torch.tensor>>) -> !torch.list<!torch.list<!torch.list<!torch.tensor>>>
|
||||
torch.nn_module {
|
||||
torch.slot "l", %l2 : !basicpy.ListType
|
||||
torch.slot "l", %l2 : !torch.list<!torch.list<!torch.list<!torch.tensor>>>
|
||||
} : !torch.nn.Module<"c">
|
||||
|
|
|
@ -6,7 +6,7 @@ torch.class_type @parent {
|
|||
|
||||
func private @module_type_return(%arg0: !torch.nn.Module<"parent">) {
|
||||
// expected-error @+1 {{unsupported use of a torch.nn.Module. Expected only method calls or attribute get/set}}
|
||||
basicpy.build_list %arg0 : (!torch.nn.Module<"parent">) -> !basicpy.ListType
|
||||
torch.prim.ListConstruct %arg0 : (!torch.nn.Module<"parent">) -> !torch.list<!torch.nn.Module<"parent">>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -3,31 +3,31 @@
|
|||
// CHECK-LABEL: func @torch.aten.__is__
|
||||
// CHECK: %[[FALSE:.*]] = basicpy.bool_constant false
|
||||
// CHECK: return %[[FALSE]] : !basicpy.BoolType
|
||||
func @torch.aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !basicpy.BoolType{
|
||||
%0 = torch.aten.__is__ %arg0, %arg1 : !basicpy.ListType, !basicpy.NoneType -> !basicpy.BoolType
|
||||
func @torch.aten.__is__(%arg0: !torch.list<i64>, %arg1: !basicpy.NoneType) -> !basicpy.BoolType{
|
||||
%0 = torch.aten.__is__ %arg0, %arg1 : !torch.list<i64>, !basicpy.NoneType -> !basicpy.BoolType
|
||||
return %0 : !basicpy.BoolType
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.size$canonicalize_to_list(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !basicpy.ListType {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list<i64> {
|
||||
// CHECK: %[[C2:.*]] = constant 2 : i64
|
||||
// CHECK: %[[C3:.*]] = constant 3 : i64
|
||||
// CHECK: %[[LIST:.*]] = basicpy.build_list %[[C2]], %[[C3]] : (i64, i64) -> !basicpy.ListType
|
||||
// CHECK: return %[[LIST]] : !basicpy.ListType
|
||||
func @torch.aten.size$canonicalize_to_list(%arg0: !torch.vtensor<[2,3],f32>) -> !basicpy.ListType {
|
||||
%0 = torch.aten.size %arg0 : !torch.vtensor<[2,3],f32> -> !basicpy.ListType
|
||||
return %0 : !basicpy.ListType
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C2]], %[[C3]] : (i64, i64) -> !torch.list<i64>
|
||||
// CHECK: return %[[LIST]] : !torch.list<i64>
|
||||
func @torch.aten.size$canonicalize_to_list(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.list<i64> {
|
||||
%0 = torch.aten.size %arg0 : !torch.vtensor<[2,3],f32> -> !torch.list<i64>
|
||||
return %0 : !torch.list<i64>
|
||||
}
|
||||
|
||||
// One size unknown, so cannot canonicalize.
|
||||
// TODO: For unknown sizes, insert the equivalent of a "dim" op.
|
||||
// Then this will only require static rank.
|
||||
// CHECK-LABEL: func @torch.aten.size$unknown_size(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3],f32>) -> !basicpy.ListType {
|
||||
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor<[?,3],f32> -> !basicpy.ListType
|
||||
func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !basicpy.ListType {
|
||||
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !basicpy.ListType
|
||||
return %0 : !basicpy.ListType
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<i64> {
|
||||
// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor<[?,3],f32> -> !torch.list<i64>
|
||||
func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<i64> {
|
||||
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<i64>
|
||||
return %0 : !torch.list<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.len.t$of_size(
|
||||
|
@ -35,8 +35,8 @@ func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !basicpy
|
|||
// CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> i64
|
||||
// CHECK: return %[[DIM]] : i64
|
||||
func @torch.aten.len.t$of_size(%arg0: !torch.vtensor<*,f32>) -> i64 {
|
||||
%0 = torch.aten.size %arg0 : !torch.vtensor<*,f32> -> !basicpy.ListType
|
||||
%1 = torch.aten.len.t %0 : !basicpy.ListType -> i64
|
||||
%0 = torch.aten.size %arg0 : !torch.vtensor<*,f32> -> !torch.list<i64>
|
||||
%1 = torch.aten.len.t %0 : !torch.list<i64> -> i64
|
||||
return %1 : i64
|
||||
}
|
||||
|
||||
|
@ -54,8 +54,8 @@ func @torch.aten.dim$with_shape(%arg0: !torch.vtensor<[?,?,?],f32>) -> i64 {
|
|||
// CHECK: %[[LEN:.*]] = constant 4 : i64
|
||||
// CHECK: return %[[LEN]] : i64
|
||||
func @torch.aten.len.t$of_build_list(%arg0: i64) -> i64 {
|
||||
%0 = basicpy.build_list %arg0, %arg0, %arg0, %arg0 : (i64, i64, i64, i64) -> !basicpy.ListType
|
||||
%1 = torch.aten.len.t %0 : !basicpy.ListType -> i64
|
||||
%0 = torch.prim.ListConstruct %arg0, %arg0, %arg0, %arg0 : (i64, i64, i64, i64) -> !torch.list<i64>
|
||||
%1 = torch.aten.len.t %0 : !torch.list<i64> -> i64
|
||||
return %1 : i64
|
||||
}
|
||||
|
||||
|
@ -75,3 +75,41 @@ func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(%arg0: !torch.vte
|
|||
%1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.__getitem__.t(
|
||||
// CHECK: %[[C5:.*]] = constant 5 : i64
|
||||
// CHECK: return %[[C5]] : i64
|
||||
func @torch.aten.__getitem__.t() -> i64 {
|
||||
%c4_i64 = constant 4 : i64
|
||||
%c5_i64 = constant 5 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = torch.prim.ListConstruct %c4_i64, %c5_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%1 = torch.aten.__getitem__.t %0, %c1_i64 : !torch.list<i64>, i64 -> i64
|
||||
return %1 : i64
|
||||
}
|
||||
|
||||
// Not canonicalized because of passed in index
|
||||
// CHECK-LABEL: func @torch.aten.__getitem__.t$no_change_test0(
|
||||
// CHECK: %[[C4:.*]] = constant 4 : i64
|
||||
// CHECK: %[[C5:.*]] = constant 5 : i64
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C4]], %[[C5]] : (i64, i64) -> !torch.list<i64>
|
||||
// CHECK: %[[ITEM:.*]] = torch.aten.__getitem__.t %[[LIST]], %arg0 : !torch.list<i64>, i64 -> i64
|
||||
// CHECK: return %[[ITEM]] : i64
|
||||
func @torch.aten.__getitem__.t$no_change_test0(%arg0: i64) -> i64 {
|
||||
%c4_i64 = constant 4 : i64
|
||||
%c5_i64 = constant 5 : i64
|
||||
%0 = torch.prim.ListConstruct %c4_i64, %c5_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%1 = torch.aten.__getitem__.t %0, %arg0 : !torch.list<i64>, i64 -> i64
|
||||
return %1 : i64
|
||||
}
|
||||
|
||||
// Not canonicalized because of passed in list
|
||||
// CHECK-LABEL: func @torch.aten.__getitem__.t$no_change_test1(
|
||||
// CHECK: %[[C5:.*]] = constant 5 : i64
|
||||
// CHECK: %[[ITEM:.*]] = torch.aten.__getitem__.t %arg0, %[[C5]] : !torch.list<i64>, i64 -> i64
|
||||
// CHECK: return %[[ITEM]] : i64
|
||||
func @torch.aten.__getitem__.t$no_change_test1(%arg0: !torch.list<i64>) -> i64 {
|
||||
%c5_i64 = constant 5 : i64
|
||||
%0 = torch.aten.__getitem__.t %arg0, %c5_i64 : !torch.list<i64>, i64 -> i64
|
||||
return %0 : i64
|
||||
}
|
||||
|
|
|
@ -72,10 +72,10 @@ func @f(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg
|
|||
func @f(%arg0:!torch.vtensor, %arg1:!torch.vtensor, %arg2:!torch.vtensor) ->!torch.vtensor {
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64 ->!torch.vtensor
|
||||
%0 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%1 = torch.prim.ListConstruct %c0_i64, %c0_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%2 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.list<i64>, !torch.list<i64>, !torch.list<i64>, i64 ->!torch.vtensor
|
||||
return %3 :!torch.vtensor
|
||||
}
|
||||
|
||||
|
@ -86,10 +86,10 @@ func @f(%arg0:!torch.vtensor, %arg1:!torch.vtensor, %arg2:!torch.vtensor) ->!tor
|
|||
func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:!torch.vtensor<*,f32>) ->!torch.vtensor {
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64 ->!torch.vtensor
|
||||
%0 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%1 = torch.prim.ListConstruct %c0_i64, %c0_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%2 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !torch.list<i64>, !torch.list<i64>, !torch.list<i64>, i64 ->!torch.vtensor
|
||||
return %3 :!torch.vtensor
|
||||
}
|
||||
|
||||
|
@ -101,12 +101,12 @@ func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
|
|||
%c3_i64 = constant 3 : i64
|
||||
%c2_i64 = constant 2 : i64
|
||||
%bool_false = basicpy.bool_constant false
|
||||
%21 = basicpy.build_list %c3_i64, %c3_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%22 = basicpy.build_list %c2_i64, %c2_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%23 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%24 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%21 = torch.prim.ListConstruct %c3_i64, %c3_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%22 = torch.prim.ListConstruct %c2_i64, %c2_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%23 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
|
||||
%24 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
|
||||
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32>
|
||||
%27 = torch.aten.max_pool2d %arg0, %21, %22, %23, %24, %bool_false : !torch.vtensor<[?,?,?,?],f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType -> !torch.vtensor
|
||||
%27 = torch.aten.max_pool2d %arg0, %21, %22, %23, %24, %bool_false : !torch.vtensor<[?,?,?,?],f32>, !torch.list<i64>, !torch.list<i64>, !torch.list<i64>, !torch.list<i64>, !basicpy.BoolType -> !torch.vtensor
|
||||
return %27 : !torch.vtensor
|
||||
}
|
||||
|
||||
|
@ -115,9 +115,9 @@ func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
|
|||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
|
||||
%c1_i64 = constant 1 : i64
|
||||
%0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType
|
||||
%0 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list<i64>
|
||||
// CHECK: torch.aten.adaptive_avg_pool2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32>
|
||||
%1 = torch.aten.adaptive_avg_pool2d %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !basicpy.ListType -> !torch.vtensor
|
||||
%1 = torch.aten.adaptive_avg_pool2d %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<i64> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue