mirror of https://github.com/llvm/torch-mlir
Add RestructureNonConstantAxes pass to address reduce op tests failing on non constant axes (#3600)
parent
638ef14512
commit
eb7bf78a9c
|
@ -149,6 +149,12 @@ StringRef getAbstractInterpLibrary();
|
|||
|
||||
static const char kTorchOpPrefix[] = R"(torch.)";
|
||||
|
||||
void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns,
|
||||
MLIRContext *context);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createRestructureNonConstantAxesPass();
|
||||
|
||||
} // namespace Torch
|
||||
|
||||
/// Registers all Torch transformation passes.
|
||||
|
|
|
@ -431,4 +431,24 @@ def VerifyBackendContractNoDecompositions
|
|||
}];
|
||||
}
|
||||
|
||||
def RestructureNonConstantAxes
|
||||
: Pass<"torch-restructure-non-constant-axes", "func::FuncOp"> {
|
||||
let summary = "Ensure that every Reduction.cpp op has a constant reduction axis.";
|
||||
let constructor = [{
|
||||
mlir::torch::Torch::createRestructureNonConstantAxesPass()
|
||||
}];
|
||||
let description = [{
|
||||
This pass ensures that every Reduction.cpp op has a constant reduction axis.
|
||||
|
||||
It does so using reshapes. For example, a <1,2,3,4,5> tensor will be reshaped to a <?,?,?> tensor
|
||||
and reduced on axis 1 to produce a <?,1,?> tensor. The resulting tensor will be reshaped back to the original shape.
|
||||
|
||||
Then when the axis is supplied at runtime (say axis = -2), the shapes will be computed as so:
|
||||
<?,?,?> becomes <6,4,5>
|
||||
which gets reduced to <6,1,5>
|
||||
and rehsaped back to the original reduction op's output shape,
|
||||
<1,2,3,1,5>
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TORCHMLIR_TORCH_PASSES
|
||||
|
|
|
@ -17,6 +17,7 @@ add_mlir_library(TorchMLIRTorchPasses
|
|||
ReifyShapeCalculations.cpp
|
||||
ReifyDtypeCalculations.cpp
|
||||
ReifyAbstractInterpCalculationsUtils.cpp
|
||||
RestructureNonConstantAxes.cpp
|
||||
ScalarizeShapes.cpp
|
||||
AbstractInterpLibrary.cpp
|
||||
SimplifyShapeCalculations.cpp
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
//===- RestructureNonConstantAxes.cpp --------------------------------*-
|
||||
// C++-*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "torch-lower-to-backend-contract"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename SrcOp>
|
||||
class ConstantifyDimArgument : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
|
||||
bool isDimConstant(SrcOp op) const {
|
||||
SmallVector<int64_t> dimList;
|
||||
int64_t dim;
|
||||
return matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList)) ||
|
||||
matchPattern(op.getDim(), m_TorchConstantInt(&dim));
|
||||
}
|
||||
|
||||
/*
|
||||
This function renders the reduction dim constant by reshaping the input tensor
|
||||
such that the dim argument is the middle dimension.
|
||||
|
||||
For example, if the input tensor has shape [3,4,5,6,7] and the dim argument is
|
||||
-2, the input tensor is reshaped to [3,4,5,6,7] -> [12,5,42], the reduction
|
||||
operation is applied, and the result is reshaped back to [3,4,1,6,7].
|
||||
|
||||
Since we don't know the dim argument at compile time, we need to compute the
|
||||
arguments to the reshape op at runtime. We do this by computing the new shape
|
||||
of the tensor by multiplying the shapes of the tensor before and after the dim
|
||||
argument, and then reshaping the tensor to this new shape.
|
||||
*/
|
||||
LogicalResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
|
||||
Value self = op.getSelf();
|
||||
Value dim = op.getDim();
|
||||
|
||||
if (isDimConstant(op)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"dim argument is already constant");
|
||||
}
|
||||
|
||||
if (isa<Torch::NoneType>(dim.getType())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "RestructureNonConstantAxes does not support None dim");
|
||||
}
|
||||
|
||||
// when keepdim is not constant, check the ranks of the input and output
|
||||
// tensors
|
||||
ValueTensorType selfTy =
|
||||
llvm::cast<ValueTensorType>(op.getSelf().getType());
|
||||
ValueTensorType resultTy =
|
||||
llvm::cast<ValueTensorType>(op.getResult().getType());
|
||||
if (selfTy.hasSizes() && resultTy.hasSizes() &&
|
||||
selfTy.getSizes().size() != resultTy.getSizes().size()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"RestructureNonConstantAxes does not yet support keepdim=false, but "
|
||||
"the input and output tensors have different ranks");
|
||||
}
|
||||
|
||||
Type intType = rewriter.getType<Torch::IntType>();
|
||||
Type boolType = rewriter.getType<Torch::BoolType>();
|
||||
auto createInt = [&](int value) {
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, intType,
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), value));
|
||||
};
|
||||
Value zero = createInt(0);
|
||||
Value one = createInt(1);
|
||||
|
||||
// handle when dim is a single element list
|
||||
bool oldDimIsList = isa<Torch::ListType>(dim.getType());
|
||||
if (oldDimIsList) {
|
||||
Value len = rewriter.create<Torch::AtenLenTOp>(loc, intType, dim);
|
||||
Value dimListIsLengthOne =
|
||||
rewriter.create<Torch::AtenEqIntOp>(loc, boolType, len, one);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
loc, dimListIsLengthOne,
|
||||
rewriter.getStringAttr("RestructureNonConstantAxes does not support "
|
||||
"dim lists with more than one element"));
|
||||
dim = rewriter.create<Torch::Aten__Getitem__TOp>(loc, intType, dim, zero);
|
||||
}
|
||||
|
||||
// Normalize negative dim
|
||||
Value rank = rewriter.create<Torch::AtenDimOp>(loc, intType, self);
|
||||
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(loc, dim, zero);
|
||||
Value rankOffset = rewriter.create<Torch::AtenMulIntOp>(
|
||||
loc, intType,
|
||||
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isNegative), rank);
|
||||
dim = rewriter.create<Torch::AtenAddIntOp>(loc, intType, dim, rankOffset);
|
||||
|
||||
auto createConditionalMult = [&](Value self, Value multiplier,
|
||||
Value condition) {
|
||||
// compute:
|
||||
// result = codition ? (self * multiplier) : self
|
||||
// via
|
||||
// result = self * (1 + (multiplier - 1) * condition)
|
||||
// which translates to:
|
||||
|
||||
// result = multiplier - 1
|
||||
Value result = rewriter.create<Torch::AtenSubIntOp>(
|
||||
loc, intType, multiplier, createInt(1));
|
||||
// result = result * condition
|
||||
result =
|
||||
rewriter.create<Torch::AtenMulIntOp>(loc, intType, result, condition);
|
||||
// result = result + 1
|
||||
result = rewriter.create<Torch::AtenAddIntOp>(loc, intType, result,
|
||||
createInt(1));
|
||||
// result = self * result
|
||||
result = rewriter.create<Torch::AtenMulIntOp>(loc, intType, self, result);
|
||||
return result;
|
||||
};
|
||||
|
||||
// new shape = [beforeDim, dimSize, afterDim]
|
||||
Value beforeProd = createInt(1);
|
||||
Value afterProd = createInt(1);
|
||||
Value dimSize = createInt(1);
|
||||
|
||||
for (size_t i = 0; i < selfTy.getSizes().size(); ++i) {
|
||||
Value idx = createInt(i);
|
||||
Value size =
|
||||
rewriter.create<Torch::AtenSizeIntOp>(loc, intType, self, idx);
|
||||
|
||||
Value isBeforeDim =
|
||||
rewriter.create<Torch::AtenLtIntOp>(loc, boolType, idx, dim);
|
||||
isBeforeDim =
|
||||
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isBeforeDim);
|
||||
Value isAfterDim =
|
||||
rewriter.create<Torch::AtenGtIntOp>(loc, boolType, idx, dim);
|
||||
isAfterDim =
|
||||
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isAfterDim);
|
||||
|
||||
Value isEqualToDim =
|
||||
rewriter.create<Torch::AtenEqIntOp>(loc, boolType, idx, dim);
|
||||
isEqualToDim =
|
||||
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isEqualToDim);
|
||||
dimSize = createConditionalMult(dimSize, size, isEqualToDim);
|
||||
|
||||
beforeProd = createConditionalMult(beforeProd, size, isBeforeDim);
|
||||
afterProd = createConditionalMult(afterProd, size, isAfterDim);
|
||||
}
|
||||
|
||||
Value newShape = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, rewriter.getType<Torch::ListType>(intType),
|
||||
ValueRange{beforeProd, dimSize, afterProd});
|
||||
|
||||
// Reshape input
|
||||
auto newSelfTy = selfTy.getWithSizesAndDtype(
|
||||
SmallVector<int64_t>{Torch::kUnknownSize, Torch::kUnknownSize,
|
||||
Torch::kUnknownSize},
|
||||
selfTy.getDtype());
|
||||
Value reshapedSelf =
|
||||
rewriter.create<Torch::AtenViewOp>(loc, newSelfTy, self, newShape);
|
||||
|
||||
// construct new operange range where self is replaced with reshapedSelf
|
||||
// tensor, and dim is replaced with 1
|
||||
Value newDim;
|
||||
if (oldDimIsList) {
|
||||
newDim = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, rewriter.getType<Torch::ListType>(intType), ValueRange{one});
|
||||
} else {
|
||||
newDim = one;
|
||||
}
|
||||
ValueRange oldOperands = op->getOperands();
|
||||
SmallVector<Value> newOperandsVect;
|
||||
for (size_t i = 0; i < oldOperands.size(); ++i) {
|
||||
if (oldOperands[i] == op.getSelf()) {
|
||||
newOperandsVect.push_back(reshapedSelf);
|
||||
} else if (oldOperands[i] == op.getDim()) {
|
||||
newOperandsVect.push_back(newDim);
|
||||
} else {
|
||||
newOperandsVect.push_back(oldOperands[i]);
|
||||
}
|
||||
}
|
||||
ValueRange newOperands = ValueRange(newOperandsVect);
|
||||
|
||||
// construct new reduction op result type
|
||||
ValueTensorType newResultTy =
|
||||
cast<ValueTensorType>(resultTy.getWithSizesAndDtype(
|
||||
SmallVector<int64_t>{Torch::kUnknownSize, 1, Torch::kUnknownSize},
|
||||
resultTy.getDtype()));
|
||||
|
||||
Value newReductionOp =
|
||||
rewriter.create<SrcOp>(loc, newResultTy, newOperands, op->getAttrs());
|
||||
|
||||
// Reshape the result back to original shape
|
||||
ValueTensorType oldResultTy =
|
||||
cast<ValueTensorType>(op.getResult().getType());
|
||||
SmallVector<Value> shapeValues;
|
||||
for (auto dim : oldResultTy.getSizes()) {
|
||||
shapeValues.push_back(createInt(dim));
|
||||
}
|
||||
Value originalShape = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, rewriter.getType<Torch::ListType>(intType), shapeValues);
|
||||
Value result = rewriter.create<Torch::AtenViewOp>(
|
||||
loc, op->getResult(0).getType(), newReductionOp, originalShape);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
};
|
||||
};
|
||||
|
||||
template <typename... OpTypes>
|
||||
void addConstantifyDimArgumentPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
// simple variadic template to sugar up adding the patterns
|
||||
(patterns.add<ConstantifyDimArgument<OpTypes>>(context), ...);
|
||||
}
|
||||
|
||||
void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
// these are the reduction ops with a dim argument
|
||||
|
||||
addConstantifyDimArgumentPatterns<
|
||||
// not supported because they have multiple results
|
||||
// AtenMaxDimOp,
|
||||
// AtenMinDimOp,
|
||||
AtenSumDimIntListOp, AtenAllDimOp, AtenLinalgVectorNormOp,
|
||||
AtenFrobeniusNormDimOp>(patterns, context);
|
||||
}
|
||||
|
||||
class RestructureNonConstantAxesPass
|
||||
: public RestructureNonConstantAxesBase<RestructureNonConstantAxesPass> {
|
||||
public:
|
||||
RestructureNonConstantAxesPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
populateRestructureNonConstantAxesPattern(patterns, context);
|
||||
|
||||
// TODO: Debug visitation order to make this more efficient.
|
||||
// A single linear scan should suffice.
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
config.maxIterations = GreedyRewriteConfig::kNoLimit;
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||
config))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createRestructureNonConstantAxesPass() {
|
||||
return std::make_unique<RestructureNonConstantAxesPass>();
|
||||
}
|
|
@ -64,6 +64,10 @@ void mlir::torch::registerTorchConversionPasses() {
|
|||
|
||||
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm) {
|
||||
// Fix non constant dims passed to reduction ops
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
torch::Torch::createRestructureNonConstantAxesPass());
|
||||
|
||||
// We want to fuse quantized operations together before lowering to linalg.
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createScalarizeShapesPass());
|
||||
|
|
Loading…
Reference in New Issue