Add decomposition to aten.native_layer_norm (#1332)

* Add decomposition to aten.native_layer_norm

* fix ci error
pull/1335/head
Tanyo Kwok 2022-09-02 09:29:22 +08:00 committed by GitHub
parent 57d8ec151f
commit 512f2d9c23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 1 deletions

View File

@ -1690,6 +1690,88 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
}; };
} // namespace } // namespace
namespace {
class DecomposeAtenNativeLayerNormOp
: public OpRewritePattern<AtenNativeLayerNormOp> {
using OpRewritePattern<AtenNativeLayerNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeLayerNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();
auto inputTy = op.input().getType().cast<BaseTensorType>();
if (!inputTy.hasSizes())
return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes.");
int64_t inputRank = inputTy.getSizes().size();
Value normalizedShape = op.normalized_shape();
SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
auto reduceDimInts = llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank));
auto reducedTy = op.getResult(1).getType();
auto sizeListType = ListType::get(IntType::get(context));
// build reduce dims
SmallVector<Value> reduceDimVals;
reduceDimVals.reserve(reduceDimInts.size());
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
std::back_inserter(reduceDimVals), [&](int64_t d) {
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(d));
});
Value reduceDimList =
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
// mean(x)
Value inputMean = rewriter.create<AtenMeanDimOp>(
loc, reducedTy, op.input(), reduceDimList, cstTrue, none);
// x - mean(x)
Value inputMeanExpanded =
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.input());
Value inputZeroMean = rewriter.create<AtenSubTensorOp>(
loc, inputTy, op.input(), inputMeanExpanded, one);
// var(x) = mean((x - mean(x))^2)
Value inputZeroMeanSquare = rewriter.create<AtenMulTensorOp>(
loc, inputTy, inputZeroMean, inputZeroMean);
Value inputVar = rewriter.create<AtenMeanDimOp>(
loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none);
// rsqrt(var(x) + eps)
Value inputVarPlusEps = rewriter.create<AtenAddScalarOp>(
loc, reducedTy, inputVar, op.eps(), one);
Value inputRsqrtVar =
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);
// (x - mean(x)) * rsqrt(var(x) + eps)
Value inputRsqrtVarExpanded = rewriter.create<AtenExpandAsOp>(
loc, inputTy, inputRsqrtVar, op.input());
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
loc, inputTy, inputZeroMean, inputRsqrtVarExpanded);
Value out = rewriter.create<TensorStaticInfoCastOp>(
loc, op.getResult(0).getType(), inputNormalized);
Value weight = op.weight();
Value bias = op.bias();
if (!weight.getType().isa<Torch::NoneType>()) {
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
}
if (!bias.getType().isa<Torch::NoneType>()) {
out =
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
}
rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar});
return success();
}
};
} // namespace
namespace { namespace {
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops. // Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> { class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
@ -2696,6 +2778,9 @@ public:
target.addIllegalOp<AtenAddcdivOp>(); target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>(); target.addIllegalOp<AtenLayerNormOp>();
patterns.add<DecomposeAtenLayerNormOp>(context); patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<DecomposeAtenNativeLayerNormOp>(context);
target.addIllegalOp<AtenNativeBatchNormOp>(); target.addIllegalOp<AtenNativeBatchNormOp>();
patterns.add<DecomposeAtenNativeBatchNormOp>(context); patterns.add<DecomposeAtenNativeBatchNormOp>(context);
target.addIllegalOp<AtenConvolutionOverrideableOp>(); target.addIllegalOp<AtenConvolutionOverrideableOp>();

View File

@ -126,7 +126,7 @@ class TensorPlaceholder:
# ops in the backend contract, and move these lists somewhere deeper in the # ops in the backend contract, and move these lists somewhere deeper in the
# compiler where each backend can "own" its set of legal ops. # compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = { BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints',], OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',], OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',],
OutputType.MHLO: [], OutputType.MHLO: [],
} }