From e0a5adb1db38b0072c44b87570bc530eb3b324ad Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 27 May 2024 15:49:50 +0800 Subject: [PATCH] [Torch] fix aten.linear's decomposition (#3391) * support aten.linear with more rank. --- .../Torch/Transforms/DecomposeComplexOps.cpp | 69 ++++++---- projects/pt1/e2e_testing/xfail_sets.py | 11 ++ .../torch_mlir_e2e_test/test_suite/matmul.py | 125 ++++++++++++++++++ 3 files changed, 180 insertions(+), 25 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6ca4fb205..d3c9b8f2f 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5513,38 +5513,57 @@ public: Value bias = op.getBias(); BaseTensorType inputType = cast(input.getType()); - if (!inputType.hasSizes() || inputType.getSizes().size() < 2) - return rewriter.notifyMatchFailure( - op, "expected input to be rank 2 or greater"); + if (!inputType.hasSizes()) + return rewriter.notifyMatchFailure(op, "expected input to have sizes"); BaseTensorType weightType = cast(weight.getType()); - // `weight` must be a rank 2 matrix. - if (!weightType.hasSizes() || weightType.getSizes().size() != 2) - return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2"); + if (!weightType.hasSizes()) + return rewriter.notifyMatchFailure(op, "expected weight to have sizes"); - SmallVector transposeShape = - llvm::to_vector(llvm::reverse(weightType.getSizes())); - Type transposeType = weightType.getWithSizesAndDtype( - llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); - Value transposeWeight = - rewriter.create(loc, transposeType, weight); + auto transposeWeight = [&]() -> Value { + SmallVector transposeShape = + llvm::to_vector(llvm::reverse(weightType.getSizes())); + Type transposeType = weightType.getWithSizesAndDtype( + llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); + Value transposeWeight = + rewriter.create(loc, transposeType, weight); + return transposeWeight; + }; - Value matmul = rewriter.create(loc, op.getType(), input, - transposeWeight); if (bias.getType().isa()) { - rewriter.replaceOp(op, matmul); + auto weightRank = weightType.getSizes().size(); + if (weightRank > 2 || weightRank <= 0) + return rewriter.notifyMatchFailure( + op, "expected weight's rank <= 2 && >= 1"); + if (weightRank == 1) { + rewriter.replaceOpWithNewOp(op, op.getType(), input, + weight); + return success(); + } else if (weightRank == 2) { + rewriter.replaceOpWithNewOp(op, op.getType(), input, + transposeWeight()); + return success(); + } + llvm_unreachable("unsupported weightRank"); + } else { + BaseTensorType biasType = cast(bias.getType()); + if (!biasType.hasSizes() || biasType.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); + + // `weight` must be a rank 2 matrix. + auto weightRank = weightType.getSizes().size(); + if (weightRank != 2) + return rewriter.notifyMatchFailure(op, + "expected weight to be a rank 2"); + + Value matmul = rewriter.create(loc, op.getType(), input, + transposeWeight()); + Value alpha = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + rewriter.replaceOpWithNewOp(op, op.getType(), matmul, + op.getBias(), alpha); return success(); } - - BaseTensorType biasType = cast(bias.getType()); - if (!biasType.hasSizes() || biasType.getSizes().size() != 1) - return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); - - Value alpha = - rewriter.create(loc, rewriter.getF64FloatAttr(1)); - rewriter.replaceOpWithNewOp(op, op.getType(), matmul, - op.getBias(), alpha); - return success(); } }; } // namespace diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 30dd72312..578af98d1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -814,6 +814,12 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = { } STABLEHLO_PASS_SET = { + "AtenLinear1D_basic", + "AtenLinear2D_basic", + "AtenLinear3DBias_basic", + "AtenLinearMatVec_basic", + "AtenLinearVecMatBias_basic", + "AtenLinearVecMat_basic", "ReduceAminSingleDim_basic", "AtenDotModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", @@ -1447,6 +1453,8 @@ STABLEHLO_CRASHING_SET = set() # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AtenLinear2D_basic", + "AtenLinear3DBias_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseDivTensorFloatModule_basic", "ElementwiseMulTensorFloatModule_basic", @@ -1911,6 +1919,9 @@ MAKE_FX_TOSA_PASS_SET = ( TOSA_PASS_SET | { ### Tests additionally passing in make_fx_tosa + "AtenLinear1D_basic", + "AtenLinearMatVec_basic", + "AtenLinearVecMatBias_basic", "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dStaticCeilModeTrueModule_basic", "MaxPool1dStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 3b9f022fa..6c556a07a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -622,6 +622,131 @@ def AtenMatmulQMixedSigni8Transpose_basic(module, tu: TestUtils): # ============================================================================== +class AtenLinear1D(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([3], torch.float32, True), + ([3], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.ops.aten.linear(a, b) + + +@register_test_case(module_factory=lambda: AtenLinear1D()) +def AtenLinear1D_basic(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(3)) + + +# ============================================================================== + + +class AtenLinearMatVec(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ([4], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.ops.aten.linear(a, b) + + +@register_test_case(module_factory=lambda: AtenLinearMatVec()) +def AtenLinearMatVec_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4), tu.rand(4)) + + +# ============================================================================== + + +class AtenLinearVecMat(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([4], torch.float32, True), + ([3, 4], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.ops.aten.linear(a, b) + + +@register_test_case(module_factory=lambda: AtenLinearVecMat()) +def AtenLinearVecMat_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(3, 4)) + + +class AtenLinearVecMatBias(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([4], torch.float32, True), + ([3, 4], torch.float32, True), + ([3], torch.float32, True), + ] + ) + def forward(self, a, b, c): + return torch.ops.aten.linear(a, b, c) + + +@register_test_case(module_factory=lambda: AtenLinearVecMatBias()) +def AtenLinearVecMatBias_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(3, 4), tu.rand(3)) + + +# ============================================================================== + + +class AtenLinear2D(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ([5, 4], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.ops.aten.linear(a, b) + + +@register_test_case(module_factory=lambda: AtenLinear2D()) +def AtenLinear2D_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4), tu.rand(5, 4)) + + +# ============================================================================== + + +class AtenLinear3DBias(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([3, 6, 4], torch.float32, True), + ([5, 4], torch.float32, True), + ([5], torch.float32, True), + ] + ) + def forward(self, a, b, c): + return torch.ops.aten.linear(a, b, c) + + +@register_test_case(module_factory=lambda: AtenLinear3DBias()) +def AtenLinear3DBias_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 6, 4), tu.rand(5, 4), tu.rand(5)) + + +# ============================================================================== + + class AtenLinalgCrossInt(torch.nn.Module): @export @annotate_args(