From 99093d0623c3d3e6fb961ad0c8fda5ee9fc264fe Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Wed, 18 May 2022 21:59:04 +0530 Subject: [PATCH] [TORCH] Add decomposition of `aten.linear` op This commit adds decomposition of `aten.linear` op. Due to limited support at tosa backend in case of dynamic dimensions, this decomposition is currently disabled for tosa backend. Signed-Off-by: Gaurav Shukla --- lib/Conversion/TorchToLinalg/Linear.cpp | 187 ++---------------- .../Torch/Transforms/DecomposeComplexOps.cpp | 51 +++++ python/torch_mlir/__init__.py | 2 +- 3 files changed, 67 insertions(+), 173 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9482187e5..9c01db32c 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -296,6 +296,21 @@ public: op, "unable to perform broadcast operation"); } + if (maxRank == 3) { + Value zeroTensor = createZeroInitTensor( + rewriter, loc, + ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1}, + elementType); + Value matmul = + rewriter + .create( + loc, zeroTensor.getType(), + ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultType, matmul); + return success(); + } + // Check if the result of the matrix multiplication has more than one // dynamic batch dimensions. ArrayRef batchDimsInt = resultType.getShape().drop_back(2); @@ -454,176 +469,6 @@ public: }; } // namespace -namespace { -// See comments at in convertMmOp and the heading for this section for general -// considerations. This function needs to be auto-generated. -class ConvertAtenLinearOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenLinearOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MLIRContext *context = op->getContext(); - Location loc = op->getLoc(); - Value input = adaptor.input(); - Value weight = adaptor.weight(); - Value bias = adaptor.bias(); - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - auto inputType = input.getType().cast(); - auto weightType = weight.getType().cast(); - - if (inputType.getRank() != 2 && inputType.getRank() != 3) { - return rewriter.notifyMatchFailure( - op, "expected input to be rank 2 or rank 3"); - } - - if (!bias.getType().isa()) { - auto biasType = bias.getType().cast(); - // Only handle the case of rank 2 `weight` for now. - // TODO: Insert the appropriate reshape to collapse any leading dimensions. - if (weightType.getRank() != 2 || biasType.getRank() != 1) { - return rewriter.notifyMatchFailure( - op, "expected weight to be rank 2 and bias to be rank 1"); - } - // TODO: Handle type promotion. What are ATen's promotion rules? - if (inputType.getElementType() != weightType.getElementType() || - inputType.getElementType() != biasType.getElementType()) { - return rewriter.notifyMatchFailure(op, "unimplemented: type promotion"); - } - // TODO: We can handle a static size 1 here at some complexity cost, but the - // dynamic case is not representable in linalg. We don't handle either for - // now. Biases are generally statically shaped for most models (since for - // inference they are constants, and for training they don't change shape - // typically), so this is not too constraining. - auto biasSize = bias.getType().cast().getShape()[0]; - if (biasSize == 1 || biasSize == ShapedType::kDynamicSize) - return rewriter.notifyMatchFailure( - op, "unimplemented: size-1 broadcasting for aten::LinearOp"); - } - - - Value batchDim = nullptr; - int restDim = 0; - if (inputType.getRank() == 3) { - batchDim = getDimOp(rewriter, loc, input, 0); - restDim = 1; - } - - Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0); - Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1); - Value weightDim0 = getDimOp(rewriter, loc, weight, 0); - Value weightDim1 = getDimOp(rewriter, loc, weight, 1); - Value contractingDimEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, inputDim1, weightDim1); - rewriter.create( - loc, contractingDimEqual, - rewriter.getStringAttr( - "mismatching contracting dimension for aten.linear")); - - if (!bias.getType().isa()) { - Value biasDim0 = getDimOp(rewriter, loc, bias, 0); - // Here we take advantage of ruling out the size-1 case above. - // In the static-size-1 case, we will not emit this check at all. - Value biasSizeCorrect = rewriter.create( - loc, arith::CmpIPredicate::eq, weightDim0, biasDim0); - rewriter.create( - loc, biasSizeCorrect, - rewriter.getStringAttr("mismatching bias size for aten.linear")); - } - - Value initTensor; - SmallVector broadcastIndexingMaps; - Value transposedWeightInitTensor; - if (inputType.getRank() > 2) { - initTensor = rewriter.create( - loc, ValueRange{batchDim, inputDim0, weightDim0}, - inputType.getElementType()); - transposedWeightInitTensor = rewriter.create( - loc, ValueRange{batchDim, weightDim1, weightDim0}, - weightType.getElementType()); - broadcastIndexingMaps = { - AffineMap::get( - /*dimCount=*/inputType.getRank(), /*symbolCount=*/0, - {rewriter.getAffineDimExpr(1 + restDim)}, context), - rewriter.getMultiDimIdentityMap(inputType.getRank())}; - } else { - initTensor = rewriter.create( - loc, ValueRange{inputDim0, weightDim0}, - inputType.getElementType()); - transposedWeightInitTensor = rewriter.create( - loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType()); - broadcastIndexingMaps = { - AffineMap::get( - /*dimCount=*/inputType.getRank(), /*symbolCount=*/0, - {rewriter.getAffineDimExpr(1)}, context), - rewriter.getMultiDimIdentityMap(inputType.getRank())}; - } - - SmallVector iteratorTypes(inputType.getRank(), "parallel"); - Value broadcasted; - if (!bias.getType().isa()) { - broadcasted = - rewriter - .create( - loc, initTensor.getType(), bias, initTensor, - /*indexingMaps=*/broadcastIndexingMaps, - /*iteratorTypes=*/iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); - } else { - Type elementType = - input.getType().cast().getElementType(); - Value c0float = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); - broadcasted = rewriter.create(loc, c0float, initTensor) - .getResult(0); - } - // We need a matmul with dimension ordering (N, K) * (M, K), so transpose - // the weights to fit into linalg::MatmulOp which is (N, K) * (K, M). - // TODO: This whole aten.linear lowering should eventually be generated from - // a single linalg ODS generator statement. Both the bias and matmul part. - SmallVector transposeIndexingMaps = { - AffineMap::get( - /*dimCount=*/inputType.getRank(), /*symbolCount=*/0, - {rewriter.getAffineDimExpr(1 + restDim), - rewriter.getAffineDimExpr(0 + restDim)}, - context), - rewriter.getMultiDimIdentityMap(inputType.getRank())}; - Value transposedWeights = - rewriter - .create( - loc, transposedWeightInitTensor.getType(), weight, - transposedWeightInitTensor, - /*indexingMaps=*/transposeIndexingMaps, - /*iteratorTypes=*/iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); - Value matmul; - if (batchDim) - matmul = rewriter - .create( - loc, broadcasted.getType(), - ValueRange{input, transposedWeights}, broadcasted) - .getResult(0); - else - matmul = rewriter - .create( - loc, broadcasted.getType(), - ValueRange{input, transposedWeights}, broadcasted) - .getResult(0); - - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, matmul); - return success(); - } -}; -} // namespace - namespace { class ConvertAtenConvolutionOp : public OpConversionPattern { public: @@ -996,8 +841,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 98a7a4a19..2c09d6dd9 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2060,6 +2060,55 @@ public: }; } // namespace +namespace { +// Decompose `aten.linear` op into `aten.matmul` and `aten.add` ops. +class DecomposeAtenLinearOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinearOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.input(); + Value weight = op.weight(); + Value bias = op.bias(); + + BaseTensorType inputType = input.getType().cast(); + if (!inputType.hasSizes() || inputType.getSizes().size() < 2) + return rewriter.notifyMatchFailure( + op, "expected input to be rank 2 or greater"); + + BaseTensorType weightType = weight.getType().cast(); + // `weight` must be a rank 2 matrix. + if (!weightType.hasSizes() || weightType.getSizes().size() != 2) + return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2"); + + SmallVector transposeShape = + llvm::to_vector(llvm::reverse(weightType.getSizes())); + Type transposeType = weightType.getWithSizesAndDtype( + llvm::makeArrayRef(transposeShape), weightType.getDtype()); + Value transposeWeight = + rewriter.create(loc, transposeType, weight); + + Value matmul = rewriter.create(loc, op.getType(), input, + transposeWeight); + if (bias.getType().isa()) { + rewriter.replaceOp(op, matmul); + return success(); + } + + BaseTensorType biasType = bias.getType().cast(); + if (!biasType.hasSizes() || biasType.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); + + Value alpha = + rewriter.create(loc, rewriter.getF64FloatAttr(1)); + rewriter.replaceOpWithNewOp(op, op.getType(), matmul, + op.bias(), alpha); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.full_like` op into `aten.empty_like` and `aten.fill` ops. class DecomposeAtenFullLikeOp : public OpRewritePattern { @@ -2837,6 +2886,8 @@ public: target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 26e6c3f51..099e50234 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -126,7 +126,7 @@ class TensorPlaceholder: # ops in the backend contract, and move these lists somewhere deeper in the # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { - OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm'], + OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm','torch.aten.linear'], OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',], OutputType.MHLO: [], }