[Stablehlo] add conversion for AtenFlipOp (#2163)

pull/2235/head snapshot-20230615.870
Yuanqiang Liu 2023-06-15 10:27:34 +08:00 committed by GitHub
parent 7c6961bcbf
commit bba0f5891b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 78 additions and 0 deletions

View File

@ -474,6 +474,8 @@ STABLEHLO_PASS_SET = {
"ExpandModule_basic", "ExpandModule_basic",
"Fill_TensorFloat64WithFloat32Static_basic", "Fill_TensorFloat64WithFloat32Static_basic",
"Fill_TensorFloat64WithInt64Static_basic", "Fill_TensorFloat64WithInt64Static_basic",
"FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic",
"FullLikeModuleDefaultDtype_basic", "FullLikeModuleDefaultDtype_basic",
"FullLikeModuleFalsePinMemory_basic", "FullLikeModuleFalsePinMemory_basic",
"FullLikeModuleFloat2D_basic", "FullLikeModuleFloat2D_basic",

View File

@ -113,6 +113,13 @@ public:
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis))) if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only constant dim lists supported"); "only constant dim lists supported");
for (unsigned i = 0, e = axis.size(); i < e; i++) {
axis[i] = toPositiveDim(axis[i], selfRank);
if (!isValidDim(axis[i], selfRank)) {
return rewriter.notifyMatchFailure(op, "axis is statically invalid");
}
}
// Only used to calculate flipped values, i.e. those on the flip axes. Other // Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used. // dims won't be used.
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self); SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);

View File

@ -1560,6 +1560,7 @@ public:
}; };
} // namespace } // namespace
// AtenFillScalarOp
template <> template <>
LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
AtenFillScalarOp op, OpAdaptor adaptor, AtenFillScalarOp op, OpAdaptor adaptor,
@ -1575,6 +1576,31 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
return success(); return success();
} }
// AtenFlipOp
template <>
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
AtenFlipOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
SmallVector<int64_t> dims;
if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) {
return rewriter.notifyMatchFailure(op, "dims must be a list of const int");
}
for (unsigned i = 0, e = dims.size(); i < e; i++) {
dims[i] = toPositiveDim(dims[i], outType.getRank());
if (!isValidDim(dims[i], outType.getRank())) {
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
}
}
rewriter.replaceOpWithNewOp<stablehlo::ReverseOp>(
op, outType, self, rewriter.getI64TensorAttr(dims));
return success();
}
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
@ -1700,6 +1726,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenUniformOp); INSERT_ATENOP_PATTERN(AtenUniformOp);
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp);
INSERT_ATENOP_PATTERN(AtenFlipOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \

View File

@ -2892,6 +2892,48 @@ class FlipModule(torch.nn.Module):
def FlipModule_basic(module, tu: TestUtils): def FlipModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4)) module.forward(tu.rand(3, 2, 4))
# ==============================================================================
class FlipModuleStaticShape(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 2, 4], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.flip(x, [1, 2])
@register_test_case(module_factory=lambda: FlipModuleStaticShape())
def FlipModuleStaticShape_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4))
# ==============================================================================
class FlipNegativeIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 2, 4], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.flip(x, [-1])
@register_test_case(module_factory=lambda: FlipNegativeIndexModule())
def FlipNegativeIndexModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4))
# ============================================================================== # ==============================================================================