From eb7bf78a9c1e250949cf0151628f35fb0ac06903 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Mon, 26 Aug 2024 17:06:06 -0400 Subject: [PATCH] Add RestructureNonConstantAxes pass to address reduce op tests failing on non constant axes (#3600) --- .../Dialect/Torch/Transforms/Passes.h | 6 + .../Dialect/Torch/Transforms/Passes.td | 20 ++ lib/Dialect/Torch/Transforms/CMakeLists.txt | 1 + .../Transforms/RestructureNonConstantAxes.cpp | 277 ++++++++++++++++++ .../TorchConversion/Transforms/Passes.cpp | 4 + 5 files changed, 308 insertions(+) create mode 100644 lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index aef6baa5d..e825938ee 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -149,6 +149,12 @@ StringRef getAbstractInterpLibrary(); static const char kTorchOpPrefix[] = R"(torch.)"; +void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns, + MLIRContext *context); + +std::unique_ptr> +createRestructureNonConstantAxesPass(); + } // namespace Torch /// Registers all Torch transformation passes. diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 6439feb39..e6b19201e 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -431,4 +431,24 @@ def VerifyBackendContractNoDecompositions }]; } +def RestructureNonConstantAxes + : Pass<"torch-restructure-non-constant-axes", "func::FuncOp"> { + let summary = "Ensure that every Reduction.cpp op has a constant reduction axis."; + let constructor = [{ + mlir::torch::Torch::createRestructureNonConstantAxesPass() + }]; + let description = [{ + This pass ensures that every Reduction.cpp op has a constant reduction axis. + + It does so using reshapes. For example, a <1,2,3,4,5> tensor will be reshaped to a tensor + and reduced on axis 1 to produce a tensor. The resulting tensor will be reshaped back to the original shape. + + Then when the axis is supplied at runtime (say axis = -2), the shapes will be computed as so: + becomes <6,4,5> + which gets reduced to <6,1,5> + and rehsaped back to the original reduction op's output shape, + <1,2,3,1,5> + }]; +} + #endif // TORCHMLIR_TORCH_PASSES diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index ba6af02c8..1ce006fbe 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_library(TorchMLIRTorchPasses ReifyShapeCalculations.cpp ReifyDtypeCalculations.cpp ReifyAbstractInterpCalculationsUtils.cpp + RestructureNonConstantAxes.cpp ScalarizeShapes.cpp AbstractInterpLibrary.cpp SimplifyShapeCalculations.cpp diff --git a/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp new file mode 100644 index 000000000..2e1b8e6d3 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp @@ -0,0 +1,277 @@ +//===- RestructureNonConstantAxes.cpp --------------------------------*- +// C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "torch-lower-to-backend-contract" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +template +class ConstantifyDimArgument : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + bool isDimConstant(SrcOp op) const { + SmallVector dimList; + int64_t dim; + return matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList)) || + matchPattern(op.getDim(), m_TorchConstantInt(&dim)); + } + + /* + This function renders the reduction dim constant by reshaping the input tensor + such that the dim argument is the middle dimension. + + For example, if the input tensor has shape [3,4,5,6,7] and the dim argument is + -2, the input tensor is reshaped to [3,4,5,6,7] -> [12,5,42], the reduction + operation is applied, and the result is reshaped back to [3,4,1,6,7]. + + Since we don't know the dim argument at compile time, we need to compute the + arguments to the reshape op at runtime. We do this by computing the new shape + of the tensor by multiplying the shapes of the tensor before and after the dim + argument, and then reshaping the tensor to this new shape. + */ + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + Value self = op.getSelf(); + Value dim = op.getDim(); + + if (isDimConstant(op)) { + return rewriter.notifyMatchFailure(op, + "dim argument is already constant"); + } + + if (isa(dim.getType())) { + return rewriter.notifyMatchFailure( + op, "RestructureNonConstantAxes does not support None dim"); + } + + // when keepdim is not constant, check the ranks of the input and output + // tensors + ValueTensorType selfTy = + llvm::cast(op.getSelf().getType()); + ValueTensorType resultTy = + llvm::cast(op.getResult().getType()); + if (selfTy.hasSizes() && resultTy.hasSizes() && + selfTy.getSizes().size() != resultTy.getSizes().size()) { + return rewriter.notifyMatchFailure( + op, + "RestructureNonConstantAxes does not yet support keepdim=false, but " + "the input and output tensors have different ranks"); + } + + Type intType = rewriter.getType(); + Type boolType = rewriter.getType(); + auto createInt = [&](int value) { + return rewriter.create( + loc, intType, + rewriter.getIntegerAttr(rewriter.getIntegerType(64), value)); + }; + Value zero = createInt(0); + Value one = createInt(1); + + // handle when dim is a single element list + bool oldDimIsList = isa(dim.getType()); + if (oldDimIsList) { + Value len = rewriter.create(loc, intType, dim); + Value dimListIsLengthOne = + rewriter.create(loc, boolType, len, one); + rewriter.create( + loc, dimListIsLengthOne, + rewriter.getStringAttr("RestructureNonConstantAxes does not support " + "dim lists with more than one element")); + dim = rewriter.create(loc, intType, dim, zero); + } + + // Normalize negative dim + Value rank = rewriter.create(loc, intType, self); + Value isNegative = rewriter.create(loc, dim, zero); + Value rankOffset = rewriter.create( + loc, intType, + rewriter.create(loc, intType, isNegative), rank); + dim = rewriter.create(loc, intType, dim, rankOffset); + + auto createConditionalMult = [&](Value self, Value multiplier, + Value condition) { + // compute: + // result = codition ? (self * multiplier) : self + // via + // result = self * (1 + (multiplier - 1) * condition) + // which translates to: + + // result = multiplier - 1 + Value result = rewriter.create( + loc, intType, multiplier, createInt(1)); + // result = result * condition + result = + rewriter.create(loc, intType, result, condition); + // result = result + 1 + result = rewriter.create(loc, intType, result, + createInt(1)); + // result = self * result + result = rewriter.create(loc, intType, self, result); + return result; + }; + + // new shape = [beforeDim, dimSize, afterDim] + Value beforeProd = createInt(1); + Value afterProd = createInt(1); + Value dimSize = createInt(1); + + for (size_t i = 0; i < selfTy.getSizes().size(); ++i) { + Value idx = createInt(i); + Value size = + rewriter.create(loc, intType, self, idx); + + Value isBeforeDim = + rewriter.create(loc, boolType, idx, dim); + isBeforeDim = + rewriter.create(loc, intType, isBeforeDim); + Value isAfterDim = + rewriter.create(loc, boolType, idx, dim); + isAfterDim = + rewriter.create(loc, intType, isAfterDim); + + Value isEqualToDim = + rewriter.create(loc, boolType, idx, dim); + isEqualToDim = + rewriter.create(loc, intType, isEqualToDim); + dimSize = createConditionalMult(dimSize, size, isEqualToDim); + + beforeProd = createConditionalMult(beforeProd, size, isBeforeDim); + afterProd = createConditionalMult(afterProd, size, isAfterDim); + } + + Value newShape = rewriter.create( + loc, rewriter.getType(intType), + ValueRange{beforeProd, dimSize, afterProd}); + + // Reshape input + auto newSelfTy = selfTy.getWithSizesAndDtype( + SmallVector{Torch::kUnknownSize, Torch::kUnknownSize, + Torch::kUnknownSize}, + selfTy.getDtype()); + Value reshapedSelf = + rewriter.create(loc, newSelfTy, self, newShape); + + // construct new operange range where self is replaced with reshapedSelf + // tensor, and dim is replaced with 1 + Value newDim; + if (oldDimIsList) { + newDim = rewriter.create( + loc, rewriter.getType(intType), ValueRange{one}); + } else { + newDim = one; + } + ValueRange oldOperands = op->getOperands(); + SmallVector newOperandsVect; + for (size_t i = 0; i < oldOperands.size(); ++i) { + if (oldOperands[i] == op.getSelf()) { + newOperandsVect.push_back(reshapedSelf); + } else if (oldOperands[i] == op.getDim()) { + newOperandsVect.push_back(newDim); + } else { + newOperandsVect.push_back(oldOperands[i]); + } + } + ValueRange newOperands = ValueRange(newOperandsVect); + + // construct new reduction op result type + ValueTensorType newResultTy = + cast(resultTy.getWithSizesAndDtype( + SmallVector{Torch::kUnknownSize, 1, Torch::kUnknownSize}, + resultTy.getDtype())); + + Value newReductionOp = + rewriter.create(loc, newResultTy, newOperands, op->getAttrs()); + + // Reshape the result back to original shape + ValueTensorType oldResultTy = + cast(op.getResult().getType()); + SmallVector shapeValues; + for (auto dim : oldResultTy.getSizes()) { + shapeValues.push_back(createInt(dim)); + } + Value originalShape = rewriter.create( + loc, rewriter.getType(intType), shapeValues); + Value result = rewriter.create( + loc, op->getResult(0).getType(), newReductionOp, originalShape); + + rewriter.replaceOp(op, result); + return success(); + }; +}; + +template +void addConstantifyDimArgumentPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + // simple variadic template to sugar up adding the patterns + (patterns.add>(context), ...); +} + +void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns, + MLIRContext *context) { + // these are the reduction ops with a dim argument + + addConstantifyDimArgumentPatterns< + // not supported because they have multiple results + // AtenMaxDimOp, + // AtenMinDimOp, + AtenSumDimIntListOp, AtenAllDimOp, AtenLinalgVectorNormOp, + AtenFrobeniusNormDimOp>(patterns, context); +} + +class RestructureNonConstantAxesPass + : public RestructureNonConstantAxesBase { +public: + RestructureNonConstantAxesPass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + + RewritePatternSet patterns(context); + + populateRestructureNonConstantAxesPattern(patterns, context); + + // TODO: Debug visitation order to make this more efficient. + // A single linear scan should suffice. + GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.maxIterations = GreedyRewriteConfig::kNoLimit; + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::Torch::createRestructureNonConstantAxesPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 42ec495d9..40d7b629a 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -64,6 +64,10 @@ void mlir::torch::registerTorchConversionPasses() { void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( OpPassManager &pm) { + // Fix non constant dims passed to reduction ops + pm.addNestedPass( + torch::Torch::createRestructureNonConstantAxesPass()); + // We want to fuse quantized operations together before lowering to linalg. pm.addNestedPass(Torch::createFuseQuantizedOpsPass()); pm.addNestedPass(Torch::createScalarizeShapesPass());