mirror of https://github.com/llvm/torch-mlir
Delete ConvertAtenNativeLayerNormOp from TorchToLinalg (#1336)
The ConvertAtenNativeLayerNormOp is delete because we have decomposition already see https://github.com/llvm/torch-mlir/pull/1332pull/862/head snapshot-20220905.587
parent
e6528f701a
commit
37f57a9828
|
@ -1257,271 +1257,6 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// For layernorm, the mean and standard-deviation are calculated separately over
|
|
||||||
// the last certain number dimensions which have to be of the shape specified by
|
|
||||||
// normalized_shape.
|
|
||||||
//
|
|
||||||
// The shapes of different parts are as the following:
|
|
||||||
// +-------------------+--------------------+
|
|
||||||
// | meanAndVarShape | normalizedShape |
|
|
||||||
// +-------------------+---------------------
|
|
||||||
// <------------+ inputShape +-------------->
|
|
||||||
// There are the following steps:
|
|
||||||
// Step 1. Check if all the arguments meet the requirements.
|
|
||||||
// Step 2. Common parts to be used for getting mean and var.
|
|
||||||
// This includes elements count, affineMap and iteratorTypes.
|
|
||||||
// Step 3. Get mean.
|
|
||||||
// Step 4. Get rSTD.
|
|
||||||
// Step 5. Get layernorm.
|
|
||||||
namespace {
|
|
||||||
class ConvertAtenNativeLayerNormOp
|
|
||||||
: public OpConversionPattern<AtenNativeLayerNormOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
MLIRContext *context = op->getContext();
|
|
||||||
Location loc = op->getLoc();
|
|
||||||
Value input = adaptor.input();
|
|
||||||
Value weight = adaptor.weight();
|
|
||||||
Value bias = adaptor.bias();
|
|
||||||
Value eps = adaptor.eps();
|
|
||||||
Value normalizedShape = op.normalized_shape();
|
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// TODO: Handle the None cases for the optional parameters:
|
|
||||||
// weight, bias.
|
|
||||||
if (failed(checkNotNone(rewriter, op, weight)) ||
|
|
||||||
failed(checkNotNone(rewriter, op, bias)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
|
||||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
|
||||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
|
||||||
int64_t inputRank = inputType.getRank();
|
|
||||||
Type elemTy = inputType.getElementType();
|
|
||||||
|
|
||||||
// Step 1. Check if all the arguments meet the requirements.
|
|
||||||
SmallVector<Value> normalizedShapeSizesTorchInt;
|
|
||||||
if (!getListConstructElements(normalizedShape,
|
|
||||||
normalizedShapeSizesTorchInt)) {
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"Unimplemented normalized_shape not"
|
|
||||||
"constructed from ListConstruct");
|
|
||||||
}
|
|
||||||
SmallVector<Value> normalizedShapeSizesInt = getTypeConvertedValues(
|
|
||||||
rewriter, loc, getTypeConverter(), normalizedShapeSizesTorchInt);
|
|
||||||
int64_t normalizedShapeRank = normalizedShapeSizesInt.size();
|
|
||||||
if (weightType.getRank() != normalizedShapeRank ||
|
|
||||||
biasType.getRank() != normalizedShapeRank ||
|
|
||||||
inputRank < normalizedShapeRank || normalizedShapeRank < 1)
|
|
||||||
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
|
|
||||||
"normalized shape not compatible");
|
|
||||||
|
|
||||||
// Check all the dimensions match the normalized_shape
|
|
||||||
int64_t meanAndVarShapeRank = inputRank - normalizedShapeSizesInt.size();
|
|
||||||
for (auto en : enumerate((normalizedShapeSizesInt))) {
|
|
||||||
auto index = en.index();
|
|
||||||
auto inputDim =
|
|
||||||
getDimOp(rewriter, loc, input, index + meanAndVarShapeRank);
|
|
||||||
auto weightDim = getDimOp(rewriter, loc, weight, index);
|
|
||||||
auto biasDim = getDimOp(rewriter, loc, bias, index);
|
|
||||||
|
|
||||||
auto expectedSize = en.value();
|
|
||||||
checkDimEqualHelper(rewriter, loc, inputDim, expectedSize);
|
|
||||||
checkDimEqualHelper(rewriter, loc, weightDim, expectedSize);
|
|
||||||
checkDimEqualHelper(rewriter, loc, biasDim, expectedSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get iterator types for input shape.
|
|
||||||
SmallVector<StringRef> normalizedShapeIteratorTypes(
|
|
||||||
normalizedShapeRank, getReductionIteratorTypeName());
|
|
||||||
SmallVector<StringRef> meanAndVarIterationTypes(
|
|
||||||
meanAndVarShapeRank, getParallelIteratorTypeName());
|
|
||||||
SmallVector<StringRef> inputShapeIteratorTypes = meanAndVarIterationTypes;
|
|
||||||
inputShapeIteratorTypes.append(normalizedShapeIteratorTypes);
|
|
||||||
|
|
||||||
// Step 2. Common parts to be used for getting mean and var.
|
|
||||||
|
|
||||||
// Get sizes and affineMaps needed for mean and var.
|
|
||||||
AffineMap inputShapeAffineMap = rewriter.getMultiDimIdentityMap(inputRank);
|
|
||||||
SmallVector<AffineExpr> meanAndVarShapeExprs;
|
|
||||||
for (int i = 0; i < meanAndVarShapeRank; i++)
|
|
||||||
meanAndVarShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
|
|
||||||
auto meanAndVarShapeAffineMap = AffineMap::get(
|
|
||||||
/*dimCount=*/inputRank,
|
|
||||||
/*symbolCount=*/0, meanAndVarShapeExprs, context);
|
|
||||||
SmallVector<Value> meanAndVarShapeSizes =
|
|
||||||
getTensorSizesUntilDim(rewriter, loc, input, meanAndVarShapeRank - 1);
|
|
||||||
|
|
||||||
// Get number of elements to be used for calculating mean and var.
|
|
||||||
Value elemCnts = normalizedShapeSizesInt[0];
|
|
||||||
for (int i = 1; i < normalizedShapeRank; i++) {
|
|
||||||
elemCnts = rewriter.create<arith::MulIOp>(loc, elemCnts,
|
|
||||||
normalizedShapeSizesInt[i]);
|
|
||||||
}
|
|
||||||
Value elemCntsFloat =
|
|
||||||
rewriter.create<arith::SIToFPOp>(loc, elemTy, elemCnts);
|
|
||||||
|
|
||||||
// Helper to calculate mean and var.
|
|
||||||
auto genMeanOrVarCalculation = [&](Value sumOrSquareSum) {
|
|
||||||
SmallVector<AffineMap> indexingMaps(
|
|
||||||
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
|
|
||||||
Value initShapeTensor = rewriter.create<linalg::InitTensorOp>(
|
|
||||||
loc, meanAndVarShapeSizes, elemTy);
|
|
||||||
return rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, initShapeTensor.getType(), sumOrSquareSum, initShapeTensor,
|
|
||||||
/*indexingMaps=*/indexingMaps,
|
|
||||||
/*iteratorTypes=*/meanAndVarIterationTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
Value sumOrSqureSum = args[0];
|
|
||||||
Value result =
|
|
||||||
b.create<arith::DivFOp>(loc, sumOrSqureSum, elemCntsFloat);
|
|
||||||
b.create<linalg::YieldOp>(loc, result);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Step 3. Get mean.
|
|
||||||
|
|
||||||
// Get sum to be used for calculating mean.
|
|
||||||
SmallVector<AffineMap, 2> sumIndexingMaps = {
|
|
||||||
inputShapeAffineMap, // input
|
|
||||||
meanAndVarShapeAffineMap, // output
|
|
||||||
};
|
|
||||||
auto initSumTensor =
|
|
||||||
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
|
|
||||||
Value sum = rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, initSumTensor.getType(), input, initSumTensor,
|
|
||||||
/*indexingMaps=*/sumIndexingMaps,
|
|
||||||
/*iteratorTypes=*/inputShapeIteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
Value input = args[0], sum = args[1];
|
|
||||||
Value result =
|
|
||||||
rewriter.create<arith::AddFOp>(loc, sum, input);
|
|
||||||
b.create<linalg::YieldOp>(loc, result);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
Value mean = genMeanOrVarCalculation(sum);
|
|
||||||
|
|
||||||
// Step 4. Get rSTD.
|
|
||||||
|
|
||||||
// Calculate squareSum for the layer.
|
|
||||||
SmallVector<AffineMap> squareSumIndexingMaps{
|
|
||||||
inputShapeAffineMap,
|
|
||||||
meanAndVarShapeAffineMap,
|
|
||||||
meanAndVarShapeAffineMap,
|
|
||||||
};
|
|
||||||
auto initSquareSumTensor =
|
|
||||||
createZeroInitTensor(rewriter, loc, meanAndVarShapeSizes, elemTy);
|
|
||||||
Value squareSum =
|
|
||||||
rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, initSquareSumTensor.getType(), ValueRange{input, mean},
|
|
||||||
initSquareSumTensor,
|
|
||||||
/*indexingMaps=*/squareSumIndexingMaps,
|
|
||||||
/*iteratorTypes=*/inputShapeIteratorTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
Value input = args[0], mean = args[1], squareSum = args[2];
|
|
||||||
Value sub = rewriter.create<arith::SubFOp>(loc, input, mean);
|
|
||||||
Value square = rewriter.create<arith::MulFOp>(loc, sub, sub);
|
|
||||||
Value result =
|
|
||||||
rewriter.create<arith::AddFOp>(loc, squareSum, square);
|
|
||||||
b.create<linalg::YieldOp>(loc, result);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
Value var = genMeanOrVarCalculation(squareSum);
|
|
||||||
Value rSTDTensor = rewriter.create<linalg::InitTensorOp>(
|
|
||||||
loc, meanAndVarShapeSizes, elemTy);
|
|
||||||
SmallVector<AffineMap> rSTDIndexingMap(
|
|
||||||
2, rewriter.getMultiDimIdentityMap(meanAndVarShapeRank));
|
|
||||||
|
|
||||||
Value rSTD = rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, rSTDTensor.getType(), var, rSTDTensor,
|
|
||||||
rSTDIndexingMap, meanAndVarIterationTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
Value result =
|
|
||||||
calculateRSTD(b, loc, elemTy, eps, args[0]);
|
|
||||||
b.create<linalg::YieldOp>(loc, result);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
|
|
||||||
// Step 5. Get layernorm.
|
|
||||||
|
|
||||||
// Get affineMap for normalized shape.
|
|
||||||
SmallVector<AffineExpr> normalizedShapeExprs;
|
|
||||||
for (int i = meanAndVarShapeRank; i < inputRank; i++)
|
|
||||||
normalizedShapeExprs.push_back(mlir::getAffineDimExpr(i, context));
|
|
||||||
auto normalizedShapeAffineMap = AffineMap::get(
|
|
||||||
/*dimCount=*/inputRank,
|
|
||||||
/*symbolCount=*/0, normalizedShapeExprs, context);
|
|
||||||
auto inputSizes = getTensorSizes(rewriter, loc, input);
|
|
||||||
Value initLayerNormTensor =
|
|
||||||
rewriter.create<linalg::InitTensorOp>(loc, inputSizes, elemTy);
|
|
||||||
SmallVector<AffineMap> indexingMaps(1, inputShapeAffineMap);
|
|
||||||
indexingMaps.resize(3, meanAndVarShapeAffineMap);
|
|
||||||
indexingMaps.resize(5, normalizedShapeAffineMap);
|
|
||||||
indexingMaps.push_back(inputShapeAffineMap);
|
|
||||||
SmallVector<StringRef> layerNormIterationTypes(
|
|
||||||
inputRank, getParallelIteratorTypeName());
|
|
||||||
Value layerNorm =
|
|
||||||
rewriter
|
|
||||||
.create<linalg::GenericOp>(
|
|
||||||
loc, initLayerNormTensor.getType(),
|
|
||||||
ValueRange{input, mean, rSTD, weight, bias},
|
|
||||||
initLayerNormTensor,
|
|
||||||
/*indexingMaps=*/indexingMaps,
|
|
||||||
/*iteratorTypes=*/layerNormIterationTypes,
|
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
||||||
Value input = args[0], mean = args[1], rSTD = args[2],
|
|
||||||
weight = args[3], bias = args[4];
|
|
||||||
Value result =
|
|
||||||
createLinalgPayloadCalculationForNormOpsWithRSTD(
|
|
||||||
b, loc, elemTy, input, mean, rSTD, eps, weight, bias);
|
|
||||||
b.create<linalg::YieldOp>(loc, result);
|
|
||||||
})
|
|
||||||
.getResult(0);
|
|
||||||
SmallVector<int64_t> expandShape(inputRank, 1);
|
|
||||||
for (int i = 0; i < meanAndVarShapeRank; i++) {
|
|
||||||
// `mean` and `rstd` are not yet casted, so they will be having dynamic
|
|
||||||
// shape. Hence to match them, for each dimension corresponding to `mean`
|
|
||||||
// or `rstd` assign -1.
|
|
||||||
expandShape[i] = -1;
|
|
||||||
}
|
|
||||||
auto expandShapeType = RankedTensorType::get(expandShape, elemTy);
|
|
||||||
SmallVector<ReassociationIndices> reassociation(meanAndVarShapeRank);
|
|
||||||
for (auto i : llvm::seq<int64_t>(0, meanAndVarShapeRank)) {
|
|
||||||
reassociation[i].push_back(i);
|
|
||||||
if (i == meanAndVarShapeRank - 1) {
|
|
||||||
for (auto j : llvm::seq<int64_t>(0, normalizedShapeRank))
|
|
||||||
reassociation[i].push_back(i + j + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Value meanResult = rewriter.create<tensor::ExpandShapeOp>(
|
|
||||||
loc, expandShapeType, mean, reassociation);
|
|
||||||
Value rSTDResult = rewriter.create<tensor::ExpandShapeOp>(
|
|
||||||
loc, expandShapeType, rSTD, reassociation);
|
|
||||||
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
|
|
||||||
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
|
|
||||||
Type rSTDResultType = getTypeConverter()->convertType(op.getType(2));
|
|
||||||
Value layerNorm_ =
|
|
||||||
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
|
|
||||||
Value mean_ =
|
|
||||||
rewriter.create<tensor::CastOp>(loc, meanResultType, meanResult);
|
|
||||||
Value var_ =
|
|
||||||
rewriter.create<tensor::CastOp>(loc, rSTDResultType, rSTDResult);
|
|
||||||
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenNllLossBackwardOp
|
class ConvertAtenNllLossBackwardOp
|
||||||
: public OpConversionPattern<AtenNllLossBackwardOp> {
|
: public OpConversionPattern<AtenNllLossBackwardOp> {
|
||||||
|
@ -1728,8 +1463,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenBatchNormOp>();
|
target.addIllegalOp<AtenBatchNormOp>();
|
||||||
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNativeLayerNormOp>();
|
|
||||||
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
|
|
||||||
target.addIllegalOp<AtenNllLossBackwardOp>();
|
target.addIllegalOp<AtenNllLossBackwardOp>();
|
||||||
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
|
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
|
||||||
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
|
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
|
||||||
|
|
Loading…
Reference in New Issue