From 38ed5593983fff4de09ab3040f3cfd2d0ffc3013 Mon Sep 17 00:00:00 2001 From: Ziheng Jiang Date: Thu, 16 Feb 2023 20:26:46 -0800 Subject: [PATCH] [StableHLO] Add support for AtenPowTensorScalar. (#1883) * [MHLO] Add support for AtenPowTensorScalar. * Update. --------- Co-authored-by: Ziheng Jiang --- e2e_testing/xfail_sets.py | 1 + lib/Conversion/TorchToStablehlo/Basic.cpp | 39 +++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index edcba2624..bd19de899 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -136,6 +136,7 @@ STABLEHLO_PASS_SET = { "ElementwiseClampModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampMaxModule_basic", + "ElementwisePowModule_basic", "ElementwiseExpModule_basic", "ElementwiseFlattenBroadcastModule_basic", "ElementwiseLeakyReluModule_basic", diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 585c6006f..d84fbaf9b 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -771,6 +771,44 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenPowTensorScalarOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPowTensorScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = adaptor.getSelf(); + auto lhsType = lhs.getType().dyn_cast(); + Value rhs = adaptor.getExponent(); + TensorType rhsType = rhs.getType().dyn_cast(); + + if (!lhsType) + return op.emitError("only Tensor types supported in StableHLO"); + + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + + if (!rhsType) { + rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, + outElemTy); + } + DenseIntElementsAttr bcastDimensions; + lhs = hlo::promoteType(rewriter, lhs, outType); + rhs = hlo::promoteType(rewriter, rhs, outType); + auto loc = op.getLoc(); + Value result = + rewriter.create(loc, outType, lhs, rhs, bcastDimensions); + + rewriter.replaceOp(op, result); + return success(); +} + // PrimNumToTensorScalarOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1464,6 +1502,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); + INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenContiguousOp);