diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index 541c02a07..71b675b5e 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -161,12 +161,70 @@ public: using ConvertAtenOp::ConvertAtenOp; using OpAdaptor = typename AtenOpT::Adaptor; + unsigned getBitWidth(Type type) const { + if (auto complexTy = dyn_cast(type)) + return 2 * getBitWidth(complexTy.getElementType()); + return type.getIntOrFloatBitWidth(); + } + LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto rankType = dyn_cast(adaptor.getSelf().getType()); if (!rankType) - return op.emitError("Only ranked tensor types are currently supported"); + return op.emitError("Only ranked tensor types are currently supported."); + auto loc = op.getLoc(); + + // support AtenViewDtypeOp + if (isa(op)) { + auto self = adaptor.getSelf(); + auto baseResultTy = dyn_cast(op.getType()); + + // infer the result shape + auto operandElt = rankType.getElementType(); + auto targetElt = baseResultTy.getDtype(); + auto operandEltBitWidth = getBitWidth(operandElt); + auto targetEltBitWidth = getBitWidth(targetElt); + auto operandSizes = rankType.getShape(); + SmallVector castShape(operandSizes); + if (operandEltBitWidth > targetEltBitWidth) { + int64_t last_size = operandEltBitWidth / targetEltBitWidth; + castShape.push_back(last_size); + } else if (operandEltBitWidth < targetEltBitWidth) { + int64_t last_size = targetEltBitWidth / operandEltBitWidth; + if (!ShapedType::isDynamic(castShape.back()) and + last_size != castShape.back()) { + return rewriter.notifyMatchFailure( + op, "The last dim size is not equal to targetEltBitWidth / " + "operandEltBitWidth."); + } else { + castShape.pop_back(); + } + } + + auto resultType = + OpConversionPattern::getTypeConverter()->convertType( + baseResultTy); + if (!dyn_cast(resultType).hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "Currently only support static output shape."); + } + + auto castType = + baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype()); + auto cast = rewriter.create( + loc, + OpConversionPattern::getTypeConverter()->convertType( + castType), + self); + + auto reshape = + rewriter.create(loc, resultType, cast); + + rewriter.replaceOp(op, reshape); + + return success(); + } // collect Value of dims SmallVector dimSizes; @@ -174,7 +232,6 @@ public: return op.emitError("Dims size must be a list of Scalar"); } - auto loc = op.getLoc(); if (dimSizes.size() == 0 || rankType.getRank() == 0) { rewriter.replaceOpWithNewOp( op, @@ -236,6 +293,13 @@ public: SmallVector &dimSizes) const; }; +template <> +bool ConvertAtenViewOp::getAtenViewOpSizes( + AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + SmallVector &dimSizes) const { + return false; +} + template <> bool ConvertAtenViewOp::getAtenViewOpSizes( AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, @@ -496,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( #define INSERT_VIEW_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) + INSERT_VIEW_OP_PATTERN(AtenViewDtypeOp); INSERT_VIEW_OP_PATTERN(AtenViewOp); INSERT_VIEW_OP_PATTERN(AtenReshapeOp); #undef INSERT_VIEW_OP_PATTERN diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 052eceb5a..326a7afe8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -506,6 +506,7 @@ FX_IMPORTER_XFAIL_SET = { "UpSampleNearest2dDynamicFactor_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", + "ViewDtypeStaticModule_basic", "WeightNormInterfaceModule_basic", # Error: `aten.as_strided` op is not supported "ChunkListUnpackDynamic_Module_basic", @@ -3169,6 +3170,7 @@ ONNX_XFAIL_SET = { "Unfold_Module_Rank_Zero_basic", "Unfold_Module_Rank_Zero_Size_Zero_basic", "Unfold_Module_Dynamic_basic", + "ViewDtypeStaticModule_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index ee9cbbf05..9e2d2693b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1174,6 +1174,30 @@ def ReshapeDynamicModule_basic(module, tu: TestUtils): # ============================================================================== +class ViewDtypeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([12, 1], torch.float32, True), + ] + ) + def forward(self, a): + res = a.view(torch.int8) + return res + + +@register_test_case(module_factory=lambda: ViewDtypeStaticModule()) +def ViewDtypeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(12, 1)) + + +# ============================================================================== + + class ReshapeAliasCollapseModule(torch.nn.Module): def __init__(self): super().__init__()