diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 432a3ad6e..8f949d9ba 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15909,6 +15909,34 @@ def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [ let hasFolder = 1; } +def Torch_PrimsIotaOp : Torch_Op<"prims.iota", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prims::iota : (int, int, int, int, Device, bool) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$length, + Torch_IntType:$start, + Torch_IntType:$step, + Torch_IntType:$dtype, + Torch_DeviceType:$device, + Torch_BoolType:$requires_grad + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsIotaOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void PrimsIotaOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ HasValueSemantics, AllowsTypeRefinement, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5ad9b8216..a1cc7ddf6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8653,6 +8653,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list {\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 97bd85063..87f93ba9c 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4789,6 +4789,35 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern { }; } // namespace +namespace { +// The `prims.iota` op is converted to `aten.arange.startStep` op. +class DecomposePrimsIotaOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimsIotaOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + int64_t length, start, step; + if (!matchPattern(op.getLength(), m_TorchConstantInt(&length))) + return rewriter.notifyMatchFailure( + op, "unimplemented: low must be a constant integer"); + if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + return rewriter.notifyMatchFailure( + op, "unimplemented: low must be a constant integer"); + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + return rewriter.notifyMatchFailure( + op, "unimplemented: low must be a constant integer"); + auto endVal = rewriter.create( + loc, rewriter.getI64IntegerAttr(start + length * step)); + auto none = rewriter.create(loc); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getStart(), endVal, op.getStep(), op.getDtype(), + none, op.getDevice(), none); + return success(); + } +}; +} // namespace + namespace { // Decompose constant tensor full like ops. template @@ -7605,6 +7634,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1fd88ce6e..0607a9720 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1228,6 +1228,7 @@ STABLEHLO_PASS_SET = { "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", "PrimsConvertElementTypeModule_basic", + "PrimsIotaModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", @@ -1789,6 +1790,7 @@ TOSA_PASS_SET = { "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "PrimListUnpackNumMismatchModule_basic", + "PrimsIotaModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", @@ -2683,6 +2685,9 @@ ONNX_XFAIL_SET = { "SqueezeModule_allUnitDim", "SqueezeModule_broadcast", "SqueezeModule_static", + + # RuntimeError: unsupported input type: Device + "PrimsIotaModule_basic", # Failure - unknown "BernoulliModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 5d63d7a7d..8b32ff602 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1319,6 +1319,12 @@ def prims〇view_of〡dtype(a_rank_dtype: Tuple[int, int]) -> int: _, a_dtype = a_rank_dtype return a_dtype +def prims〇iota〡shape(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> List[int]: + return [length] + +def prims〇iota〡dtype(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> int: + return dtype + def prim〇NumToTensor〇Scalar〡shape(a: float) -> List[int]: return [] diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index e6258ece8..5c4e4d214 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -897,6 +897,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("prims::split_dim : (Tensor, int, int) -> (Tensor)") emit("prims::squeeze : (Tensor, int[]) -> (Tensor)") emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True) + emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)") # ========================================================================== # `quantized::` namespace. diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py index fff3e60c4..948901307 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py @@ -380,3 +380,20 @@ class LinspaceTwoSizeModule(torch.nn.Module): @register_test_case(module_factory=lambda: LinspaceTwoSizeModule()) def LinspaceTwoSizeModule_basic(module, tu: TestUtils): module.forward() + + +class PrimsIotaModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device='cpu', + requires_grad=False) + +@register_test_case(module_factory=lambda: PrimsIotaModule()) +def PrimsIotaModule_basic(module, tu: TestUtils): + module.forward()