From f18b2be9111aa1c199d42ab61e1e1c8d06bf5a02 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Thu, 19 May 2022 15:48:15 -0700 Subject: [PATCH] torch,linalg: add support for translating aten.linalg.vector_norm (#839) This patch adds support for the torch.linalg.vector_norm op to the torch dialect, including the necessary shape function. It also extends the conversion of reduction operators to support lowering of AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end tests to validate the lowering. There exist several opportunities to make this lowering optimal and robust. For instance, in its current form, the translation does not support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise each element to the power 1.0. Similarly, L2 norms could benefit from strength reduction. Since the canonicalization pass is not able to apply these optimizations, we should consider applying them during the linalg lowering itself. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 27 ++ lib/Conversion/TorchToLinalg/Reduction.cpp | 275 ++++++++++++++---- .../TorchToLinalg/Uncategorized.cpp | 114 +------- lib/Conversion/TorchToLinalg/Utils.cpp | 128 +++++++- lib/Conversion/TorchToLinalg/Utils.h | 25 +- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 8 + lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 23 ++ .../jit_ir/build_tools/shape_lib_gen.py | 5 + .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../test_suite/reduction.py | 108 +++++++ 10 files changed, 522 insertions(+), 192 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 351c15dae..60078d8d1 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3625,6 +3625,33 @@ def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [ }]; } +def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$ord, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$keepdim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgVectorNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenLinalgVectorNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index e1c856f2f..cf2d2beee 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -67,7 +68,8 @@ public: bool keepDim = false; if (!matchPattern(maxDimOp.keepdim(), m_TorchConstantBool(&keepDim))) - return failure(); + return rewriter.notifyMatchFailure( + maxDimOp, "aten.max_dim requires boolean value for keepdim"); int64_t dim; if (!matchPattern(maxDimOp.dim(), m_TorchConstantInt(&dim))) @@ -173,9 +175,8 @@ public: }; } // namespace -static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc, - Operation *op, - Type elementType) { +static Value createInitElementForReduceOp(OpBuilder &b, Location loc, + Operation *op, Type elementType) { if (isa(op)) return b.create(loc, b.getZeroAttr(elementType)); @@ -195,14 +196,18 @@ static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc, elementType.getIntOrFloatBitWidth()))); } - op->emitError("unimplemented lowering in " - "createLinalgNeutralElementForReduceOp"); + if (isa(op)) + return b.create(loc, b.getZeroAttr(elementType)); + + op->emitError("unimplemented lowering in createInitElementForReduceOp"); return nullptr; } -static Value createLinalgPayloadCalculationForReduceOp( - OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op, - ArrayRef operands, Type resultElementType) { +static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, + ValueRange payloadArgs, + Operation *op, + ArrayRef operands, + Type resultElementType) { if (isa(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); @@ -228,14 +233,180 @@ static Value createLinalgPayloadCalculationForReduceOp( if (intType.isSigned()) return b.create(loc, self, result); } + } else if (isa(op)) { + // This creates payload for only the first of the two linalg.generic ops. + // TODO: Short-circuit operations if `ord` is zero or one. + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + auto abs = b.create(loc, self); + AtenLinalgVectorNormOp::Adaptor adaptor(operands); + Value ord = convertScalarToDtype(b, loc, adaptor.ord(), resultElementType); + auto pow = b.create(loc, abs, ord); + return b.create(loc, pow, result); } - op->emitError("unimplemented lowering in " - "createLinalgPayloadCalculationForReduceOp"); + op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp"); return nullptr; } namespace { class ConvertReductionOp : public ConversionPattern { +private: + /// Given a reduction operation that has the `keepdim` attribute and the + /// (optional) `dim` attribute, return the source tensor operand and the + /// literal values of the attributes or failure otherwise. + template + FailureOr + computeReductionOpInfoForDimVariantOp( + T op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; + typename T::Adaptor adaptor(operands); + opInfo.tensorOperand = adaptor.self(); + auto inputType = opInfo.tensorOperand.getType().cast(); + + if (!matchPattern(op.keepdim(), m_TorchConstantBool(&opInfo.keepDim))) + return rewriter.notifyMatchFailure(op, + "`keepdim` must be a constant bool"); + + SmallVector dimList; + if (matchPattern(op.dim(), m_TorchConstantIntList(dimList))) { + // Fix negative dimensions, if any, before adding to the list. + for (int64_t dim : dimList) { + dim = toPositiveDim(dim, inputType.getRank()); + // Drop invalid dimensions + if (isValidDim(dim, inputType.getRank())) + opInfo.dimSet.insert(dim); + } + } else if (op.dim().getType().template isa()) { + // If no dimensions were specified, reduce along all dimensions + for (int64_t i = 0; i < inputType.getRank(); i++) + opInfo.dimSet.insert(i); + } else { + return rewriter.notifyMatchFailure( + op, "`dim` argument must be a constant int list or None"); + } + + return opInfo; + } + + /// Given a reduction operation, return the source tensor operand and the + /// literal values of the `keepdim` and `dim` attributes, if any, or failure + /// otherwise. + FailureOr + computeReductionOpInfo(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; + + if (isa(op)) { + opInfo.tensorOperand = operands[0]; + auto inputType = opInfo.tensorOperand.getType().cast(); + + // `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the + // input tensor. + for (int64_t i = 0; i < inputType.getRank(); i++) + opInfo.dimSet.insert(i); + + return opInfo; + } + + if (auto sumOp = dyn_cast(op)) + return computeReductionOpInfoForDimVariantOp(sumOp, operands, rewriter); + + if (auto normOp = dyn_cast(op)) + return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter); + + return rewriter.notifyMatchFailure(op, "not a supported reduce op"); + } + + /// Generate a linalg.generic operation for pointwise exponentiation of each + /// element. + Value createElementwiseExp(Location loc, Type elemType, Value exponent, + Value inputTensor, + const torch_to_linalg::ReductionOpInfo &opInfo, + ConversionPatternRewriter &rewriter) const { + bool err = false; + auto powBodyBuilder = [&](OpBuilder &builder, Location loc, + ValueRange payloadArgs) { + Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], elemType); + auto result = builder.create(loc, elem, exponent); + if (result) + builder.create(loc, Value{result}); + err = !result; + }; + + Value powOp = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, {inputTensor}, elemType, powBodyBuilder); + return err ? Value{} : powOp; + } + + FailureOr createSecondReductionForVectorNormOp( + Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp, + Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo, + ConversionPatternRewriter &rewriter) const { + // Cast `ord` to float so that we can readily pass it math.powf. + Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType); + + // TODO: Add support for ord = {0, +inf, -inf}. + auto epsilon = 1e-5; + auto ordLiteral = 0.0; + if (matchPattern(ordValue, m_TorchConstantFloat(&ordLiteral)) && + fabs(ordLiteral) < epsilon) + return rewriter.notifyMatchFailure(op, "unimplemented: L0 norm"); + + if (std::isinf(ordLiteral)) + return rewriter.notifyMatchFailure(op, "unimplemented: ord = +/- inf"); + + // Raise each summed value to the inverse of the order of the norm. + Attribute oneAttr = rewriter.getFloatAttr(elemType, 1.0); + auto oneValue = rewriter.create(loc, oneAttr); + auto inverseOrdValue = + rewriter.create(loc, oneValue, ordValue); + + // Use the results of the first reduction operation from above to generate + // a second reduction operation. + Value reduceOp = createElementwiseExp(loc, elemType, inverseOrdValue, + firstReduction, opInfo, rewriter); + if (!reduceOp) + return rewriter.notifyMatchFailure( + op, "failed to create linalg.generic operation for element-wise " + "exponentiation"); + + return reduceOp; + } + + /// Generate a linalg.generic operation for a reduction. + Value createReductionOp(Location loc, Type elemType, Operation *op, + ArrayRef operands, + const torch_to_linalg::ReductionOpInfo &opInfo, + ConversionPatternRewriter &rewriter) const { + bool err = false; + auto reductionBodyBuilder = [&](OpBuilder &builder, Location loc, + ValueRange payloadArgs) { + Value result = createLinalgPayloadForReduceOp(builder, loc, payloadArgs, + op, operands, elemType); + if (result) + builder.create(loc, result); + err = !result; + }; + + Value initElem = createInitElementForReduceOp(rewriter, loc, op, elemType); + Value reduceOp = torch_to_linalg::createReductionLinalgGeneric( + rewriter, loc, opInfo, initElem, reductionBodyBuilder); + return err ? Value{} : reduceOp; + } + + /// Depending on the operation, check validity of the result's element type. + LogicalResult + validateReductionElementType(Operation *op, Type elemType, + ConversionPatternRewriter &rewriter) const { + if (isa(op) && !elemType.isa()) + return rewriter.notifyMatchFailure( + op, "only float types are valid for vector norm ops"); + // No checks for all other reduction operations + return success(); + } + public: ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, @@ -244,67 +415,42 @@ public: matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); + return rewriter.notifyMatchFailure( + op, "invalid operand or result types to use with linalg on tensors"); - // Every reduce operation must set a value for the `dimSet`, - // `tensorOperand`, and `keepDim` in accordance with their specification. - DenseSet dimSet; - Value tensorOperand; - bool keepDim = false; - if (isa(op) || isa(op)) { - tensorOperand = operands[0]; - auto inputType = tensorOperand.getType().cast(); - - // `AtenSumOp` and `AtenMaxOp` reduces along all the dimensions of the - // input tensor. - for (int64_t i = 0; i < inputType.getRank(); i++) - dimSet.insert(i); - } else if (auto sumDimIntListOp = dyn_cast(op)) { - tensorOperand = operands[0]; - auto inputType = tensorOperand.getType().cast(); - - if (!matchPattern(sumDimIntListOp.keepdim(), - m_TorchConstantBool(&keepDim))) - return failure(); - - SmallVector dimList; - if (!matchPattern(sumDimIntListOp.dim(), m_TorchConstantIntList(dimList))) - return failure(); - for (auto dim : dimList) { - // Torch allows for negative values in dimSet to go in reverse - // order in the dimensions of the input tensor. - dim = dim >= 0 ? dim : dim + inputType.getRank(); - // Drop invalid dimensions - if (dim < inputType.getRank()) - dimSet.insert(dim); - } - } else { - return rewriter.notifyMatchFailure(op, "not a supported reduce op"); - } + FailureOr opInfo = + computeReductionOpInfo(op, operands, rewriter); + if (failed(opInfo)) + return opInfo; Location loc = op->getLoc(); auto resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); - Value initElem = createLinalgNeutralElementForReduceOp( - rewriter, loc, op, resultType.getElementType()); + Type elemType = resultType.getElementType(); + LogicalResult elemTypeCheck = + validateReductionElementType(op, elemType, rewriter); + if (failed(elemTypeCheck)) + return elemTypeCheck; - bool hadErrorCreatingPayload = false; - Value generic = torch_to_linalg::createReductionLinalgGeneric( - rewriter, loc, tensorOperand, dimSet, keepDim, initElem, - [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { - Value result = createLinalgPayloadCalculationForReduceOp( - b, loc, payloadArgs, op, operands, resultType.getElementType()); - if (!result) { - hadErrorCreatingPayload = true; - return; - } - b.create(loc, result); - }); + Value reduceOp = + createReductionOp(loc, elemType, op, operands, *opInfo, rewriter); + if (!reduceOp) + return rewriter.notifyMatchFailure( + op, "failed to create linalg.generic operation for reduction"); - if (hadErrorCreatingPayload) - return failure(); - rewriter.replaceOpWithNewOp(op, resultType, generic); + // If this is aten.linalg_vector_norm op, then we need to generate another + // linalg.generic op that references the first linalg.generic op. + if (auto normOp = dyn_cast(op)) { + AtenLinalgVectorNormOp::Adaptor adaptor(operands); + FailureOr secondReduceOp = createSecondReductionForVectorNormOp( + loc, elemType, normOp, adaptor.ord(), reduceOp, *opInfo, rewriter); + if (failed(secondReduceOp)) + return secondReduceOp; + reduceOp = *secondReduceOp; + } + + rewriter.replaceOpWithNewOp(op, resultType, reduceOp); return success(); } }; @@ -319,5 +465,6 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index de6c03f27..e21c9c0c7 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -35,113 +35,6 @@ template static bool hasElementType(Value tensor) { return tensorElementType.isa(); } -static Value createElementwiseLinalgGeneric( - OpBuilder &b, Location loc, ValueRange tensorOperands, - Type resultElementType, - function_ref bodyBuild) { - // The overall error handling strategy here is best viewed by thinking about - // what happens for a single result dimension. This loop not structured that - // way because it is hard to create the affine maps for each operand unless - // we structure the loop to iterate over tensor operands as the outer loop - // instead of inner loop. This pseudocode gives better intuition: - // ``` - // for each result dimension: - // for each tensor operand: - // if it doesn't even have high enough rank relative to the result: - // continue - // if it is a static size-1 along this result dimension: - // continue - // if this is the first tensor operand that didn't continue above: - // take its dimension size as the size of the non-broadcasted - // traversal along this dimension (this may include a dynamic size-1, - // **non-broadcasted** traversal!) - // emit error check "if the size does not match the non-broadcasted - // traversal size along this dimension, error" - // ``` - SmallVector operandRanks; - operandRanks.resize(tensorOperands.size()); - llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) { - return tensor.getType().dyn_cast().getRank(); - }); - - auto resultRankIt = - std::max_element(operandRanks.begin(), operandRanks.end()); - assert(resultRankIt != operandRanks.end() && "Unable to get result rank."); - int64_t resultRank = *resultRankIt; - - // Initialize the resultShape to all 1's, as a fallback in case - // all sizes along that result dimension are statically 1. - auto c1 = b.create(loc, /*value=*/1); - SmallVector resultShape(resultRank, c1); - SmallVector indexingMaps; - for (Value tensorOperand : tensorOperands) { - SmallVector exprs; - auto type = tensorOperand.getType().cast(); - for (auto size : llvm::enumerate(type.getShape())) { - // If the size is statically known to be 1, we don't want any - // error guards to be spuriously emitted, since we are specifically - // allowing size-1 broadcasts in this case, as they correspond to a - // constant-0 indexing map. - if (size.value() == 1) { - exprs.push_back(b.getAffineConstantExpr(0)); - continue; - } - - // The rank of this operand might be smaller than the overall rank of - // the broadcast. Add an offset to correlate it to the correct - // dimension of the result. - auto resultDim = size.index() + (resultRank - type.getRank()); - - // The generated linalg op will now be iterating along the full size - // of this dimension. Record that fact. - exprs.push_back(b.getAffineDimExpr(resultDim)); - - // Now, we need to ensure that such iteration is not going to trigger - // undefined behavior, by doing appropriate checks against the current - // dimension size. - auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index()); - - // If the result size of this dimension has so far only hit the - // statically-known-to-be-1 case above (i.e., we have not yet assigned a - // new Value to `resultShape[resultDim]`), then we have no other dynamic - // values to check against, and merely need to record the current - // dimension size. - if (resultShape[resultDim] == c1) { - resultShape[resultDim] = currentDimSize; - continue; - } - - // We prohibit the size-1 dynamic broadcasting scenario, so just check - // for exact equality with the running result size. - // This is the check which protects against the undefined behavior of - // the generated linalg op in the case of iterating two operands with - // dimensions sizes that are expected to match. - auto equalToRunning = - b.create(loc, arith::CmpIPredicate::eq, - resultShape[resultDim], currentDimSize); - b.create(loc, equalToRunning, - "mismatched size for broadcast"); - } - indexingMaps.push_back(AffineMap::get( - /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext())); - } - - SmallVector iteratorTypes(resultRank, - getParallelIteratorTypeName()); - // Add the indexing map for the outs init tensor. - indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank)); - - Value initTensor = b.create( - loc, getAsOpFoldResult(resultShape), resultElementType); - return b - .create(loc, - /*resultTensorTypes=*/initTensor.getType(), - /*inputs=*/tensorOperands, - /*outputs=*/initTensor, indexingMaps, - iteratorTypes, bodyBuild) - .getResult(0); -} - template static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, @@ -964,7 +857,7 @@ public: ->convertType(op->getResult(0).getType()) .cast(); bool hadErrorCreatingPayload = false; - Value generic = createElementwiseLinalgGeneric( + Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value result = createLinalgPayloadCalculationForElementwiseOp( @@ -1031,7 +924,7 @@ public: Value zeroVal = rewriter.create( loc, rewriter.getZeroAttr(elementType)); - Value finalRes = createElementwiseLinalgGeneric( + Value finalRes = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, {target}, elementType, [&](OpBuilder &b, Location loc, ValueRange args) { Value targetVal = args[0]; @@ -1068,8 +961,9 @@ public: /*inclusive=*/false); DenseSet dimSet(dimsToReduce.begin(), dimsToReduce.end()); + auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet}; finalRes = torch_to_linalg::createReductionLinalgGeneric( - rewriter, loc, finalRes, dimSet, /*keepDim=*/false, + rewriter, loc, opInfo, /*initElem=*/zeroVal, [&](OpBuilder &b, Location loc, ValueRange args) { Value newVal = args[0]; diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index fcc26a68f..4d8492b45 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -98,23 +98,22 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, } Value torch_to_linalg::createReductionLinalgGeneric( - OpBuilder &b, Location loc, Value tensorOperand, - const DenseSet &dimSet, bool keepDim, Value initElem, + OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem, function_ref bodyBuild) { - auto inputType = tensorOperand.getType().cast(); + auto inputType = opInfo.tensorOperand.getType().cast(); // Get the result shape by obtaining the size of each // dimension in the input tensor that is not getting reduced. - // If `keepDim` is true, the rank of the output tensor + // If `opInfo.keepDim` is true, the rank of the output tensor // is kept the same as the rank of the input tensor, and the // reduced dimensions are set to have size 1. auto c1 = b.create(loc, /*value=*/1); SmallVector resultShape; for (int64_t i = 0; i < inputType.getRank(); i++) { - auto currentDimSize = b.create(loc, tensorOperand, i); - if (!dimSet.contains(i)) + auto currentDimSize = b.create(loc, opInfo.tensorOperand, i); + if (!opInfo.dimSet.contains(i)) resultShape.push_back(currentDimSize); - else if (keepDim) + else if (opInfo.keepDim) resultShape.push_back(c1); } @@ -127,11 +126,11 @@ Value torch_to_linalg::createReductionLinalgGeneric( for (auto size : llvm::enumerate(inputType.getShape())) { exprs.push_back(b.getAffineDimExpr(size.index())); - if (dimSet.contains(size.index())) { + if (opInfo.dimSet.contains(size.index())) { iteratorTypes.push_back(getReductionIteratorTypeName()); - // If `keepDim`, create affine map to the first element + // If `opInfo.keepDim`, create affine map to the first element // in the current dimension. - if (keepDim) + if (opInfo.keepDim) resultExprs.push_back(b.getAffineConstantExpr(0)); } else { iteratorTypes.push_back(getParallelIteratorTypeName()); @@ -146,7 +145,114 @@ Value torch_to_linalg::createReductionLinalgGeneric( return b .create( loc, /*resultTensorTypes=*/accumulator.getType(), - /*inputs=*/tensorOperand, + /*inputs=*/opInfo.tensorOperand, /*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild) .getResult(0); } + +Value torch_to_linalg::createElementwiseLinalgGeneric( + OpBuilder &b, Location loc, ValueRange tensorOperands, + Type resultElementType, + function_ref bodyBuild) { + // The overall error handling strategy here is best viewed by thinking about + // what happens for a single result dimension. This loop not structured that + // way because it is hard to create the affine maps for each operand unless + // we structure the loop to iterate over tensor operands as the outer loop + // instead of inner loop. This pseudocode gives better intuition: + // ``` + // for each result dimension: + // for each tensor operand: + // if it doesn't even have high enough rank relative to the result: + // continue + // if it is a static size-1 along this result dimension: + // continue + // if this is the first tensor operand that didn't continue above: + // take its dimension size as the size of the non-broadcasted + // traversal along this dimension (this may include a dynamic size-1, + // **non-broadcasted** traversal!) + // emit error check "if the size does not match the non-broadcasted + // traversal size along this dimension, error" + // ``` + SmallVector operandRanks; + operandRanks.resize(tensorOperands.size()); + llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) { + return tensor.getType().dyn_cast().getRank(); + }); + + auto resultRankIt = + std::max_element(operandRanks.begin(), operandRanks.end()); + assert(resultRankIt != operandRanks.end() && "Unable to get result rank."); + int64_t resultRank = *resultRankIt; + + // Initialize the resultShape to all 1's, as a fallback in case + // all sizes along that result dimension are statically 1. + auto c1 = b.create(loc, /*value=*/1); + SmallVector resultShape(resultRank, c1); + SmallVector indexingMaps; + for (Value tensorOperand : tensorOperands) { + SmallVector exprs; + auto type = tensorOperand.getType().cast(); + for (auto size : llvm::enumerate(type.getShape())) { + // If the size is statically known to be 1, we don't want any + // error guards to be spuriously emitted, since we are specifically + // allowing size-1 broadcasts in this case, as they correspond to a + // constant-0 indexing map. + if (size.value() == 1) { + exprs.push_back(b.getAffineConstantExpr(0)); + continue; + } + + // The rank of this operand might be smaller than the overall rank of + // the broadcast. Add an offset to correlate it to the correct + // dimension of the result. + auto resultDim = size.index() + (resultRank - type.getRank()); + + // The generated linalg op will now be iterating along the full size + // of this dimension. Record that fact. + exprs.push_back(b.getAffineDimExpr(resultDim)); + + // Now, we need to ensure that such iteration is not going to trigger + // undefined behavior, by doing appropriate checks against the current + // dimension size. + auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index()); + + // If the result size of this dimension has so far only hit the + // statically-known-to-be-1 case above (i.e., we have not yet assigned a + // new Value to `resultShape[resultDim]`), then we have no other dynamic + // values to check against, and merely need to record the current + // dimension size. + if (resultShape[resultDim] == c1) { + resultShape[resultDim] = currentDimSize; + continue; + } + + // We prohibit the size-1 dynamic broadcasting scenario, so just check + // for exact equality with the running result size. + // This is the check which protects against the undefined behavior of + // the generated linalg op in the case of iterating two operands with + // dimensions sizes that are expected to match. + auto equalToRunning = + b.create(loc, arith::CmpIPredicate::eq, + resultShape[resultDim], currentDimSize); + b.create(loc, equalToRunning, + "mismatched size for broadcast"); + } + indexingMaps.push_back(AffineMap::get( + /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext())); + } + + SmallVector iteratorTypes(resultRank, + getParallelIteratorTypeName()); + // Add the indexing map for the outs init tensor. + indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank)); + + Value initTensor = b.create( + loc, getAsOpFoldResult(resultShape), resultElementType); + return b + .create(loc, + /*resultTensorTypes=*/initTensor.getType(), + /*inputs=*/tensorOperands, + /*outputs=*/initTensor, indexingMaps, + iteratorTypes, bodyBuild) + .getResult(0); +} diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index 02c7aa51b..eb16387e0 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -13,6 +13,12 @@ namespace mlir { namespace torch { namespace torch_to_linalg { +struct ReductionOpInfo { + bool keepDim; + Value tensorOperand; + DenseSet dimSet; +}; + // Helper function to get the padding tensor given the padding int values. Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &lowPaddingInts, @@ -33,16 +39,21 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in, Value kernelSizeInt, Value strideInt, bool ceilMode = false); -// Create a reduction of `tensorOperand`, reducing along the dimensions -// in `dimSet`. If `keepDim` is true, the output tensor is the same -// rank as the `tensorOperand` and reduced dimensions are set to size 1. -// `initElem` is the element used to initialize the output tensor -// where the reduction will be stored. +// Create a reduction of `opInfo.tensorOperand`, reducing along the dimensions +// in `opInfo.dimSet`. If `opInfo.keepDim` is true, the output tensor is the +// same rank as the `opInfo.tensorOperand` and reduced dimensions are set to +// size 1. `initElem` is the element used to initialize the output tensor where +// the reduction will be stored. Value createReductionLinalgGeneric( - OpBuilder &b, Location loc, Value tensorOperand, - const DenseSet &dimSet, bool keepDim, Value initElem, + OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem, function_ref bodyBuild); +// Create a pointwise operation that uses values in `tensorOperands`, such that +// the element type of the resulting tensor is `resultElementType`. +Value createElementwiseLinalgGeneric( + OpBuilder &b, Location loc, ValueRange tensorOperands, + Type resultElementType, + function_ref bodyBuild); } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index aac10203e..715944967 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -988,6 +988,14 @@ ChangeResult TypeAnalyzer::visitOperation( if (auto scalarImplicit = dyn_cast(op)) return visitAtenScalarImplicitOp(scalarImplicit, operands); + if (auto vectorNorm = dyn_cast(op)) { + Type defaultDtype = operands[0]->getValue().dtype; + Type dtype = getDtypeOrDefault(vectorNorm.getContext(), vectorNorm.dtype(), + defaultDtype); + return visitReductionAlongDimIntListOp( + vectorNorm, vectorNorm.dim(), vectorNorm.keepdim(), dtype, operands); + } + // Otherwise, this is an unknown operation. Just mark all results as // having reached a pessimistic fixpoint. return markAllPessimisticFixpoint(op->getResults()); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index cd3e8064a..a02745bdd 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -3078,6 +3078,29 @@ module { %none = torch.constant.none return %none : !torch.none } + func.func @"__torch_mlir_shape_fn.aten.linalg_vector_norm"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list { + %none = torch.constant.none + %true = torch.constant.bool true + %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool + %1 = torch.prim.If %0 -> (!torch.list) { + %2 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list + %3 = torch.derefine %arg4 : !torch.optional to !torch.any + %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + torch.prim.If.yield %4 : !torch.list + } else { + %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %3 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %2, %true, init() { + ^bb0(%arg5: !torch.int): + %6 = torch.aten.append.t %3, %arg5 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + %4 = torch.derefine %arg4 : !torch.optional to !torch.any + %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.mean_dim(%arg0, %3, %arg3, %4) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + torch.prim.If.yield %5 : !torch.list + } + return %1 : !torch.list + } } )mlir"); #pragma clang diagnostic pop diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 479535fe2..3f3e8d8e6 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -946,6 +946,11 @@ def hacky_get_unknown_dimension_size(): def aten〇bincount(self: List[int], weights: Optional[List[int]] = None, minlength: int = 0) -> List[int]: return [hacky_get_unknown_dimension_size()] +def aten〇linalg_vector_norm(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: + if dim is None: + dim = list(range(len(self))) + return upstream_shape_helpers.mean_dim(self, dim, keepdim, dtype) + # ============================================================================== # Shape library generator main(). # ============================================================================== diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 15b31de32..0d502222d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -365,6 +365,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") + emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index 35711201c..be1ed5078 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -383,3 +383,111 @@ class ReduceMaxUnsignedIntModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceMaxUnsignedIntModule()) def ReduceMaxUnsignedIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(100, (3, 4, 5))) + +# ============================================================================== + +class ReduceL1NormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.linalg.vector_norm(a, dim=0, ord=1) + +@register_test_case(module_factory=lambda: ReduceL1NormModule()) +def ReduceL1NormModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceL1NormWithDTypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.linalg.vector_norm(a, dim=0, ord=1, dtype=torch.float64) + +@register_test_case(module_factory=lambda: ReduceL1NormWithDTypeModule()) +def ReduceL1NormWithDTypeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float32)) + +# ============================================================================== + +class ReduceL2NormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.linalg.vector_norm(a, dim=0) + +@register_test_case(module_factory=lambda: ReduceL2NormModule()) +def ReduceL2NormModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceLN3NormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.linalg.vector_norm(a, dim=0, ord=-3) + +@register_test_case(module_factory=lambda: ReduceLN3NormModule()) +def ReduceLN3NormModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceL3NormAllDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.linalg.vector_norm(a, dim=None, ord=3) + +@register_test_case(module_factory=lambda: ReduceL3NormAllDimsModule()) +def ReduceL3NormAllDimsModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceL3NormKeepDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.linalg.vector_norm(a, keepdim=True, ord=3) + +@register_test_case(module_factory=lambda: ReduceL3NormKeepDimModule()) +def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5))