diff --git a/frontends/pytorch/csrc/builder/acap_dispatch.cpp b/frontends/pytorch/csrc/builder/acap_dispatch.cpp index 0bfacf61c..093066b52 100644 --- a/frontends/pytorch/csrc/builder/acap_dispatch.cpp +++ b/frontends/pytorch/csrc/builder/acap_dispatch.cpp @@ -468,7 +468,8 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc, for (IValue element : list) { elements.push_back(mapIValueToMlirValue(loc, element)); } - return funcBuilder->buildList(loc, elements); + return funcBuilder->buildList(loc, + typeMapper.mapFromTorchType(loc, list.elementType()), elements); } if (ival.isNone()) { return funcBuilder->getNoneConstant(loc); @@ -511,7 +512,9 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc, return mlirIntegerTypeGet(funcBuilder->getContext(), 1); } if (ival.isList()) { - return npcompListTypeGet(funcBuilder->getContext()); + return npcompListTypeGet( + typeMapper.mapFromTorchType( + loc, ival.toList().elementType())); } if (ival.isNone()) { return npcompNoneTypeGet(funcBuilder->getContext()); diff --git a/frontends/pytorch/csrc/builder/func_builder.cpp b/frontends/pytorch/csrc/builder/func_builder.cpp index 2c8363a17..3d175f9ec 100644 --- a/frontends/pytorch/csrc/builder/func_builder.cpp +++ b/frontends/pytorch/csrc/builder/func_builder.cpp @@ -130,10 +130,10 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc, return insertConstantOp(OpBuilder(context).createStdConstant(loc, value)); } -MlirValue FuncBuilder::buildList(MlirLocation loc, +MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType, std::vector &elements) { - MlirType resultType = npcompListTypeGet(context); - OperationStateHolder state{"basicpy.build_list", loc}; + MlirType resultType = npcompListTypeGet(elementType); + OperationStateHolder state{"torch.prim.ListConstruct", loc}; mlirOperationStateAddResults(state, 1, &resultType); mlirOperationStateAddOperands(state, elements.size(), elements.data()); MlirOperation op = state.createOperation(); diff --git a/frontends/pytorch/csrc/builder/func_builder.h b/frontends/pytorch/csrc/builder/func_builder.h index 92697391a..7319ef224 100644 --- a/frontends/pytorch/csrc/builder/func_builder.h +++ b/frontends/pytorch/csrc/builder/func_builder.h @@ -117,7 +117,8 @@ public: MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value); /// Builds a list with the given elements - MlirValue buildList(MlirLocation loc, std::vector &elements); + MlirValue buildList(MlirLocation loc, MlirType elementType, + std::vector &elements); private: FuncBuilder(MlirContext context, MlirOperation funcOp, diff --git a/frontends/pytorch/csrc/builder/ivalue_importer.cpp b/frontends/pytorch/csrc/builder/ivalue_importer.cpp index 6a2e4a94d..5eefbd47e 100644 --- a/frontends/pytorch/csrc/builder/ivalue_importer.cpp +++ b/frontends/pytorch/csrc/builder/ivalue_importer.cpp @@ -270,8 +270,10 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { elems.push_back(importIValue(elem)); } MlirOperation operation = - createMlirOperationAtEnd(importBlock, "basicpy.build_list", loc, - npcompListTypeGet(context), elems); + createMlirOperationAtEnd(importBlock, "torch.prim.ListConstruct", loc, + npcompListTypeGet( + typeMapper.mapFromTorchType( + loc, list.elementType())), elems); return mlirOperationGetResult(operation, 0); } if (ivalue.isTuple()) { diff --git a/frontends/pytorch/csrc/builder/node_importer.cpp b/frontends/pytorch/csrc/builder/node_importer.cpp index d0a63b340..5ae631f30 100644 --- a/frontends/pytorch/csrc/builder/node_importer.cpp +++ b/frontends/pytorch/csrc/builder/node_importer.cpp @@ -82,6 +82,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { // Builtin interpreter ops with no operator/schema. switch (kind) { case c10::prim::ListUnpack: + case c10::prim::ListConstruct: createAndMapTrivialNode(node, "torch.prim." + std::string(kind.toUnqualString())); return; @@ -96,10 +97,6 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { // Ops trivially lowered through `basicpy` dialect. switch (kind) { - case c10::prim::ListConstruct: { - createAndMapTrivialNode(node, "basicpy.build_list"); - return; - } case c10::prim::TupleConstruct: { createAndMapTrivialNode(node, "basicpy.build_tuple"); return; diff --git a/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp b/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp index 5238a48bd..dc259810b 100644 --- a/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp +++ b/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp @@ -181,8 +181,9 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc, return npcompBoolTypeGet(context); } case TypeKind::ListType: { - // TODO: Don't lose the element type information. - return npcompListTypeGet(context); + return npcompListTypeGet( + mapFromTorchType( + loc, torchType->cast()->getElementType())); } case TypeKind::TupleType: { // TODO: Don't lose the element type information. diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py index d39c94149..e2b494a68 100644 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py @@ -224,7 +224,7 @@ TORCH_TYPE_TO_ODS_TYPE = { "bool": "AnyTorchBoolType", "bool[]": "AnyTorchBoolListType", "float": "AnyFloat", - "t[]": "Basicpy_ListType", + "t[]": "AnyTorchListType", "t": "AnyTorchType", "t1": "AnyTorchType", "t2": "AnyTorchType", @@ -455,6 +455,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True) + emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) + emit("aten::_set_item.t : (t[], int, t) -> (t[])") def emit_quantized_ops(torch_ir_dir: str, registry: Registry): diff --git a/frontends/pytorch/test/acap_export/test_export_cat.py b/frontends/pytorch/test/acap_export/test_export_cat.py index 1bb72a9ae..f26026413 100644 --- a/frontends/pytorch/test/acap_export/test_export_cat.py +++ b/frontends/pytorch/test/acap_export/test_export_cat.py @@ -40,7 +40,7 @@ with mb.capture_function("conv_cat", [inputs, target]) as f: # CHECK: "aten.convolution" # CHECK: "aten.convolution" -# CHECK: basicpy.build_list +# CHECK: torch.prim.ListConstruct # CHECK: "aten._cat" # CHECK: "aten._log_softmax" # CHECK: "aten.nll_loss2d_forward" diff --git a/frontends/pytorch/test/acap_export/test_export_conv2d_fwd.py b/frontends/pytorch/test/acap_export/test_export_conv2d_fwd.py index db1f65837..122781073 100644 --- a/frontends/pytorch/test/acap_export/test_export_conv2d_fwd.py +++ b/frontends/pytorch/test/acap_export/test_export_conv2d_fwd.py @@ -45,11 +45,11 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f: # CHECK: %[[VAL_10:.*]] = constant 1 : i64 # CHECK: %[[VAL_11:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>) : !torch.tensor<[4,16,3,3],f32> # CHECK: %[[VAL_12:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4xf32>) : !torch.tensor<[4],f32> -# CHECK: %[[VAL_13:.*]] = basicpy.build_list %[[VAL_1]], %[[VAL_2]] : (i64, i64) -> !basicpy.ListType -# CHECK: %[[VAL_14:.*]] = basicpy.build_list %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !basicpy.ListType -# CHECK: %[[VAL_15:.*]] = basicpy.build_list %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !basicpy.ListType -# CHECK: %[[VAL_16:.*]] = basicpy.build_list %[[VAL_8]], %[[VAL_9]] : (i64, i64) -> !basicpy.ListType -# CHECK: %[[VAL_17:.*]] = torch.operator "aten.convolution"(%[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_7]], %[[VAL_16]], %[[VAL_10]]) : (!torch.tensor<[3,16,10,10],f32>, !torch.tensor<[4,16,3,3],f32>, !torch.tensor<[4],f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType, !basicpy.ListType, i64) -> !torch.tensor<[3,4,8,8],f32> +# CHECK: %[[VAL_13:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_2]] : (i64, i64) -> !torch.list +# CHECK: %[[VAL_14:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (i64, i64) -> !torch.list +# CHECK: %[[VAL_15:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (i64, i64) -> !torch.list +# CHECK: %[[VAL_16:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_9]] : (i64, i64) -> !torch.list +# CHECK: %[[VAL_17:.*]] = torch.operator "aten.convolution"(%[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_7]], %[[VAL_16]], %[[VAL_10]]) : (!torch.tensor<[3,16,10,10],f32>, !torch.tensor<[4,16,3,3],f32>, !torch.tensor<[4],f32>, !torch.list, !torch.list, !torch.list, !basicpy.BoolType, !torch.list, i64) -> !torch.tensor<[3,4,8,8],f32> # CHECK: return %[[VAL_17]] : !torch.tensor<[3,4,8,8],f32> # CHECK: } diff --git a/frontends/pytorch/test/ivalue_import/list.py b/frontends/pytorch/test/ivalue_import/list.py index f78bec4ff..a2509a3c3 100644 --- a/frontends/pytorch/test/ivalue_import/list.py +++ b/frontends/pytorch/test/ivalue_import/list.py @@ -16,14 +16,13 @@ class TestModule(torch.nn.Module): super().__init__() self.l = [1, 2] # CHECK: torch.class_type @[[CLASSTYPE:.*]] { -# TODO: Don't lose element type. -# CHECK: torch.attr "l" : !basicpy.ListType +# CHECK: torch.attr "l" : !torch.list # CHECK: } # CHECK: %[[N1:.*]] = basicpy.numeric_constant 1 : i64 # CHECK: %[[N2:.*]] = basicpy.numeric_constant 2 : i64 -# CHECK: %[[LIST:.*]] = basicpy.build_list %[[N1]], %[[N2]] : (i64, i64) -> !basicpy.ListType +# CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[N1]], %[[N2]] : (i64, i64) -> !torch.list # CHECK: torch.nn_module { -# CHECK: torch.slot "l", %[[LIST]] : !basicpy.ListType +# CHECK: torch.slot "l", %[[LIST]] : !torch.list # CHECK: } : !torch.nn.Module<"[[CLASSTYPE]]"> diff --git a/frontends/pytorch/test/ivalue_import/object-identity-torch-bug.py b/frontends/pytorch/test/ivalue_import/object-identity-torch-bug.py index fd310a49e..c77169195 100644 --- a/frontends/pytorch/test/ivalue_import/object-identity-torch-bug.py +++ b/frontends/pytorch/test/ivalue_import/object-identity-torch-bug.py @@ -21,8 +21,8 @@ mb = torch_mlir.ModuleBuilder() class TestModule(torch.nn.Module): def __init__(self): super().__init__() - # CHECK: %[[L2:.*]] = basicpy.build_list - # CHECK: %[[L1:.*]] = basicpy.build_list + # CHECK: %[[L2:.*]] = torch.prim.ListConstruct + # CHECK: %[[L1:.*]] = torch.prim.ListConstruct # CHECK: torch.nn_module { # CHECK: torch.slot "l2", %[[L2]] # CHECK: torch.slot "l1", %[[L1]] diff --git a/frontends/pytorch/test/node_import/list.py b/frontends/pytorch/test/node_import/list.py index fb072f572..9528be7ec 100644 --- a/frontends/pytorch/test/node_import/list.py +++ b/frontends/pytorch/test/node_import/list.py @@ -11,9 +11,9 @@ mb = torch_mlir.ModuleBuilder() # CHECK-LABEL: func @__torch__.f( # CHECK-SAME: %[[T0:.*]]: !torch.tensor, -# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !basicpy.ListType { -# CHECK: %[[RET:.*]] = basicpy.build_list %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !basicpy.ListType -# CHECK: return %[[RET]] : !basicpy.ListType +# CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> !torch.list { +# CHECK: %[[RET:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]] : (!torch.tensor, !torch.tensor) -> !torch.list +# CHECK: return %[[RET]] : !torch.list @mb.import_function @torch.jit.script diff --git a/frontends/pytorch/test/node_import/prim.py b/frontends/pytorch/test/node_import/prim.py index 74d87e975..4f6b62d67 100644 --- a/frontends/pytorch/test/node_import/prim.py +++ b/frontends/pytorch/test/node_import/prim.py @@ -84,8 +84,8 @@ def prim_TupleIndex(tup: typing.Tuple[int, int]): return tup[0] # CHECK-LABEL: func @__torch__.prim_ListUnpack( -# CHECK-SAME: %[[ARG:.*]]: !basicpy.ListType) -> i64 { -# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !basicpy.ListType -> i64, i64 +# CHECK-SAME: %[[ARG:.*]]: !torch.list) -> i64 { +# CHECK: %[[RET:.*]]:3 = torch.prim.ListUnpack %[[ARG]] : !torch.list -> i64, i64 # CHECK: return %[[RET]]#1 : i64 @mb.import_function @torch.jit.script @@ -122,11 +122,11 @@ def prim_device(x): # CHECK-LABEL: func @__torch__.prim_min( # CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType { -# CHECK: %[[SINGLETON:.*]] = basicpy.build_list %[[ARG]] : (i64) -> !basicpy.ListType -# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !basicpy.ListType -> i64 +# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list +# CHECK: %[[MIN1:.*]] = torch.prim.min.self_int %[[SINGLETON]] : !torch.list -> i64 # CHECK: %[[MIN2:.*]] = torch.prim.min.int %[[ARG]], %[[ARG]] : i64, i64 -> i64 -# CHECK: %[[ARG_3_TIMES:.*]] = basicpy.build_list %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !basicpy.ListType -# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !basicpy.ListType -> i64 +# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list +# CHECK: %[[MIN3:.*]] = torch.prim.min.self_int %[[ARG_3_TIMES]] : !torch.list -> i64 # CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MIN1]], %[[MIN2]], %[[MIN3]] : (i64, i64, i64) -> !basicpy.TupleType # CHECK: return %[[RET]] : !basicpy.TupleType @mb.import_function @@ -136,11 +136,11 @@ def prim_min(x: int): # CHECK-LABEL: func @__torch__.prim_max( # CHECK-SAME: %[[ARG:.*]]: i64) -> !basicpy.TupleType { -# CHECK: %[[SINGLETON:.*]] = basicpy.build_list %[[ARG]] : (i64) -> !basicpy.ListType -# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !basicpy.ListType -> i64 +# CHECK: %[[SINGLETON:.*]] = torch.prim.ListConstruct %[[ARG]] : (i64) -> !torch.list +# CHECK: %[[MAX1:.*]] = torch.prim.max.self_int %[[SINGLETON]] : !torch.list -> i64 # CHECK: %[[MAX2:.*]] = torch.prim.max.int %[[ARG]], %[[ARG]] : i64, i64 -> i64 -# CHECK: %[[ARG_3_TIMES:.*]] = basicpy.build_list %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !basicpy.ListType -# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !basicpy.ListType -> i64 +# CHECK: %[[ARG_3_TIMES:.*]] = torch.prim.ListConstruct %[[ARG]], %[[ARG]], %[[ARG]] : (i64, i64, i64) -> !torch.list +# CHECK: %[[MAX3:.*]] = torch.prim.max.self_int %[[ARG_3_TIMES]] : !torch.list -> i64 # CHECK: %[[RET:.*]] = basicpy.build_tuple %[[MAX1]], %[[MAX2]], %[[MAX3]] : (i64, i64, i64) -> !basicpy.TupleType # CHECK: return %[[RET]] : !basicpy.TupleType @mb.import_function diff --git a/include/npcomp-c/Types.h b/include/npcomp-c/Types.h index 72ed03c1d..f69eb2150 100644 --- a/include/npcomp-c/Types.h +++ b/include/npcomp-c/Types.h @@ -62,10 +62,10 @@ MlirType npcompDictTypeGet(MlirContext context); /*============================================================================*/ /** Checks whether the given type is the Python "list" type. */ -int npcompTypeIsAList(MlirType t); +int npcompTypeIsABasicpyList(MlirType t); /** Gets the generic Python "list" type. */ -MlirType npcompListTypeGet(MlirContext context); +MlirType npcompBasicpyListTypeGet(MlirContext context); /*============================================================================*/ /* NDArray type. */ @@ -137,6 +137,16 @@ int npcompTypeIsAOptional(MlirType t); /** Gets the !torch.optional type with subtype T. */ MlirType npcompOptionalTypeGet(MlirType containedType); +/*============================================================================*/ +/* torch.list type. */ +/*============================================================================*/ + +/** Checks whether the given type is a !torch.list type */ +int npcompTypeIsAList(MlirType t); + +/** Gets the !torch.list type with contained T. */ +MlirType npcompListTypeGet(MlirType containedType); + /*============================================================================*/ /* torch.Device type. */ /*============================================================================*/ diff --git a/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td b/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td index 7e63ea40d..f6f559f7f 100644 --- a/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td @@ -382,7 +382,7 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [ ]> { let summary = "Generated op for `aten::len.t : (t[]) -> (int)`"; let arguments = (ins - Basicpy_ListType:$a + AnyTorchListType:$a ); let results = (outs AnyTorchIntType:$result @@ -392,3 +392,33 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [ let hasCanonicalizer = 1; } +def Torch_Aten__Getitem__TOp : Torch_Op<"aten.__getitem__.t", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::__getitem__.t : (t[], int) -> (t)`"; + let arguments = (ins + AnyTorchListType:$list, + AnyTorchIntType:$idx + ); + let results = (outs + AnyTorchType:$result + ); + let assemblyFormat = "$list `,` $idx attr-dict `:` type($list) `,` type($idx) `->` type($result)"; + let hasCanonicalizer = 1; +} + +def Torch_Aten_SetItemTOp : Torch_Op<"aten._set_item.t", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::_set_item.t : (t[], int, t) -> (t[])`"; + let arguments = (ins + AnyTorchListType:$l, + AnyTorchIntType:$idx, + AnyTorchType:$el + ); + let results = (outs + AnyTorchListType:$result + ); + let assemblyFormat = "$l `,` $idx `,` $el attr-dict `:` type($l) `,` type($idx) `,` type($el) `->` type($result)"; +} + diff --git a/include/npcomp/Dialect/Torch/IR/TorchOps.td b/include/npcomp/Dialect/Torch/IR/TorchOps.td index 05b97008f..3665dd941 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchOps.td +++ b/include/npcomp/Dialect/Torch/IR/TorchOps.td @@ -322,6 +322,26 @@ def Torch_PrimListUnpackOp: Torch_Op<"prim.ListUnpack", }]; } +def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ + NoSideEffect, + AllowsTypeRefinement, + ]> { + let summary = "TorchScript prim::ListConstruct op"; + + let arguments = (ins + Variadic:$elements + ); + let results = (outs + AnyTorchListType:$result + ); + + let verifier = "return ::verify(*this);"; + + let assemblyFormat = [{ + $elements attr-dict `:` functional-type(operands, results) + }]; +} + def Torch_PrimGetAttrOp : Torch_Op<"prim.GetAttr", []> { let summary = "TorchScript prim::GetAttr op"; diff --git a/include/npcomp/Dialect/Torch/IR/TorchTypes.td b/include/npcomp/Dialect/Torch/IR/TorchTypes.td index 7d2fe60b3..60c759d56 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchTypes.td +++ b/include/npcomp/Dialect/Torch/IR/TorchTypes.td @@ -21,6 +21,31 @@ class Torch_Type : Torch_Type { + let parameters = (ins "::mlir::Type":$containedType); + + let printer = [{ + $_printer << getMnemonic() << "<" << getImpl()->containedType << ">"; + }]; + + let parser = [{ + if (parser.parseLess()) + return Type(); + Type containedType; + if ($_parser.parseType(containedType)) + return Type(); + if ($_parser.parseGreater()) + return Type(); + return get($_ctxt, containedType); + }]; + + let builders = [ + TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{ + return Base::get(containedType.getContext(), containedType); + }]> + ]; +} + def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> { let summary = "torch.nn.Module"; let description = [{ @@ -195,32 +220,16 @@ def AnyTorchTensorType : Type< // For now, we only need it as a stand-in type to allow importing // the `_is_full_backward_hook` optional bool type that Torch puts on // all classes. -def Torch_OptionalType : Torch_Type<"Optional", "optional"> { +def Torch_OptionalType : Torch_TypeWithContainedType<"Optional", "optional"> { let summary = "!torch.optional"; let description = [{ }]; - let parameters = (ins "::mlir::Type":$containedType); +} - let printer = [{ - $_printer << "optional<" << getImpl()->containedType << ">"; +def Torch_ListType : Torch_TypeWithContainedType<"List", "list"> { + let summary = "!torch.list"; + let description = [{ }]; - - let parser = [{ - if (parser.parseLess()) - return Type(); - Type containedType; - if ($_parser.parseType(containedType)) - return Type(); - if ($_parser.parseGreater()) - return Type(); - return get($_ctxt, containedType); - }]; - - let builders = [ - TypeBuilderWithInferredContext<(ins "::mlir::Type":$containedType), [{ - return Base::get(containedType.getContext(), containedType); - }]> - ]; } def Torch_DeviceType : Torch_Type<"Device", "Device"> { @@ -279,6 +288,13 @@ def AnyTorchScalarType : AnyTypeOf<[ AnySignlessInteger, ], "Any primitive type suitable to be passed as a Torch Scalar">; +def IsListTypePred : CPred<"$_self.isa<::mlir::NPCOMP::Torch::ListType>()">; + +class ListOf allowedTypes, string descr> : + ContainerType, IsListTypePred, + "$_self.cast<::mlir::NPCOMP::Torch::ListType>().getContainedType()", + descr, "::mlir::NPCOMP::Torch::ListType">; + def AnyTorchNumberType : AnyTypeOf<[ AnySignedInteger, AnyFloat, @@ -293,33 +309,29 @@ def AnyTorchBoolType : AnyTypeOf<[ Basicpy_BoolType, ], "Any permissible bool type">; -def AnyTorchBoolListType : AnyTypeOf<[ - Basicpy_ListType, - // TODO: Support typed list when available. -], "Any bool list type (bool[])">; +def AnyTorchBoolListType : ListOf<[AnyTorchBoolType], "Any bool list type (bool[])">; def AnyTorchIntType : AnyTypeOf<[ AnySignedInteger, AnySignlessInteger, ], "Any primitive integer type suitable to be passed as a Torch 'int'">; -def AnyTorchIntListType : AnyTypeOf<[ - Basicpy_ListType, - // TODO: Support typed list when available. -], "Any int list type (int[])">; +def AnyTorchIntListType : ListOf<[AnyTorchIntType], "Any int list type (int[])">; def AnyTorchType : AnyTypeOf<[ AnyTorchBoolType, AnyTorchScalarType, AnyTorchTensorType, - Basicpy_ListType, Basicpy_TupleType, Basicpy_NoneType, Basicpy_BytesType, Torch_NnModuleType, Torch_OptionalType, + Torch_ListType, Torch_DeviceType, Torch_LinearParamsType, ], "Any type that is legal to pass to a Torch kernel">; +def AnyTorchListType : ListOf<[AnyType], "Any Torch list Type">; + #endif // TORCH_TYPES diff --git a/lib/CAPI/Types.cpp b/lib/CAPI/Types.cpp index 06d34055c..a16491b40 100644 --- a/lib/CAPI/Types.cpp +++ b/lib/CAPI/Types.cpp @@ -69,10 +69,10 @@ MlirType npcompDictTypeGet(MlirContext context) { /*============================================================================*/ /** Checks whether the given type is the Python "list" type. */ -int npcompTypeIsAList(MlirType t) { return unwrap(t).isa(); } +int npcompTypeIsABasicpyList(MlirType t) { return unwrap(t).isa(); } /** Gets the generic Python "dict" type. */ -MlirType npcompListTypeGet(MlirContext context) { +MlirType npcompBasicpyListTypeGet(MlirContext context) { return wrap(Basicpy::ListType::get(unwrap(context))); } @@ -175,6 +175,19 @@ MlirType npcompOptionalTypeGet(MlirType containedType) { return wrap(Torch::OptionalType::get(unwrap(containedType))); } +/*============================================================================*/ +/* torch.list type. */ +/*============================================================================*/ +/** Checks whether the given type is a !torch.list type */ +int npcompTypeIsAList(MlirType t) { + return unwrap(t).isa(); +} + +/** Gets the !torch.List type with contained T. */ +MlirType npcompListTypeGet(MlirType containedType) { + return wrap(Torch::ListType::get(unwrap(containedType))); +} + /*============================================================================*/ /* torch.Device type. */ /*============================================================================*/ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 316d508cc..50ea823d0 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" @@ -131,6 +132,21 @@ LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// PrimListConstructOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(PrimListConstructOp op) { + auto resultType = op.getResult().getType(); + auto resultElementType = resultType.dyn_cast().getContainedType(); + auto matchResultElementType = [&](Type type) { + return type.getTypeID() == resultElementType.getTypeID(); + }; + if (llvm::all_of(op->getOperandTypes(), matchResultElementType)) + return success(); + else return failure(); +} + //===----------------------------------------------------------------------===// // ClassTypeOp //===----------------------------------------------------------------------===// @@ -246,9 +262,9 @@ OpFoldResult AtenDimOp::fold(ArrayRef operands) { OpFoldResult AtenLenTOp::fold(ArrayRef operands) { // `len([1,1,1])` -> `3` - if (auto buildList = getOperand().getDefiningOp()) { + if (auto listConstruct = getOperand().getDefiningOp()) { return IntegerAttr::get(IntegerType::get(getContext(), 64), - buildList.getNumOperands()); + listConstruct.getNumOperands()); } return nullptr; } @@ -288,8 +304,8 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, listElements.push_back(rewriter.create<::mlir::ConstantOp>( op->getLoc(), rewriter.getI64IntegerAttr(size))); } - rewriter.replaceOpWithNewOp( - op, Basicpy::ListType::get(rewriter.getContext()), listElements); + rewriter.replaceOpWithNewOp( + op, Torch::ListType::get(rewriter.getI64Type()), listElements); return success(); }); // One-off pattern to erase if dead. @@ -425,5 +441,30 @@ LogicalResult FromBuiltinTensorOp::inferReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// Aten__Getitem__TOp +//===----------------------------------------------------------------------===// + +void Aten__Getitem__TOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) { + auto torchList = op.getOperand(0); + if(!torchList.hasOneUse()) + return failure(); + + auto listConstruct = torchList.getDefiningOp(); + if (!listConstruct) + return failure(); + + APInt indexAP; + if (!matchPattern(op.getOperand(1), m_ConstantInt(&indexAP))) + return failure(); + + auto index = indexAP.getSExtValue(); + rewriter.replaceOp(op, {listConstruct.getOperand(index)}); + return success(); + }); +} + #define GET_OP_CLASSES #include "npcomp/Dialect/Torch/IR/TorchOps.cpp.inc" diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir index 5a97cd3c2..204eb76fc 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/initializers.mlir @@ -2,20 +2,20 @@ // CHECK that multiple nested initialization ops are properly handled. -// CHECK-LABEL: torch.global_slot @l : !basicpy.ListType { -// CHECK: %[[L0:.*]] = basicpy.build_list : () -> !basicpy.ListType -// CHECK: %[[L1:.*]] = basicpy.build_list %[[L0]], %[[L0]] : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType -// CHECK: %[[L2:.*]] = basicpy.build_list %[[L1]], %[[L1]] : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType -// CHECK: torch.global_slot.init %[[L2]] : !basicpy.ListType +// CHECK-LABEL: torch.global_slot @l : !torch.list>> { +// CHECK: %[[L0:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[L1:.*]] = torch.prim.ListConstruct %[[L0]], %[[L0]] : (!torch.list, !torch.list) -> !torch.list> +// CHECK: %[[L2:.*]] = torch.prim.ListConstruct %[[L1]], %[[L1]] : (!torch.list>, !torch.list>) -> !torch.list>> +// CHECK: torch.global_slot.init %[[L2]] : !torch.list>> // CHECK: } torch.class_type @c { - torch.attr "l" : !basicpy.ListType + torch.attr "l" : !torch.list>> } -%l0 = basicpy.build_list : () -> !basicpy.ListType -%l1 = basicpy.build_list %l0, %l0 : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType -%l2 = basicpy.build_list %l1, %l1 : (!basicpy.ListType, !basicpy.ListType) -> !basicpy.ListType +%l0 = torch.prim.ListConstruct : () -> !torch.list +%l1 = torch.prim.ListConstruct %l0, %l0 : (!torch.list, !torch.list) -> !torch.list> +%l2 = torch.prim.ListConstruct %l1, %l1 : (!torch.list>, !torch.list>) -> !torch.list>> torch.nn_module { - torch.slot "l", %l2 : !basicpy.ListType + torch.slot "l", %l2 : !torch.list>> } : !torch.nn.Module<"c"> diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/module-uses-error.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/module-uses-error.mlir index a78a475ac..4ee4b9f96 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/module-uses-error.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/module-uses-error.mlir @@ -6,7 +6,7 @@ torch.class_type @parent { func private @module_type_return(%arg0: !torch.nn.Module<"parent">) { // expected-error @+1 {{unsupported use of a torch.nn.Module. Expected only method calls or attribute get/set}} - basicpy.build_list %arg0 : (!torch.nn.Module<"parent">) -> !basicpy.ListType + torch.prim.ListConstruct %arg0 : (!torch.nn.Module<"parent">) -> !torch.list> return } diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index e9cf5bb0e..f676fb672 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3,31 +3,31 @@ // CHECK-LABEL: func @torch.aten.__is__ // CHECK: %[[FALSE:.*]] = basicpy.bool_constant false // CHECK: return %[[FALSE]] : !basicpy.BoolType -func @torch.aten.__is__(%arg0: !basicpy.ListType, %arg1: !basicpy.NoneType) -> !basicpy.BoolType{ - %0 = torch.aten.__is__ %arg0, %arg1 : !basicpy.ListType, !basicpy.NoneType -> !basicpy.BoolType +func @torch.aten.__is__(%arg0: !torch.list, %arg1: !basicpy.NoneType) -> !basicpy.BoolType{ + %0 = torch.aten.__is__ %arg0, %arg1 : !torch.list, !basicpy.NoneType -> !basicpy.BoolType return %0 : !basicpy.BoolType } // CHECK-LABEL: func @torch.aten.size$canonicalize_to_list( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !basicpy.ListType { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list { // CHECK: %[[C2:.*]] = constant 2 : i64 // CHECK: %[[C3:.*]] = constant 3 : i64 -// CHECK: %[[LIST:.*]] = basicpy.build_list %[[C2]], %[[C3]] : (i64, i64) -> !basicpy.ListType -// CHECK: return %[[LIST]] : !basicpy.ListType -func @torch.aten.size$canonicalize_to_list(%arg0: !torch.vtensor<[2,3],f32>) -> !basicpy.ListType { - %0 = torch.aten.size %arg0 : !torch.vtensor<[2,3],f32> -> !basicpy.ListType - return %0 : !basicpy.ListType +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C2]], %[[C3]] : (i64, i64) -> !torch.list +// CHECK: return %[[LIST]] : !torch.list +func @torch.aten.size$canonicalize_to_list(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.list { + %0 = torch.aten.size %arg0 : !torch.vtensor<[2,3],f32> -> !torch.list + return %0 : !torch.list } // One size unknown, so cannot canonicalize. // TODO: For unknown sizes, insert the equivalent of a "dim" op. // Then this will only require static rank. // CHECK-LABEL: func @torch.aten.size$unknown_size( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3],f32>) -> !basicpy.ListType { -// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor<[?,3],f32> -> !basicpy.ListType -func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !basicpy.ListType { - %0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !basicpy.ListType - return %0 : !basicpy.ListType +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list { +// CHECK: %[[SIZE:.*]] = torch.aten.size %[[ARG]] : !torch.vtensor<[?,3],f32> -> !torch.list +func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list { + %0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list + return %0 : !torch.list } // CHECK-LABEL: func @torch.aten.len.t$of_size( @@ -35,8 +35,8 @@ func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !basicpy // CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> i64 // CHECK: return %[[DIM]] : i64 func @torch.aten.len.t$of_size(%arg0: !torch.vtensor<*,f32>) -> i64 { - %0 = torch.aten.size %arg0 : !torch.vtensor<*,f32> -> !basicpy.ListType - %1 = torch.aten.len.t %0 : !basicpy.ListType -> i64 + %0 = torch.aten.size %arg0 : !torch.vtensor<*,f32> -> !torch.list + %1 = torch.aten.len.t %0 : !torch.list -> i64 return %1 : i64 } @@ -54,8 +54,8 @@ func @torch.aten.dim$with_shape(%arg0: !torch.vtensor<[?,?,?],f32>) -> i64 { // CHECK: %[[LEN:.*]] = constant 4 : i64 // CHECK: return %[[LEN]] : i64 func @torch.aten.len.t$of_build_list(%arg0: i64) -> i64 { - %0 = basicpy.build_list %arg0, %arg0, %arg0, %arg0 : (i64, i64, i64, i64) -> !basicpy.ListType - %1 = torch.aten.len.t %0 : !basicpy.ListType -> i64 + %0 = torch.prim.ListConstruct %arg0, %arg0, %arg0, %arg0 : (i64, i64, i64, i64) -> !torch.list + %1 = torch.aten.len.t %0 : !torch.list -> i64 return %1 : i64 } @@ -75,3 +75,41 @@ func @torch.copy.tensor$unnecessary_intermediate_nonval_tensor(%arg0: !torch.vte %1 = torch.copy.tensor %0 : !torch.tensor -> !torch.vtensor return %1 : !torch.vtensor } + +// CHECK-LABEL: func @torch.aten.__getitem__.t( +// CHECK: %[[C5:.*]] = constant 5 : i64 +// CHECK: return %[[C5]] : i64 +func @torch.aten.__getitem__.t() -> i64 { + %c4_i64 = constant 4 : i64 + %c5_i64 = constant 5 : i64 + %c1_i64 = constant 1 : i64 + %0 = torch.prim.ListConstruct %c4_i64, %c5_i64 : (i64, i64) -> !torch.list + %1 = torch.aten.__getitem__.t %0, %c1_i64 : !torch.list, i64 -> i64 + return %1 : i64 +} + +// Not canonicalized because of passed in index +// CHECK-LABEL: func @torch.aten.__getitem__.t$no_change_test0( +// CHECK: %[[C4:.*]] = constant 4 : i64 +// CHECK: %[[C5:.*]] = constant 5 : i64 +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C4]], %[[C5]] : (i64, i64) -> !torch.list +// CHECK: %[[ITEM:.*]] = torch.aten.__getitem__.t %[[LIST]], %arg0 : !torch.list, i64 -> i64 +// CHECK: return %[[ITEM]] : i64 +func @torch.aten.__getitem__.t$no_change_test0(%arg0: i64) -> i64 { + %c4_i64 = constant 4 : i64 + %c5_i64 = constant 5 : i64 + %0 = torch.prim.ListConstruct %c4_i64, %c5_i64 : (i64, i64) -> !torch.list + %1 = torch.aten.__getitem__.t %0, %arg0 : !torch.list, i64 -> i64 + return %1 : i64 +} + +// Not canonicalized because of passed in list +// CHECK-LABEL: func @torch.aten.__getitem__.t$no_change_test1( +// CHECK: %[[C5:.*]] = constant 5 : i64 +// CHECK: %[[ITEM:.*]] = torch.aten.__getitem__.t %arg0, %[[C5]] : !torch.list, i64 -> i64 +// CHECK: return %[[ITEM]] : i64 +func @torch.aten.__getitem__.t$no_change_test1(%arg0: !torch.list) -> i64 { + %c5_i64 = constant 5 : i64 + %0 = torch.aten.__getitem__.t %arg0, %c5_i64 : !torch.list, i64 -> i64 + return %0 : i64 +} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index e203cee22..c4bdfc68b 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -72,10 +72,10 @@ func @f(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg func @f(%arg0:!torch.vtensor, %arg1:!torch.vtensor, %arg2:!torch.vtensor) ->!torch.vtensor { %c0_i64 = constant 0 : i64 %c1_i64 = constant 1 : i64 - %0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType - %1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType - %2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType - %3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64 ->!torch.vtensor + %0 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list + %1 = torch.prim.ListConstruct %c0_i64, %c0_i64 : (i64, i64) -> !torch.list + %2 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list + %3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor, !torch.vtensor, !torch.vtensor, !torch.list, !torch.list, !torch.list, i64 ->!torch.vtensor return %3 :!torch.vtensor } @@ -86,10 +86,10 @@ func @f(%arg0:!torch.vtensor, %arg1:!torch.vtensor, %arg2:!torch.vtensor) ->!tor func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:!torch.vtensor<*,f32>) ->!torch.vtensor { %c0_i64 = constant 0 : i64 %c1_i64 = constant 1 : i64 - %0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType - %1 = basicpy.build_list %c0_i64, %c0_i64 : (i64, i64) -> !basicpy.ListType - %2 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType - %3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, i64 ->!torch.vtensor + %0 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list + %1 = torch.prim.ListConstruct %c0_i64, %c0_i64 : (i64, i64) -> !torch.list + %2 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list + %3 = torch.aten.conv2d %arg0, %arg1, %arg2, %0, %1, %2, %c1_i64 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !torch.vtensor<*,f32>, !torch.list, !torch.list, !torch.list, i64 ->!torch.vtensor return %3 :!torch.vtensor } @@ -101,12 +101,12 @@ func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor { %c3_i64 = constant 3 : i64 %c2_i64 = constant 2 : i64 %bool_false = basicpy.bool_constant false - %21 = basicpy.build_list %c3_i64, %c3_i64 : (i64, i64) -> !basicpy.ListType - %22 = basicpy.build_list %c2_i64, %c2_i64 : (i64, i64) -> !basicpy.ListType - %23 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType - %24 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %21 = torch.prim.ListConstruct %c3_i64, %c3_i64 : (i64, i64) -> !torch.list + %22 = torch.prim.ListConstruct %c2_i64, %c2_i64 : (i64, i64) -> !torch.list + %23 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list + %24 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list // CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32> - %27 = torch.aten.max_pool2d %arg0, %21, %22, %23, %24, %bool_false : !torch.vtensor<[?,?,?,?],f32>, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.ListType, !basicpy.BoolType -> !torch.vtensor + %27 = torch.aten.max_pool2d %arg0, %21, %22, %23, %24, %bool_false : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !basicpy.BoolType -> !torch.vtensor return %27 : !torch.vtensor } @@ -115,9 +115,9 @@ func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor { // CHECK-LABEL: func @f func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor { %c1_i64 = constant 1 : i64 - %0 = basicpy.build_list %c1_i64, %c1_i64 : (i64, i64) -> !basicpy.ListType + %0 = torch.prim.ListConstruct %c1_i64, %c1_i64 : (i64, i64) -> !torch.list // CHECK: torch.aten.adaptive_avg_pool2d{{.*}} -> !torch.vtensor<[?,?,?,?],f32> - %1 = torch.aten.adaptive_avg_pool2d %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !basicpy.ListType -> !torch.vtensor + %1 = torch.aten.adaptive_avg_pool2d %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor return %1 : !torch.vtensor }