[MLIR][TORCH] Add E2E support for `aten.native_layer_norm`. (#470)

This commit adds support for aten.native_layer_norm operation. Here
the previous code for aten.layer_norm is tweaked a little bit to
accomodate both mean and variance values alongwith the layer norm
value. This commit also adds decomposition of aten.layer_norm into
aten.native_layer_norm, which was previously getting lowered directly
to linalg.

Signed-Off-By: Prateek Gupta<prateek@nod-labs.com>
pull/473/head
Prateek Gupta 2021-12-10 19:06:19 +05:30 committed by GitHub
parent 5a47f92390
commit cfc8de36f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 137 additions and 7 deletions

View File

@ -87,8 +87,33 @@ class BatchNorm3DModule(torch.nn.Module):
def BatchNorm3DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 3, 6, 4))
# ==============================================================================
class NativeLayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 5, 2, 2, 3], torch.float32, True),
([2, 2, 3], torch.float32, True),
([2, 2, 3], torch.float32, True),
])
def forward(self, x, weight, bias):
list = [2, 2, 3]
return torch.ops.aten.native_layer_norm(
x, list, weight, bias, eps=0.5)[0]
@register_test_case(module_factory=lambda: NativeLayerNormModule())
def NativeLayerNormModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
# ==============================================================================
class LayerNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -138,6 +163,8 @@ def LayerNormLastDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 2, 2, 3))
# ==============================================================================
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -1336,6 +1336,26 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps `,` $cudnn_enable attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `,` type($cudnn_enable) `->` type($result)";
}
def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
TorchIntListType:$normalized_shape,
AnyTorchOptionalTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
Torch_FloatType:$eps
);
let results = (outs
AnyTorchTensorType:$layer_norm,
AnyTorchTensorType:$mean,
AnyTorchTensorType:$variance
);
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($layer_norm) `,` type($mean) `,` type($variance)";
}
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -693,11 +693,12 @@ public:
// Step 4. Get var.
// Step 5. Get layernorm.
namespace {
class ConvertAtenLayerNormOp : public OpConversionPattern<AtenLayerNormOp> {
class ConvertAtenNativeLayerNormOp
: public OpConversionPattern<AtenNativeLayerNormOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenLayerNormOp op, OpAdaptor adaptor,
matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
@ -889,9 +890,14 @@ public:
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, layerNorm);
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
Type varResultType = getTypeConverter()->convertType(op.getType(2));
Value layerNorm_ =
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
Value mean_ = rewriter.create<tensor::CastOp>(loc, meanResultType, mean);
Value var_ = rewriter.create<tensor::CastOp>(loc, varResultType, var);
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
return success();
}
};
@ -3659,8 +3665,8 @@ public:
patterns.add<ConvertAtenCatOp>(typeConverter, context);
target.addIllegalOp<AtenGatherOp>();
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenBroadcastToOp>();
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
target.addIllegalOp<AtenArgmaxOp>();

View File

@ -477,6 +477,33 @@ class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
return success();
}
};
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLayerNormOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto input = op.input().getType().cast<BaseTensorType>();
if (!input.hasSizes())
return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes.");
int64_t inputRank = input.getSizes().size();
Value normalizedShape = op.normalized_shape();
SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
std::vector<int64_t> meanVarSizes;
for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++)
meanVarSizes.push_back(input.getSizes()[i]);
auto meanVarType = input.getWithSizesAndDtype(
llvm::makeArrayRef(meanVarSizes), input.getDtype());
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
loc, op.getType(), meanVarType, meanVarType, op.input(),
op.normalized_shape(), op.weight(), op.bias(), op.eps());
rewriter.replaceOp(op, nativeLayerNorm.getResult(0));
return success();
}
};
} // namespace
namespace {
@ -522,6 +549,9 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenAddcmulOp>();
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(context);
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<DecomposeAtenLayerNormOp>(context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();

View File

@ -473,6 +473,8 @@ public:
return visitBinaryScalarOp(scalarOp);
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
return visitAtenNllLossForwardOp(nllForwardOp, operands);
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
}
// Otherwise, this is an unknown operation. Just mark all results as
@ -609,6 +611,9 @@ private:
ChangeResult
visitAtenNllLossForwardOp(AtenNllLossForwardOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitAtenNativeLayerNormOp(
AtenNativeLayerNormOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
};
} // namespace
@ -1605,6 +1610,45 @@ ChangeResult TypeAnalyzer::visitAtenAddCLikeOp(
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenNativeLayerNormOp(
AtenNativeLayerNormOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto layerNormKnowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto meanKnowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto varKnowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
layerNormKnowledge.hasSizes = input.hasSizes;
layerNormKnowledge.sizes = input.sizes;
layerNormKnowledge.dtype = input.dtype;
int64_t layerNormSize = input.sizes.size();
Value normalizedShape = op.normalized_shape();
SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
std::vector<int64_t> meanVarSizes;
if (input.hasSizes) {
for (int i = normalizedShapeSizesTorchInt.size(); i < layerNormSize; i++)
meanVarSizes.push_back(input.sizes[i]);
}
meanKnowledge.hasSizes = input.hasSizes;
meanKnowledge.sizes = meanVarSizes;
meanKnowledge.dtype = input.dtype;
varKnowledge.hasSizes = input.hasSizes;
varKnowledge.sizes = meanVarSizes;
varKnowledge.dtype = input.dtype;
auto resultLattice =
getLatticeElement(op.getResult(0)).join(layerNormKnowledge);
resultLattice |= getLatticeElement(op.getResult(1)).join(meanKnowledge);
resultLattice |= getLatticeElement(op.getResult(2)).join(varKnowledge);
return resultLattice;
}
// -----------------------------------------------------------------------------
// Transforms.
// -----------------------------------------------------------------------------

View File

@ -502,6 +502,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit (
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
emit(
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
)