mirror of https://github.com/llvm/torch-mlir
[StableHLO] Add support for AtenPowTensorScalar. (#1883)
* [MHLO] Add support for AtenPowTensorScalar. * Update. --------- Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>pull/1890/head snapshot-20230217.752
parent
e85def790c
commit
38ed559398
|
@ -136,6 +136,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseClampModule_basic",
|
"ElementwiseClampModule_basic",
|
||||||
"ElementwiseClampMinModule_basic",
|
"ElementwiseClampMinModule_basic",
|
||||||
"ElementwiseClampMaxModule_basic",
|
"ElementwiseClampMaxModule_basic",
|
||||||
|
"ElementwisePowModule_basic",
|
||||||
"ElementwiseExpModule_basic",
|
"ElementwiseExpModule_basic",
|
||||||
"ElementwiseFlattenBroadcastModule_basic",
|
"ElementwiseFlattenBroadcastModule_basic",
|
||||||
"ElementwiseLeakyReluModule_basic",
|
"ElementwiseLeakyReluModule_basic",
|
||||||
|
|
|
@ -771,6 +771,44 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AtenPowTensorScalarOp
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
|
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value lhs = adaptor.getSelf();
|
||||||
|
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||||
|
Value rhs = adaptor.getExponent();
|
||||||
|
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||||
|
|
||||||
|
if (!lhsType)
|
||||||
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
|
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template cast<TensorType>();
|
||||||
|
|
||||||
|
Type outElemTy = outType.getElementType();
|
||||||
|
if (!outElemTy.isIntOrFloat()) {
|
||||||
|
return op.emitError(
|
||||||
|
"only floating-point or integer datatype legalization supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!rhsType) {
|
||||||
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs,
|
||||||
|
outElemTy);
|
||||||
|
}
|
||||||
|
DenseIntElementsAttr bcastDimensions;
|
||||||
|
lhs = hlo::promoteType(rewriter, lhs, outType);
|
||||||
|
rhs = hlo::promoteType(rewriter, rhs, outType);
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value result =
|
||||||
|
rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs, bcastDimensions);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// PrimNumToTensorScalarOp
|
// PrimNumToTensorScalarOp
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||||
|
@ -1464,6 +1502,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
|
|
||||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
||||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue