From 5c43daa3bf344a495979071ecd9498fae2eae03b Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 31 Aug 2023 12:20:22 +0000 Subject: [PATCH] [MLIR][TORCH] Add e2e support for aten.pow.Scalar op Signed-Off By: Vivek Khandelwal --- .../TorchToLinalg/Uncategorized.cpp | 53 ++++++++++++------- .../Transforms/AbstractInterpLibrary.cpp | 13 +++++ python/torch_mlir/__init__.py | 1 - .../build_tools/abstract_interp_lib_gen.py | 9 ++++ .../test_suite/elementwise.py | 22 ++++++++ 5 files changed, 77 insertions(+), 21 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 5291869ec..8684b68d9 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -660,6 +660,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp( divTensorMode.emitError("invalid rounding mode"); return nullptr; } + + if (auto pow = dyn_cast(op)) { + Type dtype = pow.getType().cast().getDtype(); + if (!dtype.isa()) { + pow.emitError("unimplemented: non-floating point dtype"); + return nullptr; + } + Value selfPromoted = convertScalarToDtype(b, loc, operands[0], dtype); + Value expPromoted = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + return b.create(loc, selfPromoted, expPromoted); + } + if (auto pow = dyn_cast(op)) { if (!pow.getType() .cast() @@ -1162,20 +1174,20 @@ public: AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, - AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, + AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, - AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, - AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, - AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, - AtenAtanOp, AtenRealOp, AtenImagOp>(op)) + AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, + AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, + AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, + AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, + AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, + AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, + AtenFillTensorOp, AtenAtanOp, AtenRealOp, AtenImagOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1697,17 +1709,18 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, - AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, - AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, - AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp, - AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, - AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, - AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, - AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenRealOp, AtenImagOp>(); + AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, + AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, + AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, + AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, + AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, + AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, + AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, + AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, + AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenRealOp, AtenImagOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 769118df5..7556dec9d 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6522,6 +6522,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.pow.Scalar\"(%arg0: !torch.float, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9747,6 +9751,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" " return %6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 40c7ac87a..8de6cc1a1 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -446,7 +446,6 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with: if output_type == OutputType.RAW: return mb.module - # mb.module.dump() option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \ " extra-library=" + extra_library_file_name + "}" run_pipeline_with_repro_report( diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index aeefa8f7e..7cf86ccfc 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -267,6 +267,9 @@ def aten〇remainder〇Scalar〡shape(self: List[int], other: float) -> List[int def aten〇floor_divide〇Scalar〡shape(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇pow〇Scalar〡shape(self: float, exponent: List[int]) -> List[int]: + return upstream_shape_functions.unary(exponent) + def aten〇pow〇Tensor_Scalar〡shape(self: List[int], exponent: float) -> List[int]: return upstream_shape_functions.unary(self) @@ -2704,6 +2707,12 @@ def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other dtypes = [self_dtype, get_dtype_of_scalar(other)] return promote_dtypes(ranks, dtypes) +def aten〇pow〇Scalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int: + exponent_rank, exponent_dtype = exponent_rank_dtype + ranks: List[Optional[int]] = [None, exponent_rank] + dtypes = [get_dtype_of_scalar(self), exponent_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0)) def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float, complex]) -> int: diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 40bb9975a..cacc6e04a 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1468,6 +1468,28 @@ def ElementwisePowTensorBroadcastStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwisePowScalarModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.float32, True), + ]) + def forward(self, exp): + return torch.pow(2.0, exp) + + +@register_test_case(module_factory=lambda: ElementwisePowScalarModule()) +def ElementwisePowScalarModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ElementwiseToDtypeF32ToI64Module(torch.nn.Module): def __init__(self):