mirror of https://github.com/llvm/torch-mlir
[stablehlo] add aten.remainder.Tensor op conversion support (#3197)
parent
b01245c0e8
commit
ea0ecb67be
|
@ -1817,6 +1817,22 @@ LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenRemainderTensorOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
|
||||
AtenRemainderTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getOther();
|
||||
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
|
||||
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::RemOp>(op, lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
|
@ -1959,6 +1975,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||
|
|
|
@ -664,9 +664,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_basic",
|
||||
"ElementwiseRsqrtIntModule_basic",
|
||||
"ElementwiseSigmoidIntModule_basic",
|
||||
"ElementwiseSinIntModule_basic",
|
||||
|
@ -1074,6 +1071,9 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwisePreluStaticModule_basic",
|
||||
"ElementwiseReciprocalModule_basic",
|
||||
"ElementwiseReluModule_basic",
|
||||
"ElementwiseRemainderTensorModule_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_basic",
|
||||
"ElementwiseRsqrtModule_basic",
|
||||
"ElementwiseSigmoidModule_basic",
|
||||
"ElementwiseSinModule_basic",
|
||||
|
|
Loading…
Reference in New Issue