//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "Utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { class ConvertAtenMmOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenMmOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value lhs = adaptor.self(); Value rhs = adaptor.mat2(); // A user can write an errorneous program where `aten.mm` is in fact called // with operands of invalid rank or dtype. We cannot convert to linalg in // this case or we will get a verifier error, which corresponds to breaking // of *internal* compiler invariants, and for a user manifests as a compiler // crash in the worst case (such as we try to canonicalize/fold/print the // invalid op before the verifier gets to see it -- also release builds of a // mature compiler usually have the verifier turned off for compile time // reasons). // // The compiler cannot crash even if the user wrote an erroneous program! if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); if (lhs.getType().cast().getRank() != 2 || rhs.getType().cast().getRank() != 2) { return rewriter.notifyMatchFailure( op, "expected both operands to aten.mm to be rank 2"); } Value lhsDim0 = rewriter.create(loc, lhs, 0); Value lhsDim1 = rewriter.create(loc, lhs, 1); Value rhsDim0 = rewriter.create(loc, rhs, 0); Value rhsDim1 = rewriter.create(loc, rhs, 1); Value contractingDimEqual = rewriter.create( loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); rewriter.create( loc, contractingDimEqual, rewriter.getStringAttr( "mismatching contracting dimension for torch.aten.mm")); Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); Value initTensor = rewriter.create( loc, ValueRange{lhsDim0, rhsDim1}, elementType); Value c0 = rewriter.create( loc, FloatAttr::get(elementType, 0.0)); Value zeroFill = rewriter.create(loc, c0, initTensor).getResult(0); Value matmul = rewriter .create(loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill) .getResult(0); // When constructed with just dynamic sizes, InitTensorOp will have a result // type which has all `?`'s for dimensions, which might not be the result // type of `op`. The constraints on later linalg ops means that the result // of the MatmulOp will have this type too. So cast it to the desired type // so that in the end we have the original result type. rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } }; } // namespace namespace { class ConvertAtenMatmulOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenMatmulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value lhs = adaptor.self(); Value rhs = adaptor.other(); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); unsigned lhsRank = lhs.getType().cast().getRank(); unsigned rhsRank = rhs.getType().cast().getRank(); Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); // The different cases of torch_matmul op is mentioned here: // https://pytorch.org/docs/stable/generated/torch.matmul.html // First Case: Dot Product. if (lhsRank == 1 && rhsRank == 1) { Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0); Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType); Value dotProd = rewriter .create(loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, dotProd); return success(); } // Second Case: Vec-Mat Multiplication. if (lhsRank == 1 && rhsRank == 2) { Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1); checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0); Value zeroTensor = createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType); Value matmul = rewriter .create(loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } // Third Case: Matrix-Vec Multiplication. if (lhsRank == 2 && rhsRank == 1) { Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); Value zeroTensor = createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType); Value matmul = rewriter .create(loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } // Fourth Case: Batch-Matrix Multiplication. // TODO: Broadcasting of batch dimension is remaining. if (lhsRank >= 3 && rhsRank >= 3 && lhsRank == rhsRank) { unsigned batchRank = lhsRank - 2; SmallVector resultShape; SmallVector lhsExpr; SmallVector rhsExpr; SmallVector outExpr; SmallVector iteratorTypes; // Since broadcasting is a TODO, check whether the lhs and rhs batch // dimension match. for (unsigned i = 0; i < batchRank; i++) { Value lhsBatch = getDimOp(rewriter, loc, lhs, i); Value rhsBatch = getDimOp(rewriter, loc, rhs, i); resultShape.push_back(lhsBatch); lhsExpr.push_back(rewriter.getAffineDimExpr(i)); rhsExpr.push_back(rewriter.getAffineDimExpr(i)); outExpr.push_back(rewriter.getAffineDimExpr(i)); iteratorTypes.push_back(getParallelIteratorTypeName()); checkDimEqualHelper(rewriter, loc, lhsBatch, rhsBatch); } Value lhsDim0 = getDimOp(rewriter, loc, lhs, batchRank); Value lhsDim1 = getDimOp(rewriter, loc, lhs, batchRank + 1); Value rhsDim0 = getDimOp(rewriter, loc, rhs, batchRank); Value rhsDim1 = getDimOp(rewriter, loc, rhs, batchRank + 1); checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); // Push the final matrix dimension. resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1}); lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank), rewriter.getAffineDimExpr(batchRank + 1)}); rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1), rewriter.getAffineDimExpr(batchRank + 2)}); outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank), rewriter.getAffineDimExpr(batchRank + 2)}); Value initTensor0 = createZeroInitTensor(rewriter, loc, resultShape, elementType); auto indexingMaps = AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr}); iteratorTypes.insert(iteratorTypes.end(), {"parallel", "reduction", "parallel"}); Value finalRes = rewriter .create( loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value l = args[0], r = args[1], res = args[2]; Value mul = b.create(loc, l, r); Value add = b.create(loc, mul, res); b.create(loc, add); }) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, finalRes); return success(); } return failure(); } }; } // namespace namespace { class ConvertAtenBmmOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenBmmOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op->getLoc(); Value lhs = adaptor.self(); Value rhs = adaptor.mat2(); RankedTensorType lhsType = lhs.getType().cast(); RankedTensorType rhsType = rhs.getType().cast(); if (lhsType.getRank() != 3 || rhsType.getRank() != 3) { return rewriter.notifyMatchFailure( op, "expected both operands to aten.bmm to be rank 3"); } if (!lhsType.getElementType().isa() || lhsType.getElementType() != rhsType.getElementType()) return op.emitError( "unimplemented: non floating point operands or operands of " "different types"); Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); Value lhsDim2 = getDimOp(rewriter, loc, lhs, 2); Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1); Value rhsDim2 = getDimOp(rewriter, loc, rhs, 2); // Check the batch numbers are equal. checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0); // Check the matrixs shapes are valid for mulplication. checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); Value initTensor0 = createZeroInitTensor( rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, elementType); Value bmm = rewriter .create(loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, bmm); return success(); } }; } // namespace namespace { // See comments at in convertMmOp and the heading for this section for general // considerations. This function needs to be auto-generated. class ConvertAtenLinearOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenLinearOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); Location loc = op->getLoc(); Value input = adaptor.input(); Value weight = adaptor.weight(); Value bias = adaptor.bias(); // TODO: Handle the case of bias being None (bias is optional). if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); auto inputType = input.getType().cast(); auto weightType = weight.getType().cast(); auto biasType = bias.getType().cast(); if (inputType.getRank() != 2 && inputType.getRank() != 3) { return rewriter.notifyMatchFailure( op, "expected input to be rank 2 or rank 3"); } // Only handle the case of rank 2 `weight` for now. // TODO: Insert the appropriate reshape to collapse any leading dimensions. if (weightType.getRank() != 2 || biasType.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expected weight to be rank 2 and bias to be rank 1"); } // TODO: Handle type promotion. What are ATen's promotion rules? if (inputType.getElementType() != weightType.getElementType() || inputType.getElementType() != biasType.getElementType()) { return rewriter.notifyMatchFailure(op, "unimplemented: type promotion"); } // TODO: We can handle a static size 1 here at some complexity cost, but the // dynamic case is not representable in linalg. We don't handle either for // now. Biases are generally statically shaped for most models (since for // inference they are constants, and for training they don't change shape // typically), so this is not too constraining. auto biasSize = bias.getType().cast().getShape()[0]; if (biasSize == 1 || biasSize == ShapedType::kDynamicSize) return rewriter.notifyMatchFailure( op, "unimplemented: size-1 broadcasting for aten::LinearOp"); Value batchDim = nullptr; int restDim = 0; if (inputType.getRank() == 3) { batchDim = getDimOp(rewriter, loc, input, 0); restDim = 1; } Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0); Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1); Value weightDim0 = getDimOp(rewriter, loc, weight, 0); Value weightDim1 = getDimOp(rewriter, loc, weight, 1); Value biasDim0 = getDimOp(rewriter, loc, bias, 0); Value contractingDimEqual = rewriter.create( loc, arith::CmpIPredicate::eq, inputDim1, weightDim1); rewriter.create( loc, contractingDimEqual, rewriter.getStringAttr( "mismatching contracting dimension for aten.linear")); // Here we take advantage of ruling out the size-1 case above. // In the static-size-1 case, we will not emit this check at all. Value biasSizeCorrect = rewriter.create( loc, arith::CmpIPredicate::eq, weightDim0, biasDim0); rewriter.create( loc, biasSizeCorrect, rewriter.getStringAttr("mismatching bias size for aten.linear")); Value initTensor; SmallVector broadcastIndexingMaps; Value transposedWeightInitTensor; if (inputType.getRank() > 2) { initTensor = rewriter.create( loc, ValueRange{batchDim, inputDim0, weightDim0}, inputType.getElementType()); transposedWeightInitTensor = rewriter.create( loc, ValueRange{batchDim, weightDim1, weightDim0}, weightType.getElementType()); broadcastIndexingMaps = { AffineMap::get( /*dimCount=*/inputType.getRank(), /*symbolCount=*/0, {rewriter.getAffineDimExpr(1 + restDim)}, context), rewriter.getMultiDimIdentityMap(inputType.getRank())}; } else { initTensor = rewriter.create( loc, ValueRange{inputDim0, weightDim0}, inputType.getElementType()); transposedWeightInitTensor = rewriter.create( loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType()); broadcastIndexingMaps = { AffineMap::get( /*dimCount=*/inputType.getRank(), /*symbolCount=*/0, {rewriter.getAffineDimExpr(1)}, context), rewriter.getMultiDimIdentityMap(inputType.getRank())}; } SmallVector iteratorTypes(inputType.getRank(), "parallel"); Value broadcasted = rewriter .create( loc, initTensor.getType(), bias, initTensor, /*indexingMaps=*/broadcastIndexingMaps, /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args[0]); }) .getResult(0); // We need a matmul with dimension ordering (N, K) * (M, K), so transpose // the weights to fit into linalg::MatmulOp which is (N, K) * (K, M). // TODO: This whole aten.linear lowering should eventually be generated from // a single linalg ODS generator statement. Both the bias and matmul part. SmallVector transposeIndexingMaps = { AffineMap::get( /*dimCount=*/inputType.getRank(), /*symbolCount=*/0, {rewriter.getAffineDimExpr(1 + restDim), rewriter.getAffineDimExpr(0 + restDim)}, context), rewriter.getMultiDimIdentityMap(inputType.getRank())}; Value transposedWeights = rewriter .create( loc, transposedWeightInitTensor.getType(), weight, transposedWeightInitTensor, /*indexingMaps=*/transposeIndexingMaps, /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args[0]); }) .getResult(0); Value matmul; if (batchDim) matmul = rewriter .create( loc, broadcasted.getType(), ValueRange{input, transposedWeights}, broadcasted) .getResult(0); else matmul = rewriter .create( loc, broadcasted.getType(), ValueRange{input, transposedWeights}, broadcasted) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } }; } // namespace namespace { class ConvertAtenConvolutionOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); Value input = adaptor.input(); /* in form of N*C*H*W */ Value weight = adaptor.weight(); /* in form of F*C*H*W */ Type elementType = input.getType().cast().getElementType(); if (!elementType.isa()) return op.emitError("unimplemented: non-floating point type"); size_t inRank = input.getType().cast().getRank(); if (inRank != 4) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D convolution currently supported"); Type intType = IntegerType::get(context, 64); auto castIndexToInt = [&](Value v) { return rewriter.create(loc, intType, v); }; SmallVector paddingInts; if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) { return rewriter.notifyMatchFailure( op, "only support constant padding values"); } SmallVector strideInts; if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts))) return rewriter.notifyMatchFailure(op, "only support constant int strides"); SmallVector dilationInts; if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); Value N = getDimOp(rewriter, loc, input, 0); SmallVector inDims; for (size_t i = 2; i < inRank; i++) inDims.push_back(getDimOp(rewriter, loc, input, i)); Value F = getDimOp(rewriter, loc, weight, 0); SmallVector weightDims; for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); // Guard unused values (transposed, groups) int64_t group_size; if (!matchPattern(op.groups(), m_TorchConstantInt(&group_size)) || group_size != 1) return rewriter.notifyMatchFailure( op, "unimplemented: only group size of 1 supported"); bool transposed = true; if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)) || transposed) return rewriter.notifyMatchFailure( op, "unimplemented: only non-transposed convolution supported"); // Pad the input tensor according to padding. SmallVector paddingIncludingNC = {0, 0}; paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(), paddingInts.end()); Value paddedInput = torch_to_linalg::getZeroPaddedTensor( op, rewriter, input, paddingIncludingNC); SmallVector paddingIntValues = getAsConstantIntValues(rewriter, loc, paddingInts); SmallVector dilationIntValues = getAsConstantIntValues(rewriter, loc, dilationInts); SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); SmallVector outDims{N, F}; for (size_t i = 0; i < inRank - 2; i++) outDims.push_back(torch_to_linalg::getOutputDimForConvOps( rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], castIndexToInt(weightDims[i]), strideIntValues[i])); Value initTensor = rewriter.create(loc, outDims, elementType); Value bias = adaptor.bias(); Value biasInitTensor; if (bias.getType().isa()) { Value c0float = rewriter.create( loc, FloatAttr::get(elementType, 0.0)); biasInitTensor = rewriter.create(loc, c0float, initTensor) .getResult(0); } else { auto biasType = bias.getType().cast(); if (biasType.getRank() != 1) return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); if (elementType != biasType.getElementType()) return rewriter.notifyMatchFailure(op, "unimplemented: type promotion"); auto resultRank = initTensor.getType().cast().getRank(); SmallVector indexingMaps = { // bias is used to initialize the channels - dimension 1 of output AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0, rewriter.getAffineDimExpr(1), context), rewriter.getMultiDimIdentityMap(resultRank)}; SmallVector iteratorTypes(resultRank, "parallel"); biasInitTensor = rewriter .create( loc, initTensor.getType(), bias, initTensor, indexingMaps, iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { b.create(loc, args[0]); }) .getResult(0); } auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); // TODO: add 1D and 3D case Value conv = rewriter .create( loc, biasInitTensor.getType(), ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr, dilationAttr) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } }; } // namespace void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); }