mirror of https://github.com/llvm/torch-mlir
Add decomposition to aten.native_layer_norm (#1332)
* Add decomposition to aten.native_layer_norm * fix ci errorpull/1335/head
parent
57d8ec151f
commit
512f2d9c23
|
@ -1690,6 +1690,88 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenNativeLayerNormOp
|
||||
: public OpRewritePattern<AtenNativeLayerNormOp> {
|
||||
using OpRewritePattern<AtenNativeLayerNormOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenNativeLayerNormOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto context = op.getContext();
|
||||
|
||||
auto inputTy = op.input().getType().cast<BaseTensorType>();
|
||||
if (!inputTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor should have known sizes.");
|
||||
int64_t inputRank = inputTy.getSizes().size();
|
||||
Value normalizedShape = op.normalized_shape();
|
||||
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||||
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
||||
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
||||
auto reduceDimInts = llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank));
|
||||
auto reducedTy = op.getResult(1).getType();
|
||||
auto sizeListType = ListType::get(IntType::get(context));
|
||||
|
||||
// build reduce dims
|
||||
SmallVector<Value> reduceDimVals;
|
||||
reduceDimVals.reserve(reduceDimInts.size());
|
||||
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
|
||||
std::back_inserter(reduceDimVals), [&](int64_t d) {
|
||||
return rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(d));
|
||||
});
|
||||
Value reduceDimList =
|
||||
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
|
||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
// mean(x)
|
||||
Value inputMean = rewriter.create<AtenMeanDimOp>(
|
||||
loc, reducedTy, op.input(), reduceDimList, cstTrue, none);
|
||||
|
||||
// x - mean(x)
|
||||
Value inputMeanExpanded =
|
||||
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.input());
|
||||
Value inputZeroMean = rewriter.create<AtenSubTensorOp>(
|
||||
loc, inputTy, op.input(), inputMeanExpanded, one);
|
||||
// var(x) = mean((x - mean(x))^2)
|
||||
Value inputZeroMeanSquare = rewriter.create<AtenMulTensorOp>(
|
||||
loc, inputTy, inputZeroMean, inputZeroMean);
|
||||
Value inputVar = rewriter.create<AtenMeanDimOp>(
|
||||
loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none);
|
||||
|
||||
// rsqrt(var(x) + eps)
|
||||
Value inputVarPlusEps = rewriter.create<AtenAddScalarOp>(
|
||||
loc, reducedTy, inputVar, op.eps(), one);
|
||||
Value inputRsqrtVar =
|
||||
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);
|
||||
|
||||
// (x - mean(x)) * rsqrt(var(x) + eps)
|
||||
Value inputRsqrtVarExpanded = rewriter.create<AtenExpandAsOp>(
|
||||
loc, inputTy, inputRsqrtVar, op.input());
|
||||
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
|
||||
loc, inputTy, inputZeroMean, inputRsqrtVarExpanded);
|
||||
Value out = rewriter.create<TensorStaticInfoCastOp>(
|
||||
loc, op.getResult(0).getType(), inputNormalized);
|
||||
|
||||
Value weight = op.weight();
|
||||
Value bias = op.bias();
|
||||
if (!weight.getType().isa<Torch::NoneType>()) {
|
||||
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
|
||||
}
|
||||
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||
out =
|
||||
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
|
||||
}
|
||||
rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
|
||||
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
||||
|
@ -2696,6 +2778,9 @@ public:
|
|||
target.addIllegalOp<AtenAddcdivOp>();
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
patterns.add<DecomposeAtenLayerNormOp>(context);
|
||||
target.addIllegalOp<AtenNativeLayerNormOp>();
|
||||
patterns.add<DecomposeAtenNativeLayerNormOp>(context);
|
||||
|
||||
target.addIllegalOp<AtenNativeBatchNormOp>();
|
||||
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
||||
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
||||
|
|
|
@ -126,7 +126,7 @@ class TensorPlaceholder:
|
|||
# ops in the backend contract, and move these lists somewhere deeper in the
|
||||
# compiler where each backend can "own" its set of legal ops.
|
||||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints',],
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',],
|
||||
OutputType.MHLO: [],
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue