From 68f568b7041dd216a0d34901fcee7375e57840c6 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 21 Nov 2022 14:08:47 +0530 Subject: [PATCH] [MLIR][TORCH] Add E2E support for prims.convert_element_type op Signed-Off By: Vivek Khandelwal --- e2e_testing/xfail_sets.py | 3 ++- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 +++++++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 21 ++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 6 +++++ lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 ++++ .../jit_ir/build_tools/shape_lib_gen.py | 3 +++ .../jit_ir/build_tools/torch_ods_gen.py | 6 +++++ .../test_suite/type_conversion.py | 19 +++++++++++++++ 8 files changed, 85 insertions(+), 1 deletion(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 5ab4ab965..6281ff9dc 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -647,5 +647,6 @@ LTC_XFAIL_SET = { "ConvolutionBackwardModule2D_basic", "ConvolutionBackwardModule2DPadded_basic", "VarMeanCorrectionModule_basic", - "VarMeanCorrectionNoneModule_basic" + "VarMeanCorrectionNoneModule_basic", + "PrimsConvertElementTypeModule_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5e3fa4e62..4a1ada21f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10389,6 +10389,30 @@ def Torch_PrimAbsScalarOp : Torch_Op<"prim.abs.Scalar", [ }]; } +def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prims::convert_element_type : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$a, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsConvertElementTypeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void PrimsConvertElementTypeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ HasValueSemantics, AllowsTypeRefinement, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2623aa4fa..af7b91d17 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3150,6 +3150,25 @@ public: }; } // namespace +namespace { +// Decompose `prims.convert_element_type` op into `aten.to.dtype` op. +class DecomposePrimsConvertElementTypeOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimsConvertElementTypeOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value cstFalse = rewriter.create(loc, false); + Value cstNone = rewriter.create(loc); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.a(), op.dtype(), /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, /*memory_format=*/cstNone); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -3355,6 +3374,8 @@ public: target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); for (std::string opName : legalOps) { target.addLegalOp(OperationName(opName, context)); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 83c7b6378..1520c2baf 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -1074,6 +1074,12 @@ void TypeAnalysis::visitOperation(Operation *op, return; } + if (auto primsConvertElementType = dyn_cast(op)) { + visitAtenToDtypeLikeOp(primsConvertElementType, + operands); + return; + } + if (auto toDtypeLayout = dyn_cast(op)) { visitAtenToDtypeLikeOp(toDtypeLayout, operands); return; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index b73386132..7d5cd420f 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5622,6 +5622,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.convert_element_type\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.to.dtype_layout\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index c64aa6459..19d3181d3 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -433,6 +433,9 @@ def aten〇rsub〇Scalar(self: List[int], other: float, alpha: float = 1) -> Lis def aten〇to〇dtype(self: List[int], dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) +def prims〇convert_element_type(a: List[int], dtype: int) -> List[int]: + return upstream_shape_functions.unary(a) + def aten〇to〇dtype_layout(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]: return self diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 4b53e330d..627c48794 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -653,6 +653,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("prim::tolist : (...) -> (...)") emit("prim::abs.Scalar : (Scalar) -> (Scalar)") + # ========================================================================== + # `prims::` namespace. + # ========================================================================== + + emit("prims::convert_element_type : (Tensor, int) -> (Tensor)") + # ========================================================================== # `quantized::` namespace. # ========================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 53f2d2e0a..6e15da5a4 100644 --- a/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -234,3 +234,22 @@ class TypeAsSameModule(torch.nn.Module): @register_test_case(module_factory=lambda: TypeAsSameModule()) def TypeAsSameModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(3, 5)) + + +# ============================================================================== + + +class PrimsConvertElementTypeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.prims.convert_element_type(x, dtype=torch.int64) + + +@register_test_case(module_factory=lambda: PrimsConvertElementTypeModule()) +def PrimsConvertElementTypeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5))