[StableHLO] Add support for AtenPowTensorScalar. (#1883)

* [MHLO] Add support for AtenPowTensorScalar.

* Update.

---------

Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
pull/1890/head snapshot-20230217.752
Ziheng Jiang 2023-02-16 20:26:46 -08:00 committed by GitHub
parent e85def790c
commit 38ed559398
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 0 deletions

View File

@ -136,6 +136,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseClampModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwisePowModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseLeakyReluModule_basic",

View File

@ -771,6 +771,44 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
return success();
}
// AtenPowTensorScalarOp
template <>
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
AtenPowTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
auto lhsType = lhs.getType().dyn_cast<TensorType>();
Value rhs = adaptor.getExponent();
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO");
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"only floating-point or integer datatype legalization supported");
}
if (!rhsType) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs,
outElemTy);
}
DenseIntElementsAttr bcastDimensions;
lhs = hlo::promoteType(rewriter, lhs, outType);
rhs = hlo::promoteType(rewriter, rhs, outType);
auto loc = op.getLoc();
Value result =
rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs, bcastDimensions);
rewriter.replaceOp(op, result);
return success();
}
// PrimNumToTensorScalarOp
template <>
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
@ -1464,6 +1502,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);