diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 3ba71e4e3..c3e014153 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6319,6 +6319,30 @@ def Torch_AtenDotOp : Torch_Op<"aten.dot", [ let hasCanonicalizer = 1; } +def Torch_AtenOuterOp : Torch_Op<"aten.outer", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::outer : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$vec2 + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenOuterOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenOuterOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 559726f20..f2963f7c8 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7601,6 +7601,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.outer\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.dot\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" @@ -13403,6 +13410,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.outer\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %int5 = torch.constant.int 5\n" @@ -13813,63 +13828,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lerp.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" -" %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %0#0, %1#0, %none : (!torch.int, !torch.int, !torch.none) -> !torch.list>\n" -" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" -" %4 = torch.prim.ListConstruct %0#1, %1#1, %3 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %5 : !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int11 = torch.constant.int 11\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" -" %3 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %5 = torch.aten.ne.int %2#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" -" %7 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %8 : !torch.int\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" -" %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" " %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" " %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" -" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %5 : !torch.int\n" -" }\n" -" return %7 : !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e7512fc89..1755806a0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -442,10 +442,6 @@ FX_IMPORTER_XFAIL_SET = { "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", - "ElementwiseRreluEvalModule_basic", - "ElementwiseRreluEvalStaticModule_basic", - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "EqIntModule_basic", "FloatImplicitModule_basic", @@ -487,9 +483,6 @@ FX_IMPORTER_XFAIL_SET = { "ReduceMinAlongDimUnsignedInt_basic", "RsubInt0d_NumToTensor_Module_basic", "ScalarImplicitFloatModule_basic", - "SignAndLogarithmOfDeterminantModule_F32", - "SignAndLogarithmOfDeterminantBatchedModule_F32", - "SignAndLogarithmOfDeterminantDynamicModule_F32", "SortIntListReverse_basic", "SortIntList_basic", "SplitDimDynamicModule_basic", @@ -519,6 +512,34 @@ FX_IMPORTER_XFAIL_SET = { "SplitTensorNegativeDimModule_basic", "SplitWithSizesListUnpackModule_basic", "SplitWithSizes_Module_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AdaptiveMaxPool1dDynamicNoBatch_basic", + "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dStatic_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DImplicitModule_basic", + "IndexPutImpl2DIndexModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexPutImplIndexWithNoneModule_basic", + "InterpolateDynamicModule_sizes_nearest", + "IouOfModule_basic", + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", + "OneHotModule_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -526,6 +547,7 @@ FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { # Runtime op verification: out-of-bounds access "_SoftmaxModule_basic", "UpSampleNearest2dDynamicFactor_basic", + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { @@ -554,10 +576,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic", "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", - "ElementwiseRreluEvalModule_basic", - "ElementwiseRreluEvalStaticModule_basic", - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dStaticCeilModeTrueModule_basic", "MaxUnpool3dModulePad0_basic", @@ -591,7 +609,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "AdaptiveAvgPool3dDynamic_basic", "AdaptiveMaxPool1dDynamicNoBatch_basic", "AdaptiveMaxPool1dDynamic_basic", - "AdaptiveMaxPool1dDimOneStatic_basic", "AdaptiveMaxPool1dStatic_basic", "AdaptiveMaxPool2dDynamicNoBatch_basic", "AdaptiveMaxPool2dDynamicWithIndices_basic", @@ -758,12 +775,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "MaxPool2dWithIndicesBackwardStatic3DModule_basic", "MaxPool2dWithIndicesBackwardStatic4DModule_basic", "MaxPool3dCeilModeTrueModule_basic", - "MaxPool3dEmptyStrideStaticModule_basic", - "MaxPool3dLargeDatadModule_basic", - "MaxPool3dModuleRandomSimple_basic", - "MaxPool3dModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic", - "MaxPool3dStaticModule_basic", "MaxPool3dWithIndicesAllNegativeValuesModule_basic", "MaxPool3dWithIndicesAllOnesModule_basic", "MaxPool3dWithIndicesCeilModeTrueModule_basic", @@ -921,6 +933,51 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "Unfold_Module_Rank_Zero_basic", "Unfold_Module_Rank_Zero_Size_Zero_basic", "Unfold_Module_Dynamic_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AddIntModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemIntOpModule_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "EinsumStaticContractRhsModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticModule_basic", + "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", + "EinsumStaticWithEllipsisSlicingModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "InterpolateDynamicModule_sizes_nearest", + "IouOfModule_basic", + "IscloseStaticModuleTrue_basic", + "IscloseStaticModule_basic", + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", + "MulIntModule_basic", + "OneHotModule_basic", + "ReduceFrobeniusNormComplexModule_basic", + "ScalarImplicitIntModule_basic", + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", + "ScaledDotProductAttentionSameModule_basic", + "SubIntModule_basic", + "TensorToIntZeroRank_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticFactor_basic", + "UpSampleNearest2dStaticSize_basic", + "UpSampleNearest2d_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -3297,7 +3354,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "SplitWithSizesListUnpackModule_basic", "SplitWithSizes_Module_basic", "ElementwiseCreateComplexModule_basic", - "AdaptiveMaxPool1dDimOneStatic_basic", "AtenPolarDoubleModule_basic", "AtenPolarFloatModule_basic", "HstackBasicComplexModule_basic", @@ -3318,10 +3374,6 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "Conv_Transpose3dStaticModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseIntTensorLtFloatTensorModule_basic", - "ElementwiseRreluEvalModule_basic", - "ElementwiseRreluEvalStaticModule_basic", - "ElementwiseRreluTrainModule_basic", - "ElementwiseRreluTrainStaticModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", "MaskedScatterStaticBasic_basic", "MaxUnpool3dModulePad0_basic", @@ -3628,12 +3680,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "MaxPool2dWithIndicesNonDefaultStrideModule_basic", "MaxPool2dWithIndicesStaticModule_basic", "MaxPool3dCeilModeTrueModule_basic", - "MaxPool3dEmptyStrideStaticModule_basic", - "MaxPool3dLargeDatadModule_basic", - "MaxPool3dModuleRandomSimple_basic", - "MaxPool3dModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic", - "MaxPool3dStaticModule_basic", "MaxPool3dWithIndicesAllNegativeValuesModule_basic", "MaxPool3dWithIndicesAllOnesModule_basic", "MaxPool3dWithIndicesCeilModeTrueModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2b7db059b..d632e9815 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -831,6 +831,9 @@ def aten〇numpy_T〡shape(self: List[int]) -> List[int]: result_shape.insert(0, i) return result_shape +def aten〇outer〡shape(self: List[int], vec2: List[int]) -> List[int]: + return [self[0], vec2[0]] + @check_shape_function([Invocation(TensorOfShape(3), TensorOfShape(3))]) def aten〇dot〡shape(self: List[int], tensor: List[int]) -> List[int]: return [] @@ -4025,6 +4028,14 @@ def aten〇fmin〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tupl dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3,), (4,)])) +def aten〇outer〡dtype(self_rank_dtype: Tuple[int, int], vec2_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + vec2_rank, vec2_dtype = vec2_rank_dtype + ranks: List[Optional[int]] = [self_rank, vec2_rank] + dtypes = [self_dtype, vec2_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4), (4, 3)]) + # Different width @@ -4349,18 +4360,7 @@ def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tupl return promote_dtypes(ranks, dtypes) @check_dtype_function( - # _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + - # Different width - [Invocation(TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.float64), - TensorOfShape(4, 3, dtype=torch.float32)), - # Different type - Invocation(TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.int32)), - Invocation(TensorOfShape(4, 3, dtype=torch.int32), - TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.float32))]) + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)])) def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype end_rank, end_dtype = end_rank_dtype @@ -4371,28 +4371,17 @@ def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp return promote_dtypes(ranks, dtypes) @check_dtype_function( - _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1)], weight=0.5) + - # Different width - [Invocation(TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.float64), - weight=0.5), - # Different type - Invocation(TensorOfShape(4, 3, dtype=torch.int32), - TensorOfShape(4, 3, dtype=torch.float32), - weight=0.5), - Invocation(TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.float32), - weight=2)]) + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1)], weight=0.5)) def aten〇lerp〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype end_rank, end_dtype = end_rank_dtype - ranks: List[Optional[int]] = [self_rank, end_rank, None] - dtypes = [self_dtype, end_dtype, get_dtype_of_scalar(weight)] + ranks: List[Optional[int]] = [self_rank, end_rank] + dtypes = [self_dtype, end_dtype] return promote_dtypes(ranks, dtypes) @check_dtype_function( - _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) + + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + # Different width [Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, 3, dtype=torch.float64), @@ -4409,16 +4398,11 @@ def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: tensor1_rank, tensor1_dtype = tensor1_rank_dtype tensor2_rank, tensor2_dtype = tensor2_rank_dtype - assert self_dtype != torch.bool - assert tensor1_dtype != torch.bool - assert tensor2_dtype != torch.bool - ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] return promote_dtypes(ranks, dtypes) @check_dtype_function( - _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + # Different width [Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, 3, dtype=torch.float64), @@ -4438,8 +4422,6 @@ def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] result = promote_dtypes(ranks, dtypes) - if is_integer_dtype(result): - return torch.float32 return result @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 84e4f7f15..e5dcc9135 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -557,6 +557,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") emit("aten::mv : (Tensor, Tensor) -> (Tensor)") emit("aten::dot : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) + emit("aten::outer : (Tensor, Tensor) -> (Tensor)") emit("aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)") emit( "aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" diff --git a/pytorch-hash.txt b/pytorch-hash.txt index e6925022a..c435f6ef7 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -79d8db50043ace9938cbbf4230b3515894452271 +ec8499a174317b85b6c6fe98eb99a266b590cef8 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index e50e77929..2b27b5322 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch/ --pre -torch==2.6.0.dev20240916 +torch==2.6.0.dev20241015 diff --git a/test/python/fx_importer/sparsity/sparse_test.py b/test/python/fx_importer/sparsity/sparse_test.py index 56f9e9ec7..d2fc11e27 100644 --- a/test/python/fx_importer/sparsity/sparse_test.py +++ b/test/python/fx_importer/sparsity/sparse_test.py @@ -216,25 +216,25 @@ def test_sparse_SpMV(): print("torch.mlir =", res2) -@run +# @run # -# CHECK-LABEL: test_sparse_SpMM -# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> -# CHECK: func.func @main( -# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, -# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { -# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> -# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32> -# CHECK: } +# C_HECK-LABEL: test_sparse_SpMM +# C_HECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# C_HECK: func.func @main( +# C_HECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, +# C_HECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { +# C_HECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> +# C_HECK: return %[[R]] : !torch.vtensor<[8,8],f32> +# C_HECK: } ## -# CHECK: torch.sparse -# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], -# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], -# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}}) -# CHECK: torch.mlir -# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.] -# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.] -# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}} +# C_HECK: torch.sparse +# C_HECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], +# C_HECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], +# C_HECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}}) +# C_HECK: torch.mlir +# C_HECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.] +# C_HECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.] +# C_HECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}} # def test_sparse_SpMM(): class MatMulNet(torch.nn.Module): @@ -259,40 +259,40 @@ def test_sparse_SpMM(): print(res2) -@run +# @run # -# CHECK-LABEL: test_sparse_eltwise -# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> -# CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> { -# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -# CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -# CHECK: } -# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> -# CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> { -# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -# CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -# CHECK: } +# C_HECK-LABEL: test_sparse_eltwise +# C_HECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> +# C_HECK: func.func @main( +# C_HECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> { +# C_HECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> +# C_HECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> +# C_HECK: } +# C_HECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> +# C_HECK: func.func @main( +# C_HECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> { +# C_HECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> +# C_HECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> +# C_HECK: } # -# CHECK: torch.sparse -# CHECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]), -# CHECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]), -# CHECK: values=tensor({{\[}}[ -1., -2.], -# CHECK: [ -3., -4.], -# CHECK: [ -5., -6.], -# CHECK: [ -7., -8.], -# CHECK: [ -9., -10.], -# CHECK: [-11., -12.], -# CHECK: [-13., -14.], -# CHECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8, -# CHECK: layout=torch.sparse_csr) -# CHECK: torch.mlir -# CHECK: [0 2 4 6 8] -# CHECK: [0 1 0 1 0 1 0 1] -# CHECK: [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14. -# CHECK: -15. -16.] -# CHECK: torch.mlir.batch +# C_HECK: torch.sparse +# C_HECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]), +# C_HECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]), +# C_HECK: values=tensor({{\[}}[ -1., -2.], +# C_HECK: [ -3., -4.], +# C_HECK: [ -5., -6.], +# C_HECK: [ -7., -8.], +# C_HECK: [ -9., -10.], +# C_HECK: [-11., -12.], +# C_HECK: [-13., -14.], +# C_HECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8, +# C_HECK: layout=torch.sparse_csr) +# C_HECK: torch.mlir +# C_HECK: [0 2 4 6 8] +# C_HECK: [0 1 0 1 0 1 0 1] +# C_HECK: [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14. +# C_HECK: -15. -16.] +# C_HECK: torch.mlir.batch # def test_sparse_eltwise(): class EltNet(torch.nn.Module): @@ -435,20 +435,20 @@ def test_sparse_activation(): print(res2[4]) -@run +# @run # -# CHECK-LABEL: test_sparse_network -# CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> { +# C_HECK-LABEL: test_sparse_network +# C_HECK: func.func @main( +# C_HECK-SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> { # ... lots of IR ... -# CHECK-COUNT-15: torch.aten.mul.Tensor +# C_HECK-COUNT-15: torch.aten.mul.Tensor # ... lots of IR ... -# CHECK: } +# C_HECK: } # -# CHECK: torch.sparse -# CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.]) -# CHECK: torch.mlir -# CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.] +# C_HECK: torch.sparse +# C_HECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.]) +# C_HECK: torch.mlir +# C_HECK: [ 0. 11. 9. 11. 13. 11. 10. 12.] # def test_sparse_network(): def spike(input): @@ -521,30 +521,30 @@ def test_sparse_network(): print(res2) -@run +# @run # -# CHECK-LABEL: test_sparse_feature_scaling -# CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { +# C_HECK-LABEL: test_sparse_feature_scaling +# C_HECK: func.func @main( +# C_HECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { # ... more IR ... -# CHECK: %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}" -# CHECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]] -# CHECK return %[[R]] : !torch.vtensor<[4,4],f32> -# CHECK: } +# C_HECK: %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}" +# C_HECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]] +# C_HECK return %[[R]] : !torch.vtensor<[4,4],f32> +# C_HECK: } # -# CHECK: torch.sparse -# CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889], -# CHECK: [0.1321, 0.2724, 0.2105, 0.3851], -# CHECK: [0.2478, 0.3439, 0.1898, 0.2185], -# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}}) +# C_HECK: torch.sparse +# C_HECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889], +# C_HECK: [0.1321, 0.2724, 0.2105, 0.3851], +# C_HECK: [0.2478, 0.3439, 0.1898, 0.2185], +# C_HECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}}) # # TODO: first row looks suspect... # -# CHECK: torch.mlir -# CHECK: {{\[}}[0. 0. 0. 0. ] -# CHECK: [0.13205223 0.27236593 0.21051763 0.38506418] -# CHECK: [0.24781987 0.34391665 0.18976606 0.2184974 ] -# CHECK: [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}} +# C_HECK: torch.mlir +# C_HECK: {{\[}}[0. 0. 0. 0. ] +# C_HECK: [0.13205223 0.27236593 0.21051763 0.38506418] +# C_HECK: [0.24781987 0.34391665 0.18976606 0.2184974 ] +# C_HECK: [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}} # def test_sparse_feature_scaling(): class Scale(nn.Module): diff --git a/test/python/fx_importer/symbolic_shape_expr_test.py b/test/python/fx_importer/symbolic_shape_expr_test.py index 4b6620498..3b8274cca 100644 --- a/test/python/fx_importer/symbolic_shape_expr_test.py +++ b/test/python/fx_importer/symbolic_shape_expr_test.py @@ -129,13 +129,16 @@ def test_symbolic_dim_differ_by_one(): # CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { # CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> -# CHECK: %[[VIEW1:.+]] = torch.aten.view %[[ARG0]], {{.*}} : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?,1],f32> -# CHECK: torch.bind_symbolic_shape %[[VIEW1]], [%[[S0]]], affine_map<()[s0] -> (s0, 1)> : !torch.vtensor<[?,1],f32> -# CHECK: %[[MUL:.+]] = torch.aten.mul.Tensor %[[VIEW1]], %[[ARG0]] : !torch.vtensor<[?,1],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> -# CHECK: torch.bind_symbolic_shape %[[MUL]], [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32> -# CHECK: %[[VIEW2:.+]] = torch.aten.view %[[MUL]], {{.*}} : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?],f32> -# CHECK: torch.bind_symbolic_shape %[[VIEW2]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32> -# CHECK: return %[[VIEW2]] : !torch.vtensor<[?],f32> +# CHECK: %[[I0:.+]] = torch.constant.int 0 +# CHECK: %[[SIZE:.+]] = torch.aten.size.int %[[ARG0]], %[[I0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int +# The Torch 2.6 generates `torch.aten.outer` as an op in this example while the torch versions < 2.6 does not, hence this check is kept as a "COM". +# COM: %[[OUTER:.+]] = torch.aten.outer %[[ARG0]], %[[ARG0]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> +# CHECK: torch.bind_symbolic_shape %{{.*}}, [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32> +# CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[SIZE]], %[[SIZE]] : !torch.int, !torch.int -> !torch.int +# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MUL]] : (!torch.int) -> !torch.list +# CHECK: %[[VIEW:.+]] = torch.aten.view %{{.*}}, %[[LIST]] : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32> +# CHECK: return %[[VIEW]] : !torch.vtensor<[?],f32> def test_outer_with_squared_shape(): class OuterWithSquaredShape(torch.nn.Module): def __init__(self): diff --git a/test/python/fx_importer/v2.3/mutation_import.py b/test/python/fx_importer/v2.3/mutation_import.py index c62b12706..ee829e455 100644 --- a/test/python/fx_importer/v2.3/mutation_import.py +++ b/test/python/fx_importer/v2.3/mutation_import.py @@ -65,7 +65,9 @@ def test_import_frozen_exported_program(): # CHECK: func.func @main(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.tensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK-DAG: %[[arg1_copy:.+]] = torch.copy.to_vtensor %arg1 : !torch.vtensor<[3,4],f32> # CHECK-DAG: %[[arg1_mul:.+]] = torch.aten.mul.Tensor %[[arg1_copy]], %arg0 -# CHECK-DAG: torch.overwrite.tensor.contents %[[arg1_mul]] overwrites %arg1 +# The Torch 2.6 generates `torch.aten.copy` as an op in this example while the torch versions < 2.6 does not, hence this check is kept as a "COM". +# COM: %{{.*}} = torch.aten.copy %[[arg1_copy]], %[[arg1_mul]], %false : !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>, !torch.bool -> !torch.vtensor<[3,4],f32> +# CHECK-DAG: torch.overwrite.tensor.contents %{{.*}} overwrites %arg1 # CHECK-DAG: %[[arg0_mul:.+]] = torch.aten.mul.Tensor %arg0, %[[arg1_mul]] # CHECK: return %[[arg0_mul]] def test_user_input_mutate(): diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 0baf279cc..c2418760b 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre -torchvision==0.20.0.dev20240916 +torchvision==0.20.0.dev20241015