//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { static void getZeroPoint(Value value, Value &zeropoint) { if (auto make = value.getDefiningOp()) { zeropoint = make.getZeroPoint(); } } // for uint8 types, we shift down by 128 so that we can faithfully // represent the quantization with signed i8 types. static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, Value &zp, bool isUnsignedType, int64_t numBits) { if (!isUnsignedType) return; int64_t minSI = -(1 << (numBits - 1)); Value minSIValue = rewriter.create(loc, minSI, 32); zp = rewriter.create(loc, zp, minSIValue); minSIValue = rewriter.create(loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, ValueRange{arg}, cast(arg.getType()).getElementType(), [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value result = rewriter.create(loc, payloadArgs[0], minSIValue); b.create(loc, result); }); } static Value transposeValue(Location loc, Value value, ArrayRef perms, PatternRewriter &rewriter) { auto valueTy = cast(value.getType()); auto inShape = valueTy.getShape(); llvm::SmallVector outShape; llvm::SmallVector dynDims; for (size_t i = 0; i < perms.size(); ++i) { outShape.push_back(inShape[perms[i]]); if (ShapedType::isDynamic(inShape[perms[i]])) { dynDims.push_back(rewriter.create(loc, value, perms[i])); } } auto outTy = RankedTensorType::get(outShape, valueTy.getElementType()); Value empty = rewriter.create(loc, outTy, dynDims); Value transpose = rewriter.create(loc, value, empty, perms) ->getResult(0); return transpose; } class ConvertAtenMmOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenMmOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value lhs = adaptor.getSelf(); Value rhs = adaptor.getMat2(); // A user can write an errorneous program where `aten.mm` is in fact called // with operands of invalid rank or dtype. We cannot convert to linalg in // this case or we will get a verifier error, which corresponds to breaking // of *internal* compiler invariants, and for a user manifests as a compiler // crash in the worst case (such as we try to canonicalize/fold/print the // invalid op before the verifier gets to see it -- also release builds of a // mature compiler usually have the verifier turned off for compile time // reasons). // // The compiler cannot crash even if the user wrote an erroneous program! if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); RankedTensorType lhsType = cast(lhs.getType()); RankedTensorType rhsType = cast(rhs.getType()); if (lhsType.getRank() != 2 || rhsType.getRank() != 2) { return rewriter.notifyMatchFailure( op, "expected both operands to aten.mm to be rank 2"); } ValueTensorType lhsTorchType = cast(op.getSelf().getType()); ValueTensorType rhsTorchType = cast(op.getMat2().getType()); Value lhsZeroPoint, rhsZeroPoint; getZeroPoint(op.getSelf(), lhsZeroPoint); getZeroPoint(op.getMat2(), rhsZeroPoint); if (static_cast(lhsZeroPoint) != static_cast(rhsZeroPoint)) { return rewriter.notifyMatchFailure( op, "unsupported: aten.mm with mixed quantization"); } if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) { if (!lhsZeroPoint) { return rewriter.notifyMatchFailure( op, "unsupported: aten.mm with different input element types"); } // Allows quantized types to mismatch since they will be cast to the same // type. } bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType); bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType); Value lhsDim0 = rewriter.create(loc, lhs, 0); Value rhsDim1 = rewriter.create(loc, rhs, 1); if (!isAssumingStrictSymbolicShapes(rewriter)) { Value lhsDim1 = rewriter.create(loc, lhs, 1); Value rhsDim0 = rewriter.create(loc, rhs, 0); Value contractingDimEqual = rewriter.create( loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); rewriter.create( loc, contractingDimEqual, rewriter.getStringAttr( "mismatching contracting dimension for torch.aten.mm")); } auto resultTy = cast(op.getType()); auto resultDTy = resultTy.toBuiltinTensor().getElementType(); Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = cast(newResultType).getElementType(); auto accumulatorDType = getDefaultAccType(rewriter, resultDTy); if (accumulatorDType != resultDTy) { elementType = accumulatorDType; } Value zeroFill = createZeroInitTensor( rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); Value matmul; if (lhsZeroPoint) { lhsZeroPoint = typeConverter->materializeTargetConversion( rewriter, loc, getTypeConverter()->convertType(lhsZeroPoint.getType()), lhsZeroPoint); rhsZeroPoint = typeConverter->materializeTargetConversion( rewriter, loc, getTypeConverter()->convertType(rhsZeroPoint.getType()), rhsZeroPoint); lhsZeroPoint = rewriter.create( loc, rewriter.getI32Type(), lhsZeroPoint); rhsZeroPoint = rewriter.create( loc, rewriter.getI32Type(), rhsZeroPoint); // change uint8 quantization -> int8 quantization int64_t numBits = cast(lhsType.getElementType()).getWidth(); signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits); numBits = cast(rhsType.getElementType()).getWidth(); signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); matmul = rewriter .create( loc, zeroFill.getType(), ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill) .getResult(0); } else if (isUnsigned) { matmul = rewriter .create( loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill) .getResult(0); } else { matmul = rewriter .create(loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill) .getResult(0); } if (accumulatorDType != resultDTy) { Type resultElementType = cast(newResultType).getElementType(); matmul = torch_to_linalg::convertTensorToElementType( rewriter, loc, matmul, resultElementType); } // When constructed with just dynamic sizes, EmptyOp will have a result // type which has all `?`'s for dimensions, which might not be the result // type of `op`. The constraints on later linalg ops means that the result // of the MatmulOp will have this type too. So cast it to the desired type // so that in the end we have the original result type. rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } }; } // namespace namespace { class ConvertAtenFlipOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenFlipOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); auto selfRank = cast(adaptor.getSelf().getType()).getRank(); Type elementType = cast(adaptor.getSelf().getType()).getElementType(); Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector axis; if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis))) return rewriter.notifyMatchFailure(op, "only constant dim lists supported"); for (unsigned i = 0, e = axis.size(); i < e; i++) { axis[i] = toPositiveDim(axis[i], selfRank); if (!isValidDim(axis[i], selfRank)) { return rewriter.notifyMatchFailure(op, "axis is statically invalid"); } } // Only used to calculate flipped values, i.e. those on the flip axes. Other // dims won't be used. SmallVector dims = getTensorSizes(rewriter, loc, self); for (auto flipDim : axis) dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); Value initTensor = createZeroInitTensor( rewriter, loc, getTensorSizes(rewriter, loc, self), elementType); SmallVector iteratorTypes( selfRank, utils::IteratorType::parallel); SmallVector indexingMaps( 2, AffineMap::getMultiDimIdentityMap(selfRank, context)); Value flipped = rewriter .create( loc, self.getType(), self, initTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { SmallVector indices; for (auto i = 0; i < selfRank; i++) indices.push_back(b.create(loc, i)); for (auto flipDim : axis) { indices[flipDim] = b.create( loc, dims[flipDim], indices[flipDim]); } Value res = b.create(loc, self, indices) .getResult(); b.create(loc, res); }) .getResult(0); rewriter.replaceOpWithNewOp(op, self.getType(), flipped); return success(); } }; } // namespace namespace { class ConvertAtenMatmulOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenMatmulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { return failure(); } auto lhsType = cast(lhs.getType()); auto rhsType = cast(rhs.getType()); auto lhsTorchType = cast(op.getSelf().getType()); auto rhsTorchType = cast(op.getOther().getType()); // Get the rank of both matrix. unsigned lhsRank = lhsType.getRank(); unsigned rhsRank = rhsType.getRank(); Value lhsZeroPoint, rhsZeroPoint; getZeroPoint(op.getSelf(), lhsZeroPoint); getZeroPoint(op.getOther(), rhsZeroPoint); if (static_cast(lhsZeroPoint) != static_cast(rhsZeroPoint)) { return rewriter.notifyMatchFailure( op, "unsupported: aten.matmul with mixed quantization"); } bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType); bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType); if (!lhsZeroPoint && lhsTorchType.getDtype() != rhsTorchType.getDtype()) { // Allows quantized types to mismatch return rewriter.notifyMatchFailure( op, "unsupported: aten.matmul with different input element types"); } Type newResultType = getTypeConverter()->convertType(op.getType()); auto resultType = cast(newResultType); Type elementType = resultType.getElementType(); if (lhsZeroPoint) { // get each zero point ready to pass to a quantized_matmul lhsZeroPoint = typeConverter->materializeTargetConversion( rewriter, loc, getTypeConverter()->convertType(lhsZeroPoint.getType()), lhsZeroPoint); rhsZeroPoint = typeConverter->materializeTargetConversion( rewriter, loc, getTypeConverter()->convertType(rhsZeroPoint.getType()), rhsZeroPoint); lhsZeroPoint = rewriter.create( loc, rewriter.getI32Type(), lhsZeroPoint); rhsZeroPoint = rewriter.create( loc, rewriter.getI32Type(), rhsZeroPoint); // change uint8 quantization -> int8 quantization int64_t numBits = cast(lhsType.getElementType()).getWidth(); signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits); numBits = cast(rhsType.getElementType()).getWidth(); signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); // for quantized vec-vec, vec-mat, and mat-vec cases, lower to // expand/collapse + quantized_matmul bool lhsVec = (lhsRank == 1 && rhsRank <= 2); bool rhsVec = (lhsRank <= 2 && rhsRank == 1); if (lhsVec || rhsVec) { SmallVector reassociation(1); reassociation[0].push_back(0); reassociation[0].push_back(1); if (lhsVec) { // unsqueeze lhs to a matrix int64_t lhsDim = lhsType.getShape()[0]; auto lhsUnsqueezeType = RankedTensorType::get( ArrayRef{1, lhsDim}, lhsType.getElementType()); lhs = rewriter.create(loc, lhsUnsqueezeType, lhs, reassociation); } if (rhsVec) { // unsqueeze rhs to a matrix int64_t rhsDim = rhsType.getShape()[0]; auto rhsUnsqueezeType = RankedTensorType::get( ArrayRef{rhsDim, 1}, rhsType.getElementType()); rhs = rewriter.create(loc, rhsUnsqueezeType, rhs, reassociation); } // get quantized_matmul and squeeze result Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1); checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); Value zeroTensor = createZeroInitTensor( rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); Value matmul = rewriter .create( loc, zeroTensor.getType(), ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroTensor) .getResult(0); int64_t resultRank = resultType.getRank(); if (resultRank == 0) { // in vec-vec case, need to collapse result to a scalar reassociation.clear(); } matmul = rewriter.create( loc, resultType, matmul, reassociation); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } // the remaining quantized cases (Mat-Mat and broadcast -> BMM) are // covered in the relevant section below } // The different cases of torch_matmul op is mentioned here: // https://pytorch.org/docs/stable/generated/torch.matmul.html // First Case: Dot Product. if (lhsRank == 1 && rhsRank == 1) { Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0); Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType); Value dotProd = rewriter .create(loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, dotProd); return success(); } // Second Case: Vec-Mat Multiplication. if (lhsRank == 1 && rhsRank == 2) { Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1); checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0); Value zeroTensor = createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType); Value matmul = rewriter .create(loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } // Third Case: Matrix-Vec Multiplication. if (lhsRank == 2 && rhsRank == 1) { Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); Value zeroTensor = createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType); Value matmul = rewriter .create(loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } // Fourth Case: Mat-Mat Multiplication. if (lhsRank == 2 && rhsRank == 2) { Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1); checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); Value zeroTensor = createZeroInitTensor( rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); Value matmul; if (lhsZeroPoint) { matmul = rewriter .create( loc, zeroTensor.getType(), ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroTensor) .getResult(0); } else { matmul = rewriter .create(loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor) .getResult(0); } rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } // Fifth Case: Batch-Matrix Multiplication. // TODO: Handle batch matrix multiplication when one of the matrix is unity // rank and the other has batch dimension. if (lhsRank > 1 && rhsRank > 1) { unsigned maxRank = std::max(lhsRank, rhsRank); unsigned minRank = std::min(lhsRank, rhsRank); unsigned batchRank = maxRank - 2; // At least one of the matrix must have rank greater than 2. if (batchRank <= 0) { return rewriter.notifyMatchFailure(op, "expected batch dimensions"); } // The `broadcastedBatchShape` contains batch dimensions of the resultant // matrix. SmallVector broadcastedBatchShape(batchRank); Value maxRankMatrix = (lhsRank > rhsRank) ? lhs : rhs; Value maxDim; // Compute broadcasted batch dimensions if the batch dimensions of // the matrices are broadcastable. for (unsigned i = 1; i <= batchRank; i++) { if (i <= minRank - 2) { Value lhsDim = getDimOp(rewriter, loc, lhs, lhsRank - 2 - i); Value rhsDim = getDimOp(rewriter, loc, rhs, rhsRank - 2 - i); maxDim = rewriter.createOrFold(loc, lhsDim, rhsDim); } else { maxDim = getDimOp(rewriter, loc, maxRankMatrix, maxRank - 2 - i); } broadcastedBatchShape[batchRank - i] = maxDim; } Value lhsDim0 = getDimOp(rewriter, loc, lhs, lhsRank - 2); Value lhsDim1 = getDimOp(rewriter, loc, lhs, lhsRank - 1); Value rhsDim0 = getDimOp(rewriter, loc, rhs, rhsRank - 2); Value rhsDim1 = getDimOp(rewriter, loc, rhs, rhsRank - 1); checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); // Compute broadcasted shape of both the matrices in integer format. SmallVector lhsBroadcastToShape(broadcastedBatchShape); lhsBroadcastToShape.push_back(lhsDim0); lhsBroadcastToShape.push_back(lhsDim1); SmallVector rhsBroadcastToShape(broadcastedBatchShape); rhsBroadcastToShape.push_back(rhsDim0); rhsBroadcastToShape.push_back(rhsDim1); for (unsigned i = 0; i < maxRank; i++) { lhsBroadcastToShape[i] = castIndexToInt64(rewriter, loc, lhsBroadcastToShape[i]); rhsBroadcastToShape[i] = castIndexToInt64(rewriter, loc, rhsBroadcastToShape[i]); } // Broadcast the batch dimensions of both the matrices. Value broadcastedLhs, broadcastedRhs; // TODO: Improve usage of static shape information. SmallVector lhsTargetShape(lhsBroadcastToShape.size(), ShapedType::kDynamic); auto lhsBroadcastType = RankedTensorType::get( lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType, broadcastedLhs))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } SmallVector rhsTargetShape(rhsBroadcastToShape.size(), ShapedType::kDynamic); auto rhsBroadcastType = RankedTensorType::get( rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType, broadcastedRhs))) { return rewriter.notifyMatchFailure( op, "unable to perform broadcast operation"); } if (maxRank == 3) { Value zeroTensor = createZeroInitTensor( rewriter, loc, ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1}, elementType); Value matmul; if (lhsZeroPoint) { matmul = rewriter .create( loc, zeroTensor.getType(), ValueRange{broadcastedLhs, broadcastedRhs, lhsZeroPoint, rhsZeroPoint}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } matmul = rewriter .create( loc, zeroTensor.getType(), ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } // Check if the result of the matrix multiplication has more than one // dynamic batch dimensions. SmallVector batchDimsInt = makeShapeTorchCompatible(resultType.getShape()); batchDimsInt.pop_back(); batchDimsInt.pop_back(); bool multipleDynamicBatchDims = llvm::count(batchDimsInt, kUnknownSize) > 1; // TODO: Lowering to `linalg.BatchMatmul` is only possible when there is // at most one dynamic batch dimension due to limited support of the // `tensor.ExpandShape` op. if (!multipleDynamicBatchDims) { // Collapse the batch dimensions into one dimension. The resultant rank // will always be 3. SmallVector reassociation(3); for (unsigned i = 0, j = 0; i < maxRank; i++) { if (i >= batchRank) j++; reassociation[j].push_back(i); } Value collapsedLhs = rewriter.create( op->getLoc(), broadcastedLhs, reassociation); Value collapsedRhs = rewriter.create( op->getLoc(), broadcastedRhs, reassociation); // Compute the result shape after collapsing the batch dimensions. SmallVector collapsedResultShape; collapsedResultShape.push_back(broadcastedBatchShape[0]); for (unsigned i = 1; i < batchRank; i++) { collapsedResultShape[0] = rewriter.createOrFold( loc, collapsedResultShape[0], broadcastedBatchShape[i]); } collapsedResultShape.push_back(lhsDim0); collapsedResultShape.push_back(rhsDim1); SmallVector updatedCollapseResultShape = getAsOpFoldResult(collapsedResultShape); Value initTensor = rewriter.create( loc, updatedCollapseResultShape, elementType); Value c0 = rewriter.create( loc, rewriter.getZeroAttr(elementType)); Value zeroTensor = rewriter.create(loc, c0, initTensor).getResult(0); Value batchMatMul; if (lhsZeroPoint) { batchMatMul = rewriter .create( loc, zeroTensor.getType(), ValueRange{collapsedLhs, collapsedRhs, lhsZeroPoint, rhsZeroPoint}, zeroTensor) .getResult(0); } else { batchMatMul = rewriter .create( loc, zeroTensor.getType(), ValueRange{collapsedLhs, collapsedRhs}, zeroTensor) .getResult(0); } Value expandResult = rewriter.create( loc, resultType, batchMatMul, reassociation); rewriter.replaceOpWithNewOp(op, newResultType, expandResult); return success(); } SmallVector lhsExpr; SmallVector rhsExpr; SmallVector outExpr; SmallVector iteratorTypes( batchRank, utils::IteratorType::parallel); for (unsigned i = 0; i < batchRank; i++) { lhsExpr.push_back(rewriter.getAffineDimExpr(i)); rhsExpr.push_back(rewriter.getAffineDimExpr(i)); outExpr.push_back(rewriter.getAffineDimExpr(i)); } lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank), rewriter.getAffineDimExpr(batchRank + 1)}); rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1), rewriter.getAffineDimExpr(batchRank + 2)}); outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank), rewriter.getAffineDimExpr(batchRank + 2)}); SmallVector resultShape(broadcastedBatchShape); resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1}); Value zeroTensor = createZeroInitTensor(rewriter, loc, resultShape, elementType); auto indexingMaps = AffineMap::inferFromExprList( {lhsExpr, rhsExpr, outExpr}, rewriter.getContext()); iteratorTypes.insert(iteratorTypes.end(), {utils::IteratorType::parallel, utils::IteratorType::reduction, utils::IteratorType::parallel}); Value finalRes = rewriter .create( loc, zeroTensor.getType(), ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value l = args[0], r = args[1], res = args[2]; Value mul = b.create(loc, l, r); Value add = b.create(loc, mul, res); b.create(loc, add); }) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, finalRes); return success(); } return failure(); } }; } // namespace namespace { class ConvertAtenBmmOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBmmOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value lhs = adaptor.getSelf(); Value rhs = adaptor.getMat2(); RankedTensorType lhsType = cast(lhs.getType()); RankedTensorType rhsType = cast(rhs.getType()); Type newResultType = getTypeConverter()->convertType(op.getType()); Type resultElementType = cast(newResultType).getElementType(); Type lhsElementType = cast(lhsType).getElementType(); Type rhsElementType = cast(rhsType).getElementType(); if (lhsType.getRank() != 3 || rhsType.getRank() != 3) { return rewriter.notifyMatchFailure( op, "expected both operands to aten.bmm to be rank 3"); } // Convert the inputs element type equivalent to the result' element type. if (lhsElementType != rhsElementType) { if (lhsElementType != resultElementType) { // True if the lhs element type is not equal to the result' element // type. lhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, lhs, resultElementType); } else { // True if the rhs element type is not equal to the result' element // type. rhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, rhs, resultElementType); } } Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); Value lhsDim2 = getDimOp(rewriter, loc, lhs, 2); Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1); Value rhsDim2 = getDimOp(rewriter, loc, rhs, 2); // Check the batch numbers are equal. checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0); // Check the matrixs shapes are valid for mulplication. checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); Value initTensor0 = createZeroInitTensor( rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType); Value bmm = rewriter .create(loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, bmm); return success(); } }; } // namespace namespace { class ConvertAtenConvolutionOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); Value input = adaptor.getInput(); /* in form of N*C*H*W */ Value weight = adaptor.getWeight(); /* in form of F*C*H*W */ Value bias = adaptor.getBias(); auto resultTy = cast(op.getType()); Value inputZp, weightZp; if (auto make = op.getInput() .getDefiningOp()) { input = make.getSelf(); inputZp = make.getZeroPoint(); input = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(input.getType()), input); inputZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(inputZp.getType()), inputZp); } if (auto make = op.getWeight() .getDefiningOp()) { weight = make.getSelf(); weightZp = make.getZeroPoint(); weight = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weight.getType()), weight); weightZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weightZp.getType()), weightZp); } if (static_cast(inputZp) != static_cast(weightZp)) { return rewriter.notifyMatchFailure( op, "lhs and rhs of convolution must either be both int or fp"); } if (inputZp && weightZp && !isa(bias.getType())) { auto biasDTy = cast(bias.getType()).getElementType(); if (!biasDTy.isInteger(32)) { return rewriter.notifyMatchFailure( op, "quantized result ty should be i32 accumulator"); } } bool transposed = true; if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) return rewriter.notifyMatchFailure( op, "unimplemented: only constant transposed supported"); auto inputDTy = cast(input.getType()).getElementType(); auto weightDTy = cast(weight.getType()).getElementType(); auto resultDTy = resultTy.toBuiltinTensor().getElementType(); if (!isa(inputDTy) || !isa(weightDTy) || !isa(resultDTy)) return op.emitError("unimplemented: non-fp not-int type"); size_t inRank = cast(input.getType()).getRank(); size_t numSpatialDims = inRank - 2; if (numSpatialDims < 1 || numSpatialDims > 3) return rewriter.notifyMatchFailure( op, "unimplemented: only 1d-3d convolution currently supported"); Type intType = IntegerType::get(context, 64); auto castIndexToInt = [&](Value v) { return rewriter.create(loc, intType, v); }; SmallVector paddingIntValues; if (!getListConstructElements(op.getPadding(), paddingIntValues)) return rewriter.notifyMatchFailure( op, "only support padding from a list construct"); paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), paddingIntValues); SmallVector outputPaddingIntValues; if (!getListConstructElements(op.getOutputPadding(), outputPaddingIntValues)) return rewriter.notifyMatchFailure( op, "only support output_padding from a list construct"); outputPaddingIntValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), outputPaddingIntValues); SmallVector strideInts; if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) return rewriter.notifyMatchFailure(op, "only support constant int strides"); SmallVector dilationInts; if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); Value inBatch = getDimOp(rewriter, loc, input, 0); Value inChannels = getDimOp(rewriter, loc, input, 1); SmallVector inDims; for (size_t i = 2; i < inRank; i++) inDims.push_back(getDimOp(rewriter, loc, input, i)); Value weightBatch = getDimOp(rewriter, loc, weight, 0); Value weightChannels = getDimOp(rewriter, loc, weight, 1); SmallVector weightDims; for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); // Checks for valid group size int64_t groupSize; if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groupSize))) return rewriter.notifyMatchFailure(op, "only constant group size supported."); Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); auto validate = [&](Value toValidate, std::string err) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value inputValid = rewriter.create( loc, arith::CmpIPredicate::eq, c0, rewriter.create(loc, toValidate, groups)); rewriter.create(loc, inputValid, rewriter.getStringAttr(err)); }; validate(inChannels, "invalid: groups must divide input channel size evenly."); validate(weightBatch, "invalid: groups must divide weight batch size evenly."); SmallVector dilationIntValues = getAsConstantIntValues(rewriter, loc, dilationInts); SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); // Pad the input tensor according to padding. SmallVector outDims{inBatch, weightBatch}; Value paddedInput; if (transposed) { if (!isa(inputDTy) || !isa(weightDTy) || !isa(resultDTy)) return rewriter.notifyMatchFailure( op, "transpose does not support non-fp type yet"); Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); Value c2 = rewriter.create(loc, rewriter.getIndexAttr(2)); // Transpose and flip weight SmallVector weightInitDims = getTensorSizes(rewriter, loc, weight); std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1); outDims[1] = weightInitDims[0]; Value weightInitTensor = createZeroInitTensor(rewriter, loc, weightInitDims, weightDTy); SmallVector iteratorTypes( inRank, utils::IteratorType::parallel); SmallVector indexingMaps{ AffineMap::getMultiDimIdentityMap(inRank, context)}; weight = rewriter .create( loc, weightInitTensor.getType(), ValueRange{}, weightInitTensor, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { SmallVector indices; for (size_t i = 0; i < inRank; i++) indices.push_back(b.create(loc, i)); std::iter_swap(indices.begin(), indices.begin() + 1); // Flip only the spatial dimensions (from 2 to inRank) for (size_t flipDim = 2; flipDim < inRank; flipDim++) { indices[flipDim] = b.create( loc, b.create( loc, weightInitDims[flipDim], c1), indices[flipDim]); } Value res = b.create(loc, weight, indices) .getResult(); b.create(loc, res); }) .getResult(0); // Calculate padded input size, allocate tensor SmallVector outerSizes{inBatch, inChannels}; SmallVector innerSizes{inBatch, inChannels}; SmallVector offsets{c0, c0}; for (size_t i = 0; i < numSpatialDims; i++) { Value innerSize = rewriter.create(loc, inDims[i], c1); innerSize = rewriter.create( loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i])); innerSize = rewriter.create(loc, innerSize, c1); Value offset = rewriter.create(loc, weightDims[i], c1); offset = rewriter.create( loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i])); offset = rewriter.create( loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i])); Value outerSize = rewriter.create(loc, offset, c2); outerSize = rewriter.create(loc, outerSize, innerSize); outerSize = rewriter.create( loc, outerSize, castIntToIndex(rewriter, loc, outputPaddingIntValues[i])); outerSizes.push_back(outerSize); offsets.push_back(offset); } // Allocate padded input tensor Value initTensor = createZeroInitTensor(rewriter, loc, outerSizes, inputDTy); // Insert input into allocated tensor SmallVector strideIndexValues{c1, c1}; for (auto stride : strideIntValues) strideIndexValues.push_back(castIntToIndex(rewriter, loc, stride)); SmallVector insertSizes = getTensorSizes(rewriter, loc, input); paddedInput = rewriter.create( loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input), initTensor, offsets, insertSizes, strideIndexValues); // Calculate output dims for (size_t i = 0; i < numSpatialDims; i++) outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps( rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], castIndexToInt(weightDims[i]), strideIntValues[i], outputPaddingIntValues[i])); // Set stride to 1 strideInts.clear(); strideInts.append(numSpatialDims, 1); } else { Value pad = inputZp; if (!pad) { if (isa(inputDTy)) pad = rewriter.create( op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); if (isa(inputDTy)) pad = rewriter.create( op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); } if (pad.getType() != inputDTy) { if (isa(inputDTy)) pad = rewriter.create(op.getLoc(), inputDTy, pad); if (isa(inputDTy)) pad = rewriter.create(op.getLoc(), inputDTy, pad); } // Pad input paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad); // Calculate output dims for (size_t i = 0; i < numSpatialDims; i++) outDims.push_back(torch_to_linalg::getOutputDimForConvOps( rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], castIndexToInt(weightDims[i]), strideIntValues[i])); } Type accumulatorDType = getDefaultAccType(rewriter, resultDTy); Value initTensor = rewriter.create( loc, getAsOpFoldResult(outDims), accumulatorDType); Value outputTensor; if (accumulatorDType != resultDTy && !bias.getType().isa()) bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias, accumulatorDType); if (bias.getType().isa()) { Value c0; if (isa(accumulatorDType)) { c0 = rewriter.create( loc, FloatAttr::get(accumulatorDType, 0.0)); } else if (isa(accumulatorDType)) { c0 = rewriter.create( loc, IntegerAttr::get(accumulatorDType, 0)); } outputTensor = rewriter.create(loc, c0, initTensor).getResult(0); } else { auto biasType = cast(bias.getType()); if (biasType.getRank() != 1) return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); auto resultRank = cast(initTensor.getType()).getRank(); SmallVector indexingMaps = { // bias is used to initialize the channels - dimension 1 of output AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(1), context), rewriter.getMultiDimIdentityMap(resultRank)}; SmallVector iteratorTypes( resultRank, utils::IteratorType::parallel); outputTensor = rewriter .create( loc, initTensor.getType(), bias, initTensor, indexingMaps, iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args[0]); }) .getResult(0); } auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); Value inputStride = rewriter.create(loc, inChannels, groups); Value weightStride = rewriter.create(loc, weightBatch, groups); SmallVector zeroOffsets(inRank, rewriter.create( loc, rewriter.getIndexAttr(0))); SmallVector unitStrides(inRank, rewriter.create( loc, rewriter.getIndexAttr(1))); SmallVector outDimSlice(outDims); outDimSlice[1] = weightStride; SmallVector inputSliceSizes{inBatch, inputStride}; inputSliceSizes.append(inDims); SmallVector weightSliceSizes{weightStride, weightChannels}; weightSliceSizes.append(weightDims); Value conv; // the code so far is able to respect all numSpatialDims // the code below this point is numSpatialDims specific and groupSize // specific // TODO: factor out the above code into a helper function, and then separate // convolution into: // - grouped 1d-3d // - grouped 1d-3d (quantized) // - ungrouped 1d-3d if (groupSize == 1 && !inputZp && !weightZp) { switch (numSpatialDims) { case 1: conv = rewriter .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight}, outputTensor, stridesAttr, dilationAttr) .getResult(0); break; case 2: conv = rewriter .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight}, outputTensor, stridesAttr, dilationAttr) .getResult(0); break; case 3: conv = rewriter .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight}, outputTensor, stridesAttr, dilationAttr) .getResult(0); break; default: return rewriter.notifyMatchFailure( op, "unimplemented: only 1D, 2D, and 3D convolution supported"); }; Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = cast(newResultType).getElementType(); conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } if (groupSize == 1 && inputZp && weightZp) { // The quantized version uses a different channel ordering so we need to // permute the tensors in order to use the existing path. We should // eventually directly support this channel ordering. llvm::SmallVector inPerms, weightPerms; inPerms.push_back(0); // N stays at the front for input. // Then we expect the spatial dimensions for (size_t i = 0; i < numSpatialDims; ++i) { inPerms.push_back(i + 2); weightPerms.push_back(i + 2); } inPerms.push_back(1); weightPerms.append({1, 0}); paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); outputTensor = transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); switch (numSpatialDims) { case 2: conv = rewriter .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight, inputZp, weightZp}, outputTensor, stridesAttr, dilationAttr) .getResult(0); break; case 3: conv = rewriter .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight, inputZp, weightZp}, outputTensor, stridesAttr, dilationAttr) .getResult(0); break; default: return rewriter.notifyMatchFailure( op, "unimplemented: only 1D, 2D, and 3D convolution supported"); }; llvm::SmallVector outPerms; outPerms.push_back(0); outPerms.push_back(inPerms.size() - 1); for (size_t i = 0; i < numSpatialDims; ++i) { outPerms.push_back(i + 1); } conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = cast(newResultType).getElementType(); conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } if (inputZp || weightZp) return rewriter.notifyMatchFailure( op, "unimplemented: quantized grouped convolutions"); if (numSpatialDims != 2) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D grouped convolution supported"); // Special depthwise case auto inShape = makeShapeTorchCompatible( cast(input.getType()).getShape()); auto weightShape = makeShapeTorchCompatible( cast(weight.getType()).getShape()); if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { // Collapse weight shape SmallVector collapsedDims = {{0, 1}, {2}, {3}}; SmallVector collapsedShape{ (weightShape[0] == kUnknownSize ? kUnknownSize : weightShape[0] * weightShape[1]), weightShape[2], weightShape[3]}; Type collapsedType = RankedTensorType::get( makeShapeLLVMCompatible(collapsedShape), weightDTy); Value collapsedWeight = rewriter.create( loc, collapsedType, weight, collapsedDims); conv = rewriter .create( loc, outputTensor.getType(), ValueRange{paddedInput, collapsedWeight}, outputTensor, stridesAttr, dilationAttr) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = cast(newResultType).getElementType(); conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { auto inType = cast(tensor.getType()); auto inShape = makeShapeTorchCompatible(inType.getShape()); SmallVector outShape; for (auto i = 0; i < (long)inShape.size(); i++) { if (i == 1) { outShape.push_back(groupSize); } if (i == (long)dim) { outShape.push_back(inShape[i] == kUnknownSize ? kUnknownSize : inShape[i] / groupSize); } else { outShape.push_back(inShape[i]); } } SmallVector indices; for (auto i = 0; i <= (long)inShape.size(); i++) { if (i == (long)dim) { indices.push_back({i, ++i}); continue; } indices.push_back({i}); } auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); return rewriter.create(loc, retType, tensor, indices); }; // expand F,C,H,W -> G,F/G,C,H,W auto expandWeight = [&](Value tensor) { auto inType = cast(tensor.getType()); auto inShape = makeShapeTorchCompatible(inType.getShape()); SmallVector outShape{ groupSize, (inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / groupSize)}; outShape.append(inShape.begin() + 1, inShape.end()); SmallVector indices{{0, 1}}; for (auto i = 2; i <= (long)inShape.size(); i++) indices.push_back({i}); auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); return rewriter.create(loc, retType, tensor, indices); }; Value paddedInputExpanded = expandGroups(paddedInput, 1); Value weightExpanded = expandWeight(weight); auto expandOutputTensor = expandGroups(outputTensor, 1); // TODO: add 1D and 3D case conv = rewriter .create( loc, expandOutputTensor.getResultType(), ValueRange{paddedInputExpanded, weightExpanded}, expandOutputTensor.getResult(), stridesAttr, dilationAttr) .getResult(0); conv = rewriter.create( loc, outputTensor.getType(), conv, expandOutputTensor.getReassociationIndices()); Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = cast(newResultType).getElementType(); conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } }; } // namespace void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); }