[stablehlo] support aten.view.dtype lowering (#3778)

pull/3781/head
yyp0 2024-10-10 15:50:17 +08:00 committed by GitHub
parent 94f5410913
commit d0041dc310
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 93 additions and 2 deletions

View File

@ -161,12 +161,70 @@ public:
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor;
unsigned getBitWidth(Type type) const {
if (auto complexTy = dyn_cast<ComplexType>(type))
return 2 * getBitWidth(complexTy.getElementType());
return type.getIntOrFloatBitWidth();
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
if (!rankType)
return op.emitError("Only ranked tensor types are currently supported");
return op.emitError("Only ranked tensor types are currently supported.");
auto loc = op.getLoc();
// support AtenViewDtypeOp
if (isa<AtenViewDtypeOp>(op)) {
auto self = adaptor.getSelf();
auto baseResultTy = dyn_cast<BaseTensorType>(op.getType());
// infer the result shape
auto operandElt = rankType.getElementType();
auto targetElt = baseResultTy.getDtype();
auto operandEltBitWidth = getBitWidth(operandElt);
auto targetEltBitWidth = getBitWidth(targetElt);
auto operandSizes = rankType.getShape();
SmallVector<int64_t> castShape(operandSizes);
if (operandEltBitWidth > targetEltBitWidth) {
int64_t last_size = operandEltBitWidth / targetEltBitWidth;
castShape.push_back(last_size);
} else if (operandEltBitWidth < targetEltBitWidth) {
int64_t last_size = targetEltBitWidth / operandEltBitWidth;
if (!ShapedType::isDynamic(castShape.back()) and
last_size != castShape.back()) {
return rewriter.notifyMatchFailure(
op, "The last dim size is not equal to targetEltBitWidth / "
"operandEltBitWidth.");
} else {
castShape.pop_back();
}
}
auto resultType =
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
baseResultTy);
if (!dyn_cast<ShapedType>(resultType).hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "Currently only support static output shape.");
}
auto castType =
baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype());
auto cast = rewriter.create<stablehlo::BitcastConvertOp>(
loc,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
castType),
self);
auto reshape =
rewriter.create<stablehlo::ReshapeOp>(loc, resultType, cast);
rewriter.replaceOp(op, reshape);
return success();
}
// collect Value of dims
SmallVector<Value, 4> dimSizes;
@ -174,7 +232,6 @@ public:
return op.emitError("Dims size must be a list of Scalar");
}
auto loc = op.getLoc();
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op,
@ -236,6 +293,13 @@ public:
SmallVector<Value, 4> &dimSizes) const;
};
template <>
bool ConvertAtenViewOp<AtenViewDtypeOp>::getAtenViewOpSizes(
AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
SmallVector<Value, 4> &dimSizes) const {
return false;
}
template <>
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
@ -496,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
INSERT_VIEW_OP_PATTERN(AtenViewDtypeOp);
INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
#undef INSERT_VIEW_OP_PATTERN

View File

@ -506,6 +506,7 @@ FX_IMPORTER_XFAIL_SET = {
"UpSampleNearest2dDynamicFactor_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic",
"ViewDtypeStaticModule_basic",
"WeightNormInterfaceModule_basic",
# Error: `aten.as_strided` op is not supported
"ChunkListUnpackDynamic_Module_basic",
@ -3169,6 +3170,7 @@ ONNX_XFAIL_SET = {
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Dynamic_basic",
"ViewDtypeStaticModule_basic",
}
if torch_version_for_comparison() < version.parse("2.3.0.dev"):

View File

@ -1174,6 +1174,30 @@ def ReshapeDynamicModule_basic(module, tu: TestUtils):
# ==============================================================================
class ViewDtypeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([12, 1], torch.float32, True),
]
)
def forward(self, a):
res = a.view(torch.int8)
return res
@register_test_case(module_factory=lambda: ViewDtypeStaticModule())
def ViewDtypeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(12, 1))
# ==============================================================================
class ReshapeAliasCollapseModule(torch.nn.Module):
def __init__(self):
super().__init__()