mirror of https://github.com/llvm/torch-mlir
parent
7c6961bcbf
commit
bba0f5891b
|
@ -474,6 +474,8 @@ STABLEHLO_PASS_SET = {
|
|||
"ExpandModule_basic",
|
||||
"Fill_TensorFloat64WithFloat32Static_basic",
|
||||
"Fill_TensorFloat64WithInt64Static_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
"FlipNegativeIndexModule_basic",
|
||||
"FullLikeModuleDefaultDtype_basic",
|
||||
"FullLikeModuleFalsePinMemory_basic",
|
||||
"FullLikeModuleFloat2D_basic",
|
||||
|
|
|
@ -113,6 +113,13 @@ public:
|
|||
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only constant dim lists supported");
|
||||
for (unsigned i = 0, e = axis.size(); i < e; i++) {
|
||||
axis[i] = toPositiveDim(axis[i], selfRank);
|
||||
if (!isValidDim(axis[i], selfRank)) {
|
||||
return rewriter.notifyMatchFailure(op, "axis is statically invalid");
|
||||
}
|
||||
}
|
||||
|
||||
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
||||
// dims won't be used.
|
||||
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
|
||||
|
|
|
@ -1560,6 +1560,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// AtenFillScalarOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
||||
AtenFillScalarOp op, OpAdaptor adaptor,
|
||||
|
@ -1575,6 +1576,31 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenFlipOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
|
||||
AtenFlipOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
SmallVector<int64_t> dims;
|
||||
if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) {
|
||||
return rewriter.notifyMatchFailure(op, "dims must be a list of const int");
|
||||
}
|
||||
for (unsigned i = 0, e = dims.size(); i < e; i++) {
|
||||
dims[i] = toPositiveDim(dims[i], outType.getRank());
|
||||
if (!isValidDim(dims[i], outType.getRank())) {
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReverseOp>(
|
||||
op, outType, self, rewriter.getI64TensorAttr(dims));
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
|
@ -1700,6 +1726,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||
|
|
|
@ -2892,6 +2892,48 @@ class FlipModule(torch.nn.Module):
|
|||
def FlipModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class FlipModuleStaticShape(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 2, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.flip(x, [1, 2])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: FlipModuleStaticShape())
|
||||
def FlipModuleStaticShape_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class FlipNegativeIndexModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 2, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.flip(x, [-1])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: FlipNegativeIndexModule())
|
||||
def FlipNegativeIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue