diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 23e476d9b..eb2cb68b3 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -748,6 +748,7 @@ STABLEHLO_PASS_SET = { "NewEmptyModuleNonDefaultFloatDtype_basic", "NewEmptyModuleNonDefaultIntDtype_basic", "NewEmptyStridedModuleDefaultDtype_basic", + "EmptyStridedModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", @@ -1421,4 +1422,5 @@ LTC_XFAIL_SET = { "ScatterValueIntModule_basic", "UniformStaticShapeModule_basic", "AtenEmbeddingBagStaticModule_basic", + "EmptyStridedModule_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 12a2c9d74..f0d0a238a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8335,6 +8335,34 @@ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ }]; } +def Torch_AtenEmptyStridedOp : Torch_Op<"aten.empty_strided", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEmptyStridedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenEmptyStridedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenExpandOp : Torch_Op<"aten.expand", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index ba1471694..697ad6bbd 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7220,6 +7220,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.empty.memory_format\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.empty_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -10533,6 +10536,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.empty_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c3dde6de4..63ce4f837 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4416,6 +4416,53 @@ public: }; } // namespace +namespace { +class DecomposeAtenEmptyStridedOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEmptyStridedOp op, + PatternRewriter &rewriter) const override { + SmallVector sizeListInts, strideListInts; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) + return rewriter.notifyMatchFailure( + op, "all size list elements must be constant ints"); + if (!matchPattern(op.getStride(), + m_TorchListOfConstantInts(strideListInts))) + return rewriter.notifyMatchFailure( + op, "all stride list elements must be constant ints"); + + // We only support the cases with default stride values. + // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) + // Here the stride[0] == size[1] * size[2], stride[1] == size[2], and + // stride[2] == 1. + bool isDefaultStride = true; + for (unsigned i = 0; i < strideListInts.size(); i++) { + int64_t defaultStride = 1; + for (unsigned j = i + 1; j < sizeListInts.size(); j++) + defaultStride *= sizeListInts[j]; + if (defaultStride != strideListInts[i]) { + isDefaultStride = false; + break; + } + } + if (!isDefaultStride) + return rewriter.notifyMatchFailure( + op, "only default strides supported for new_empty_strided op"); + + Value noneVal = rewriter.create(op.getLoc()); + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), + op.getPinMemory(), /*memoryFormat=*/noneVal); + + return success(); + + + } +}; +} // namespace + namespace { class DecomposePrimsSqueezeOp : public OpRewritePattern { public: @@ -5251,6 +5298,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index baefafbc4..aadd27cb9 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -480,6 +480,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 11d1e7b26..d6f064f74 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -643,7 +643,8 @@ def aten〇ones〡shape(size: List[int], dtype: Optional[int] = None, layout: Op def aten〇empty〇memory_format〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return size - +def aten〇empty_strided〡shape(size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + return size def aten〇full〡shape(size: List[int], fill_value: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size @@ -3237,6 +3238,13 @@ def aten〇empty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[ self_rank, self_dtype = self_rank_dtype return self_dtype if dtype is None else dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=0, size=[1], stride=[1], dtype=torch.complex64)) +def aten〇empty_strided〡dtype(size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0) + diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 65c113498..9945db7e0 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -542,6 +542,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") + emit("aten::empty_strided : (int[], int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True) diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 6d2315c3f..27cf2eb4a 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1628,3 +1628,27 @@ class NewEmptyStridedModuleDefaultDtype(torch.nn.Module): @register_test_case(module_factory=lambda: NewEmptyStridedModuleDefaultDtype()) def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + + +# ============================================================================== + + +class EmptyStridedModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4], torch.float32, True), + ]) + def forward(self, a): + x = torch.ops.aten.empty_strided(a.size(), stride=[12, 4, 1]) + y = x.copy_(a) + return y + + +@register_test_case(module_factory=lambda: EmptyStridedModule()) +def EmptyStridedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4))