diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp index 08da5c073..e9983dd0f 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -373,10 +373,10 @@ public: auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); - if (lhsRank < 1) - return op.emitError("aten.Linear called but input rank 0"); - if (rhsRank != 2) - return op.emitError("aten.Linear called but weight rank not 2"); + if (lhsRank != 2 && lhsRank != 3) + return op.emitError("aten.Linear called but input rank not 2 or 3"); + if (rhsRank != 2 && rhsRank != 3) + return op.emitError("aten.Linear called but weight rank not 2 or 3"); return success(); } @@ -406,59 +406,33 @@ public: auto lhsTy = lhs.getType().cast(); auto rhsTy = rhs.getType().cast(); - auto lhsRank = lhsTy.getRank(); + auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(), + rhsTy.getRank() - lhsTy.getRank()); - auto loc = op->getLoc(); - Value dotLhs; - SmallVector resultDims; - // vector * matrix or matrix * matrix can directly use mhlo.dot_general - if (lhsTy.getRank() <= 2) { - dotLhs = lhs; - } else { - // Broadcast weight and then use bmm would lead to too much data copy, - // and more compute, then decreace the performance - // Instead, reshape input to 2-D tensor, then use dot to perform - // matrix-matrix multiply, and finnaly reshape to the output shape, - // would get better performance + const auto &options = ConvertAtenOp::getOptions(); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank, + options.dimSizeIndexBits); + auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank()); + auto nBatchDims = resultRank - 2; + auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); - // [x_1, x_2, ..., x_n, in_features] * [in_features, out_features] - // -> [x_1 * x_2 * ... * x_n , in_features] * [in_features, out_features] - auto dotLhsTy = RankedTensorType::get( - {ShapedType::kDynamicSize, lhsTy.getShape()[lhsRank - 1]}, - lhsTy.getElementType()); - const auto &options = ConvertAtenOp::getOptions(); - Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); - Value numel = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + auto lhsResultDim = nBatchDims; + auto rhsResultDim = nBatchDims + 1; + auto lhsContractingDim = nBatchDims + 1; + auto rhsContractingDim = nBatchDims; - for (int i = 0; i < lhsRank - 1; ++i) { - Value dimValue = rewriter.create(loc, lhs, i); - resultDims.push_back(dimValue); - numel = rewriter.create( - loc, numel, - rewriter.create(loc, intType, dimValue)); - } - Value lhsLastRankDim = rewriter.create( - loc, intType, rewriter.create(loc, lhs, lhsRank - 1)); - resultDims.push_back(rewriter.create(loc, rhs, 1)); - Value reshapeDim = - rewriter - .create( - op->getLoc(), ValueRange{numel, lhsLastRankDim}) - .getResult(); - dotLhs = rewriter.create(loc, dotLhsTy, lhs, - reshapeDim); - } - Value matmulOutput = - rewriter.create(loc, dotLhs, rhs, nullptr); auto outTy = - ConvertAtenOp::getTypeConverter()->convertType(op.getType()); - // reshape to [x_1, x_2, ..., x_n, out_features] - if (dotLhs != lhs) { - matmulOutput = rewriter.create( - loc, outTy, matmulOutput, - rewriter.create(loc, resultDims)); - } + castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, + lhsContractingDim, rhsContractingDim); + mhlo::DotDimensionNumbersAttr dotDimensionNumbers = + mhlo::DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhsBatchingDimensions=*/batchDims, + /*rhsBatchingDimensions=*/batchDims, + /*lhsContractingDimensions=*/{lhsContractingDim}, + /*rhsContractingDimensions=*/{rhsContractingDim}); + Value matmulOutput = rewriter.create( + op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); Value matmulPlusBias = matmulOutput; if (!biasTy.template isa()) { @@ -469,7 +443,9 @@ public: .getResult(); } - rewriter.replaceOp(op, matmulPlusBias); + auto resultTy = + ConvertAtenOp::getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resultTy, matmulPlusBias); return success(); } }; diff --git a/test/Conversion/TorchToMhlo/linear.mlir b/test/Conversion/TorchToMhlo/linear.mlir index ca0bed695..165c874ea 100644 --- a/test/Conversion/TorchToMhlo/linear.mlir +++ b/test/Conversion/TorchToMhlo/linear.mlir @@ -499,60 +499,3 @@ func.func @torch.aten.convolution$transposed_groups(%arg0: !torch.vtensor<[1,2,7 %3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int2 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,15,15],f32> return %3 : !torch.vtensor<[1,4,15,15],f32> } - -// ----- - -// CHECK-LABEL: func.func @torch.aten.linear( -// CHECK-NOT: mhlo.dynamic_reshape -// CHECK: mhlo.transpose -// CHECK: mhlo.dot -// CHECK: chlo.broadcast_add -func.func @torch.aten.linear(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[4,5],f32> { - %1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[4,5],f32> - return %1 : !torch.vtensor<[4,5],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.linear$nobias( -// CHECK-NOT: mhlo.dynamic_reshape -// CHECK: mhlo.transpose -// CHECK: mhlo.dot -// CHECK-NOT: chlo.broadcast_add -func.func @torch.aten.linear$nobias(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[4,5],f32> { - %none = torch.constant.none - %1 = torch.aten.linear %arg0, %arg1, %none : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.none -> !torch.vtensor<[4,5],f32> - return %1 : !torch.vtensor<[4,5],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.linear$dynamic( -// CHECK: mhlo.transpose -// CHECK: arith.muli -// CHECK: arith.muli -// CHECK: tensor.from_elements -// CHECK: mhlo.dynamic_reshape -// CHECK: mhlo.dot -// CHECK: mhlo.dynamic_reshape -// CHECK: chlo.broadcast_add -func.func @torch.aten.linear$dynamic(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,5],f32> { - %1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,5],f32> - return %1 : !torch.vtensor<[?,?,5],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.linear$dynamic4D( -// CHECK: mhlo.transpose -// CHECK: arith.muli -// CHECK: arith.muli -// CHECK: tensor.from_elements -// CHECK: mhlo.dynamic_reshape -// CHECK: mhlo.dot -// CHECK: mhlo.dynamic_reshape -// CHECK: chlo.broadcast_add -func.func @torch.aten.linear$dynamic4D(%arg0: !torch.vtensor<[?,?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,?,5],f32> { - %1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,?,5],f32> - return %1 : !torch.vtensor<[?,?,?,5],f32> -}