mirror of https://github.com/llvm/torch-mlir
Delete ConvertAtenNativeLayerNormOp from TorchToLinalg (#1336)
The ConvertAtenNativeLayerNormOp is delete because we have decomposition already see https://github.com/llvm/torch-mlir/pull/1332pull/862/head snapshot-20220905.587
parent
e6528f701a
commit
37f57a9828
|
@ -1257,271 +1257,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// For layernorm, the mean and standard-deviation are calculated separately over
|
||||
// the last certain number dimensions which have to be of the shape specified by
|
||||
// normalized_shape.
|
||||
//
|
||||
// The shapes of different parts are as the following:
|
||||
// +-------------------+--------------------+
|
||||
// | meanAndVarShape | normalizedShape |
|
||||
// +-------------------+---------------------
|
||||
// <------------+ inputShape +-------------->
|
||||
// There are the following steps:
|
||||
// Step 1. Check if all the arguments meet the requirements.
|
||||
// Step 2. Common parts to be used for getting mean and var.
|
||||
// This includes elements count, affineMap and iteratorTypes.
|
||||
// Step 3. Get mean.
|
||||
// Step 4. Get rSTD.
|
||||
// Step 5. Get layernorm.
|
||||
namespace {
|
||||
class ConvertAtenNativeLayerNormOp
|
||||
: public OpConversionPattern<AtenNativeLayerNormOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MLIRContext *context = op->getContext();
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.input();
|
||||
Value weight = adaptor.weight();
|
||||
Value bias = adaptor.bias();
|
||||
Value eps = adaptor.eps();
|
||||
Value normalizedShape = op.normalized_shape();
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
// TODO: Handle the None cases for the optional parameters:
|
||||
// weight, bias.
|
||||
if (failed(checkNotNone(rewriter, op, weight)) ||
|
||||
failed(checkNotNone(rewriter, op, bias)))
|
||||
return failure();
|
||||
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
int64_t inputRank = inputType.getRank();
|
||||
Type elemTy = inputType.getElementType();
|
||||
|
||||
// Step 1. Check if all the arguments meet the requirements.
|
||||
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||||
if (!getListConstructElements(normalizedShape,
|
||||
normalizedShapeSizesTorchInt)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unimplemented normalized_shape not"
|
||||
"constructed from ListConstruct");
|
||||
}
|
||||
SmallVector<Value> normalizedShapeSizesInt = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt);
|
||||
int64_t normalizedShapeRank = normalizedShapeSizesInt.size();
|
||||
if (weightType.getRank() != normalizedShapeRank ||
|
||||
biasType.getRank() != normalizedShapeRank ||
|
||||
inputRank < normalizedShapeRank || normalizedShapeRank < 1)
|
||||
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
|
||||
"normalized shape not compatible");
|
||||
|
||||
// Check all the dimensions match the normalized_shape
|
||||
int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size();
|
||||
for (auto en : enumerate((normalizedShapeSizesInt))) {
|
||||
auto index = en.index();
|
||||
auto inputDim =
|
||||
getDimOp(rewriter, loc, input, index + meanAndVarShapeRank);
|
||||
auto weightDim = getDimOp(rewriter, loc, weight, index);
|
||||
auto biasDim = getDimOp(rewriter, loc, bias, index);
|
||||
|
||||
auto expectedSize = en.value();
|
||||
checkDimEqualHelper(rewriter, loc, inputDim, expectedSize);
|
||||
checkDimEqualHelper(rewriter, loc, weightDim, expectedSize);
|
||||
checkDimEqualHelper(rewriter, loc, biasDim, expectedSize);
|
||||
}
|
||||
|
||||
// Get iterator types for input shape.
|
||||
SmallVector<StringRef> normalizedShapeIteratorTypes(
|
||||
normalizedShapeRank, getReductionIteratorTypeName());
|
||||
SmallVector<StringRef> meanAndVarIterationTypes(
|
||||
meanAndVarShapeRank, getParallelIteratorTypeName());
|
||||
SmallVector<StringRef> inputShapeIteratorTypes = meanAndVarIterationTypes;
|
||||
inputShapeIteratorTypes.append(normalizedShapeIteratorTypes);
|
||||
|
||||
// Step 2. Common parts to be used for getting mean and var.
|
||||
|
||||
// Get sizes and affineMaps needed for mean and var.
|
||||
AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank);
|
||||
SmallVector<AffineExpr> meanAndVarShapeExprs;
|
||||
for (int i = 0; i < meanAndVarShapeRank; i++)
|
||||
meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
|
||||
auto meanAndVarShapeAffineMap = AffineMap::get(
|
||||
/*dimCount=*/inputRank,
|
||||
/*symbolCount=*/0, meanAndVarShapeExprs, context);
|
||||
SmallVector<Value> meanAndVarShapeSizes =
|
||||
getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1);
|
||||
|
||||
// Get number of elements to be used for calculating mean and var.
|
||||
Value elemCnts = normalizedShapeSizesInt[0];
|
||||
for (int i = 1; i < normalizedShapeRank; i++) {
|
||||
elemCnts = rewriter.create<arith::MulIOp>(loc, elemCnts,
|
||||
normalizedShapeSizesInt[i]);
|
||||
}
|
||||
Value elemCntsFloat =
|
||||
rewriter.create<arith::SIToFPOp>(loc, elemTy, elemCnts);
|
||||
|
||||
// Helper to calculate mean and var.
|
||||
auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) {
|
||||
SmallVector<AffineMap> indexingMaps(
|
||||
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
|
||||
Value initShapeTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, meanAndVarShapeSizes, elemTy);
|
||||
return rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/meanAndVarIterationTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value sumOrSqureSum = args[0];
|
||||
Value result =
|
||||
b.create<arith::DivFOp>(loc, sumOrSqureSum, elemCntsFloat);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
};
|
||||
|
||||
// Step 3. Get mean.
|
||||
|
||||
// Get sum to be used for calculating mean.
|
||||
SmallVector<AffineMap, 2> sumIndexingMaps = {
|
||||
inputShapeAffineMap, // input
|
||||
meanAndVarShapeAffineMap, // output
|
||||
};
|
||||
auto initSumTensor =
|
||||
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
|
||||
Value sum = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initSumTensor.getType(), input, initSumTensor,
|
||||
/*indexingMaps=*/sumIndexingMaps,
|
||||
/*iteratorTypes=*/inputShapeIteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value input = args[0], sum = args[1];
|
||||
Value result =
|
||||
rewriter.create<arith::AddFOp>(loc, sum, input);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
Value mean = genMeanOrVarCalculation(sum);
|
||||
|
||||
// Step 4. Get rSTD.
|
||||
|
||||
// Calculate squareSum for the layer.
|
||||
SmallVector<AffineMap> squareSumIndexingMaps{
|
||||
inputShapeAffineMap,
|
||||
meanAndVarShapeAffineMap,
|
||||
meanAndVarShapeAffineMap,
|
||||
};
|
||||
auto initSquareSumTensor =
|
||||
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
|
||||
Value squareSum =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initSquareSumTensor.getType(), ValueRange{input, mean},
|
||||
initSquareSumTensor,
|
||||
/*indexingMaps=*/squareSumIndexingMaps,
|
||||
/*iteratorTypes=*/inputShapeIteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value input = args[0], mean = args[1], squareSum = args[2];
|
||||
Value sub = rewriter.create<arith::SubFOp>(loc, input, mean);
|
||||
Value square = rewriter.create<arith::MulFOp>(loc, sub, sub);
|
||||
Value result =
|
||||
rewriter.create<arith::AddFOp>(loc, squareSum, square);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
Value var = genMeanOrVarCalculation(squareSum);
|
||||
Value rSTDTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, meanAndVarShapeSizes, elemTy);
|
||||
SmallVector<AffineMap> rSTDIndexingMap(
|
||||
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
|
||||
|
||||
Value rSTD = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, rSTDTensor.getType(), var, rSTDTensor,
|
||||
rSTDIndexingMap, meanAndVarIterationTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value result =
|
||||
calculateRSTD(b, loc, elemTy, eps, args[0]);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
// Step 5. Get layernorm.
|
||||
|
||||
// Get affineMap for normalized shape.
|
||||
SmallVector<AffineExpr> normalizedShapeExprs;
|
||||
for (int i = meanAndVarShapeRank; i < inputRank; i++)
|
||||
normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
|
||||
auto normalizedShapeAffineMap = AffineMap::get(
|
||||
/*dimCount=*/inputRank,
|
||||
/*symbolCount=*/0, normalizedShapeExprs, context);
|
||||
auto inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
Value initLayerNormTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
|
||||
SmallVector<AffineMap> indexingMaps(1, inputShapeAffineMap);
|
||||
indexingMaps.resize(3, meanAndVarShapeAffineMap);
|
||||
indexingMaps.resize(5, normalizedShapeAffineMap);
|
||||
indexingMaps.push_back(inputShapeAffineMap);
|
||||
SmallVector<StringRef> layerNormIterationTypes(
|
||||
inputRank, getParallelIteratorTypeName());
|
||||
Value layerNorm =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initLayerNormTensor.getType(),
|
||||
ValueRange{input, mean, rSTD, weight, bias},
|
||||
initLayerNormTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/layerNormIterationTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value input = args[0], mean = args[1], rSTD = args[2],
|
||||
weight = args[3], bias = args[4];
|
||||
Value result =
|
||||
createLinalgPayloadCalculationForNormOpsWithRSTD(
|
||||
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
SmallVector<int64_t> expandShape(inputRank, 1);
|
||||
for (int i = 0; i < meanAndVarShapeRank; i++) {
|
||||
// `mean` and `rstd` are not yet casted, so they will be having dynamic
|
||||
// shape. Hence to match them, for each dimension corresponding to `mean`
|
||||
// or `rstd` assign -1.
|
||||
expandShape[i] = -1;
|
||||
}
|
||||
auto expandShapeType = RankedTensorType::get(expandShape, elemTy);
|
||||
SmallVector<ReassociationIndices> reassociation(meanAndVarShapeRank);
|
||||
for (auto i : llvm::seq<int64_t>(0, meanAndVarShapeRank)) {
|
||||
reassociation[i].push_back(i);
|
||||
if (i == meanAndVarShapeRank - 1) {
|
||||
for (auto j : llvm::seq<int64_t>(0, normalizedShapeRank))
|
||||
reassociation[i].push_back(i + j + 1);
|
||||
}
|
||||
}
|
||||
Value meanResult = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, expandShapeType, mean, reassociation);
|
||||
Value rSTDResult = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, expandShapeType, rSTD, reassociation);
|
||||
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
|
||||
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
|
||||
Type rSTDResultType = getTypeConverter()->convertType(op.getType(2));
|
||||
Value layerNorm_ =
|
||||
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
|
||||
Value mean_ =
|
||||
rewriter.create<tensor::CastOp>(loc, meanResultType, meanResult);
|
||||
Value var_ =
|
||||
rewriter.create<tensor::CastOp>(loc, rSTDResultType, rSTDResult);
|
||||
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenNllLossBackwardOp
|
||||
: public OpConversionPattern<AtenNllLossBackwardOp> {
|
||||
|
@ -1728,8 +1463,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBatchNormOp>();
|
||||
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNativeLayerNormOp>();
|
||||
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossBackwardOp>();
|
||||
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
|
||||
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
|
||||
|
|
Loading…
Reference in New Issue