mirror of https://github.com/llvm/torch-mlir
parent
7c6961bcbf
commit
bba0f5891b
|
@ -474,6 +474,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"ExpandModule_basic",
|
"ExpandModule_basic",
|
||||||
"Fill_TensorFloat64WithFloat32Static_basic",
|
"Fill_TensorFloat64WithFloat32Static_basic",
|
||||||
"Fill_TensorFloat64WithInt64Static_basic",
|
"Fill_TensorFloat64WithInt64Static_basic",
|
||||||
|
"FlipModuleStaticShape_basic",
|
||||||
|
"FlipNegativeIndexModule_basic",
|
||||||
"FullLikeModuleDefaultDtype_basic",
|
"FullLikeModuleDefaultDtype_basic",
|
||||||
"FullLikeModuleFalsePinMemory_basic",
|
"FullLikeModuleFalsePinMemory_basic",
|
||||||
"FullLikeModuleFloat2D_basic",
|
"FullLikeModuleFloat2D_basic",
|
||||||
|
|
|
@ -113,6 +113,13 @@ public:
|
||||||
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
|
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only constant dim lists supported");
|
"only constant dim lists supported");
|
||||||
|
for (unsigned i = 0, e = axis.size(); i < e; i++) {
|
||||||
|
axis[i] = toPositiveDim(axis[i], selfRank);
|
||||||
|
if (!isValidDim(axis[i], selfRank)) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "axis is statically invalid");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
||||||
// dims won't be used.
|
// dims won't be used.
|
||||||
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
|
SmallVector<Value> dims = getTensorSizes(rewriter, loc, self);
|
||||||
|
|
|
@ -1560,6 +1560,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// AtenFillScalarOp
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
||||||
AtenFillScalarOp op, OpAdaptor adaptor,
|
AtenFillScalarOp op, OpAdaptor adaptor,
|
||||||
|
@ -1575,6 +1576,31 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AtenFlipOp
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
|
||||||
|
AtenFlipOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Value self = adaptor.getSelf();
|
||||||
|
auto outType =
|
||||||
|
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
|
|
||||||
|
SmallVector<int64_t> dims;
|
||||||
|
if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "dims must be a list of const int");
|
||||||
|
}
|
||||||
|
for (unsigned i = 0, e = dims.size(); i < e; i++) {
|
||||||
|
dims[i] = toPositiveDim(dims[i], outType.getRank());
|
||||||
|
if (!isValidDim(dims[i], outType.getRank())) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::ReverseOp>(
|
||||||
|
op, outType, self, rewriter.getI64TensorAttr(dims));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
|
@ -1700,6 +1726,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
INSERT_ATENOP_PATTERN(AtenUniformOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
|
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||||
|
|
|
@ -2892,6 +2892,48 @@ class FlipModule(torch.nn.Module):
|
||||||
def FlipModule_basic(module, tu: TestUtils):
|
def FlipModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 2, 4))
|
module.forward(tu.rand(3, 2, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class FlipModuleStaticShape(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 2, 4], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.flip(x, [1, 2])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: FlipModuleStaticShape())
|
||||||
|
def FlipModuleStaticShape_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 2, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class FlipNegativeIndexModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 2, 4], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.flip(x, [-1])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: FlipNegativeIndexModule())
|
||||||
|
def FlipNegativeIndexModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 2, 4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue