Delete ConvertAtenNativeLayerNormOp from TorchToLinalg (#1336)

The ConvertAtenNativeLayerNormOp is delete because we have decomposition already
see https://github.com/llvm/torch-mlir/pull/1332
pull/862/head snapshot-20220905.587
Tanyo Kwok 2022-09-05 10:19:20 +08:00 committed by GitHub
parent e6528f701a
commit 37f57a9828
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 0 additions and 267 deletions

View File

@ -1257,271 +1257,6 @@ public:
};
} // namespace
// For layernorm, the mean and standard-deviation are calculated separately over
// the last certain number dimensions which have to be of the shape specified by
// normalized_shape.
//
// The shapes of different parts are as the following:
// +-------------------+--------------------+
// | meanAndVarShape | normalizedShape |
// +-------------------+---------------------
// <------------+ inputShape +-------------->
// There are the following steps:
// Step 1. Check if all the arguments meet the requirements.
// Step 2. Common parts to be used for getting mean and var.
// This includes elements count, affineMap and iteratorTypes.
// Step 3. Get mean.
// Step 4. Get rSTD.
// Step 5. Get layernorm.
namespace {
class ConvertAtenNativeLayerNormOp
: public OpConversionPattern<AtenNativeLayerNormOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
Value eps = adaptor.eps();
Value normalizedShape = op.normalized_shape();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// TODO: Handle the None cases for the optional parameters:
// weight, bias.
if (failed(checkNotNone(rewriter, op, weight)) ||
failed(checkNotNone(rewriter, op, bias)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>();
int64_t inputRank = inputType.getRank();
Type elemTy = inputType.getElementType();
// Step 1. Check if all the arguments meet the requirements.
SmallVector<Value> normalizedShapeSizesTorchInt;
if (!getListConstructElements(normalizedShape,
normalizedShapeSizesTorchInt)) {
return rewriter.notifyMatchFailure(op,
"Unimplemented normalized_shape not"
"constructed from ListConstruct");
}
SmallVector<Value> normalizedShapeSizesInt = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt);
int64_t normalizedShapeRank = normalizedShapeSizesInt.size();
if (weightType.getRank() != normalizedShapeRank ||
biasType.getRank() != normalizedShapeRank ||
inputRank < normalizedShapeRank || normalizedShapeRank < 1)
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
"normalized shape not compatible");
// Check all the dimensions match the normalized_shape
int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size();
for (auto en : enumerate((normalizedShapeSizesInt))) {
auto index = en.index();
auto inputDim =
getDimOp(rewriter, loc, input, index + meanAndVarShapeRank);
auto weightDim = getDimOp(rewriter, loc, weight, index);
auto biasDim = getDimOp(rewriter, loc, bias, index);
auto expectedSize = en.value();
checkDimEqualHelper(rewriter, loc, inputDim, expectedSize);
checkDimEqualHelper(rewriter, loc, weightDim, expectedSize);
checkDimEqualHelper(rewriter, loc, biasDim, expectedSize);
}
// Get iterator types for input shape.
SmallVector<StringRef> normalizedShapeIteratorTypes(
normalizedShapeRank, getReductionIteratorTypeName());
SmallVector<StringRef> meanAndVarIterationTypes(
meanAndVarShapeRank, getParallelIteratorTypeName());
SmallVector<StringRef> inputShapeIteratorTypes = meanAndVarIterationTypes;
inputShapeIteratorTypes.append(normalizedShapeIteratorTypes);
// Step 2. Common parts to be used for getting mean and var.
// Get sizes and affineMaps needed for mean and var.
AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank);
SmallVector<AffineExpr> meanAndVarShapeExprs;
for (int i = 0; i < meanAndVarShapeRank; i++)
meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
auto meanAndVarShapeAffineMap = AffineMap::get(
/*dimCount=*/inputRank,
/*symbolCount=*/0, meanAndVarShapeExprs, context);
SmallVector<Value> meanAndVarShapeSizes =
getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1);
// Get number of elements to be used for calculating mean and var.
Value elemCnts = normalizedShapeSizesInt[0];
for (int i = 1; i < normalizedShapeRank; i++) {
elemCnts = rewriter.create<arith::MulIOp>(loc, elemCnts,
normalizedShapeSizesInt[i]);
}
Value elemCntsFloat =
rewriter.create<arith::SIToFPOp>(loc, elemTy, elemCnts);
// Helper to calculate mean and var.
auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) {
SmallVector<AffineMap> indexingMaps(
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
Value initShapeTensor = rewriter.create<linalg::InitTensorOp>(
loc, meanAndVarShapeSizes, elemTy);
return rewriter
.create<linalg::GenericOp>(
loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/meanAndVarIterationTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value sumOrSqureSum = args[0];
Value result =
b.create<arith::DivFOp>(loc, sumOrSqureSum, elemCntsFloat);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
};
// Step 3. Get mean.
// Get sum to be used for calculating mean.
SmallVector<AffineMap, 2> sumIndexingMaps = {
inputShapeAffineMap, // input
meanAndVarShapeAffineMap, // output
};
auto initSumTensor =
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
Value sum = rewriter
.create<linalg::GenericOp>(
loc, initSumTensor.getType(), input, initSumTensor,
/*indexingMaps=*/sumIndexingMaps,
/*iteratorTypes=*/inputShapeIteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], sum = args[1];
Value result =
rewriter.create<arith::AddFOp>(loc, sum, input);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Value mean = genMeanOrVarCalculation(sum);
// Step 4. Get rSTD.
// Calculate squareSum for the layer.
SmallVector<AffineMap> squareSumIndexingMaps{
inputShapeAffineMap,
meanAndVarShapeAffineMap,
meanAndVarShapeAffineMap,
};
auto initSquareSumTensor =
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
Value squareSum =
rewriter
.create<linalg::GenericOp>(
loc, initSquareSumTensor.getType(), ValueRange{input, mean},
initSquareSumTensor,
/*indexingMaps=*/squareSumIndexingMaps,
/*iteratorTypes=*/inputShapeIteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], squareSum = args[2];
Value sub = rewriter.create<arith::SubFOp>(loc, input, mean);
Value square = rewriter.create<arith::MulFOp>(loc, sub, sub);
Value result =
rewriter.create<arith::AddFOp>(loc, squareSum, square);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Value var = genMeanOrVarCalculation(squareSum);
Value rSTDTensor = rewriter.create<linalg::InitTensorOp>(
loc, meanAndVarShapeSizes, elemTy);
SmallVector<AffineMap> rSTDIndexingMap(
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
Value rSTD = rewriter
.create<linalg::GenericOp>(
loc, rSTDTensor.getType(), var, rSTDTensor,
rSTDIndexingMap, meanAndVarIterationTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value result =
calculateRSTD(b, loc, elemTy, eps, args[0]);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
// Step 5. Get layernorm.
// Get affineMap for normalized shape.
SmallVector<AffineExpr> normalizedShapeExprs;
for (int i = meanAndVarShapeRank; i < inputRank; i++)
normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
auto normalizedShapeAffineMap = AffineMap::get(
/*dimCount=*/inputRank,
/*symbolCount=*/0, normalizedShapeExprs, context);
auto inputSizes = getTensorSizes(rewriter, loc, input);
Value initLayerNormTensor =
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
SmallVector<AffineMap> indexingMaps(1, inputShapeAffineMap);
indexingMaps.resize(3, meanAndVarShapeAffineMap);
indexingMaps.resize(5, normalizedShapeAffineMap);
indexingMaps.push_back(inputShapeAffineMap);
SmallVector<StringRef> layerNormIterationTypes(
inputRank, getParallelIteratorTypeName());
Value layerNorm =
rewriter
.create<linalg::GenericOp>(
loc, initLayerNormTensor.getType(),
ValueRange{input, mean, rSTD, weight, bias},
initLayerNormTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/layerNormIterationTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], mean = args[1], rSTD = args[2],
weight = args[3], bias = args[4];
Value result =
createLinalgPayloadCalculationForNormOpsWithRSTD(
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
SmallVector<int64_t> expandShape(inputRank, 1);
for (int i = 0; i < meanAndVarShapeRank; i++) {
// `mean` and `rstd` are not yet casted, so they will be having dynamic
// shape. Hence to match them, for each dimension corresponding to `mean`
// or `rstd` assign -1.
expandShape[i] = -1;
}
auto expandShapeType = RankedTensorType::get(expandShape, elemTy);
SmallVector<ReassociationIndices> reassociation(meanAndVarShapeRank);
for (auto i : llvm::seq<int64_t>(0, meanAndVarShapeRank)) {
reassociation[i].push_back(i);
if (i == meanAndVarShapeRank - 1) {
for (auto j : llvm::seq<int64_t>(0, normalizedShapeRank))
reassociation[i].push_back(i + j + 1);
}
}
Value meanResult = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, mean, reassociation);
Value rSTDResult = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, rSTD, reassociation);
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
Type rSTDResultType = getTypeConverter()->convertType(op.getType(2));
Value layerNorm_ =
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
Value mean_ =
rewriter.create<tensor::CastOp>(loc, meanResultType, meanResult);
Value var_ =
rewriter.create<tensor::CastOp>(loc, rSTDResultType, rSTDResult);
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
return success();
}
};
} // namespace
namespace {
class ConvertAtenNllLossBackwardOp
: public OpConversionPattern<AtenNllLossBackwardOp> {
@ -1728,8 +1463,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossBackwardOp>();
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);