mirror of https://github.com/llvm/torch-mlir
[StableHLO] Add support for AtenPowTensorScalar. (#1883)
* [MHLO] Add support for AtenPowTensorScalar. * Update. --------- Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>pull/1890/head snapshot-20230217.752
parent
e85def790c
commit
38ed559398
|
@ -136,6 +136,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseClampModule_basic",
|
||||
"ElementwiseClampMinModule_basic",
|
||||
"ElementwiseClampMaxModule_basic",
|
||||
"ElementwisePowModule_basic",
|
||||
"ElementwiseExpModule_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"ElementwiseLeakyReluModule_basic",
|
||||
|
|
|
@ -771,6 +771,44 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenPowTensorScalarOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value lhs = adaptor.getSelf();
|
||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||
Value rhs = adaptor.getExponent();
|
||||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
if (!rhsType) {
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs,
|
||||
outElemTy);
|
||||
}
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||
auto loc = op.getLoc();
|
||||
Value result =
|
||||
rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs, bcastDimensions);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
// PrimNumToTensorScalarOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||
|
@ -1464,6 +1502,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
|
||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
||||
|
||||
|
|
Loading…
Reference in New Issue