mirror of https://github.com/llvm/torch-mlir
Add decomposition to aten.native_layer_norm (#1332)
* Add decomposition to aten.native_layer_norm * fix ci errorpull/1335/head
parent
57d8ec151f
commit
512f2d9c23
|
@ -1690,6 +1690,88 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenNativeLayerNormOp
|
||||||
|
: public OpRewritePattern<AtenNativeLayerNormOp> {
|
||||||
|
using OpRewritePattern<AtenNativeLayerNormOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenNativeLayerNormOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto context = op.getContext();
|
||||||
|
|
||||||
|
auto inputTy = op.input().getType().cast<BaseTensorType>();
|
||||||
|
if (!inputTy.hasSizes())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "input tensor should have known sizes.");
|
||||||
|
int64_t inputRank = inputTy.getSizes().size();
|
||||||
|
Value normalizedShape = op.normalized_shape();
|
||||||
|
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||||||
|
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
||||||
|
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
|
||||||
|
auto reduceDimInts = llvm::to_vector<4>(llvm::seq<int64_t>(axis, inputRank));
|
||||||
|
auto reducedTy = op.getResult(1).getType();
|
||||||
|
auto sizeListType = ListType::get(IntType::get(context));
|
||||||
|
|
||||||
|
// build reduce dims
|
||||||
|
SmallVector<Value> reduceDimVals;
|
||||||
|
reduceDimVals.reserve(reduceDimInts.size());
|
||||||
|
std::transform(reduceDimInts.begin(), reduceDimInts.end(),
|
||||||
|
std::back_inserter(reduceDimVals), [&](int64_t d) {
|
||||||
|
return rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(d));
|
||||||
|
});
|
||||||
|
Value reduceDimList =
|
||||||
|
rewriter.create<PrimListConstructOp>(loc, sizeListType, reduceDimVals);
|
||||||
|
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
|
||||||
|
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
|
||||||
|
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
// mean(x)
|
||||||
|
Value inputMean = rewriter.create<AtenMeanDimOp>(
|
||||||
|
loc, reducedTy, op.input(), reduceDimList, cstTrue, none);
|
||||||
|
|
||||||
|
// x - mean(x)
|
||||||
|
Value inputMeanExpanded =
|
||||||
|
rewriter.create<AtenExpandAsOp>(loc, inputTy, inputMean, op.input());
|
||||||
|
Value inputZeroMean = rewriter.create<AtenSubTensorOp>(
|
||||||
|
loc, inputTy, op.input(), inputMeanExpanded, one);
|
||||||
|
// var(x) = mean((x - mean(x))^2)
|
||||||
|
Value inputZeroMeanSquare = rewriter.create<AtenMulTensorOp>(
|
||||||
|
loc, inputTy, inputZeroMean, inputZeroMean);
|
||||||
|
Value inputVar = rewriter.create<AtenMeanDimOp>(
|
||||||
|
loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none);
|
||||||
|
|
||||||
|
// rsqrt(var(x) + eps)
|
||||||
|
Value inputVarPlusEps = rewriter.create<AtenAddScalarOp>(
|
||||||
|
loc, reducedTy, inputVar, op.eps(), one);
|
||||||
|
Value inputRsqrtVar =
|
||||||
|
rewriter.create<AtenRsqrtOp>(loc, reducedTy, inputVarPlusEps);
|
||||||
|
|
||||||
|
// (x - mean(x)) * rsqrt(var(x) + eps)
|
||||||
|
Value inputRsqrtVarExpanded = rewriter.create<AtenExpandAsOp>(
|
||||||
|
loc, inputTy, inputRsqrtVar, op.input());
|
||||||
|
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
|
||||||
|
loc, inputTy, inputZeroMean, inputRsqrtVarExpanded);
|
||||||
|
Value out = rewriter.create<TensorStaticInfoCastOp>(
|
||||||
|
loc, op.getResult(0).getType(), inputNormalized);
|
||||||
|
|
||||||
|
Value weight = op.weight();
|
||||||
|
Value bias = op.bias();
|
||||||
|
if (!weight.getType().isa<Torch::NoneType>()) {
|
||||||
|
out = rewriter.create<AtenMulTensorOp>(loc, out.getType(), out, weight);
|
||||||
|
}
|
||||||
|
if (!bias.getType().isa<Torch::NoneType>()) {
|
||||||
|
out =
|
||||||
|
rewriter.create<AtenAddTensorOp>(loc, out.getType(), out, bias, one);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
|
// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops.
|
||||||
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
|
||||||
|
@ -2696,6 +2778,9 @@ public:
|
||||||
target.addIllegalOp<AtenAddcdivOp>();
|
target.addIllegalOp<AtenAddcdivOp>();
|
||||||
target.addIllegalOp<AtenLayerNormOp>();
|
target.addIllegalOp<AtenLayerNormOp>();
|
||||||
patterns.add<DecomposeAtenLayerNormOp>(context);
|
patterns.add<DecomposeAtenLayerNormOp>(context);
|
||||||
|
target.addIllegalOp<AtenNativeLayerNormOp>();
|
||||||
|
patterns.add<DecomposeAtenNativeLayerNormOp>(context);
|
||||||
|
|
||||||
target.addIllegalOp<AtenNativeBatchNormOp>();
|
target.addIllegalOp<AtenNativeBatchNormOp>();
|
||||||
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
|
||||||
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
||||||
|
|
|
@ -126,7 +126,7 @@ class TensorPlaceholder:
|
||||||
# ops in the backend contract, and move these lists somewhere deeper in the
|
# ops in the backend contract, and move these lists somewhere deeper in the
|
||||||
# compiler where each backend can "own" its set of legal ops.
|
# compiler where each backend can "own" its set of legal ops.
|
||||||
BACKEND_LEGAL_OPS = {
|
BACKEND_LEGAL_OPS = {
|
||||||
OutputType.TOSA: ['torch.aten.flatten.using_ints',],
|
OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm'],
|
||||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',],
|
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',],
|
||||||
OutputType.MHLO: [],
|
OutputType.MHLO: [],
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue