[stablehlo] add aten.remainder.Tensor op conversion support (#3197)

pull/3198/head
penguin_wwy 2024-04-21 00:03:37 +08:00 committed by GitHub
parent b01245c0e8
commit ea0ecb67be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 3 deletions

View File

@ -1817,6 +1817,22 @@ LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
return success();
}
// AtenRemainderTensorOp
template <>
LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
AtenRemainderTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther();
auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
rewriter.replaceOpWithNewOp<stablehlo::RemOp>(op, lhs, rhs);
return success();
}
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
@ -1959,6 +1975,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \

View File

@ -664,9 +664,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
"ElementwiseRemainderTensorModule_Float_basic",
"ElementwiseRemainderTensorModule_Int_Float_basic",
"ElementwiseRemainderTensorModule_Int_basic",
"ElementwiseRsqrtIntModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinIntModule_basic",
@ -1074,6 +1071,9 @@ STABLEHLO_PASS_SET = {
"ElementwisePreluStaticModule_basic",
"ElementwiseReciprocalModule_basic",
"ElementwiseReluModule_basic",
"ElementwiseRemainderTensorModule_Float_basic",
"ElementwiseRemainderTensorModule_Int_Float_basic",
"ElementwiseRemainderTensorModule_Int_basic",
"ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSinModule_basic",