mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for `aten.native_layer_norm`. (#470)
This commit adds support for aten.native_layer_norm operation. Here the previous code for aten.layer_norm is tweaked a little bit to accomodate both mean and variance values alongwith the layer norm value. This commit also adds decomposition of aten.layer_norm into aten.native_layer_norm, which was previously getting lowered directly to linalg. Signed-Off-By: Prateek Gupta<prateek@nod-labs.com>pull/473/head
parent
5a47f92390
commit
cfc8de36f8
|
@ -87,8 +87,33 @@ class BatchNorm3DModule(torch.nn.Module):
|
|||
def BatchNorm3DModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 3, 6, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NativeLayerNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 5, 2, 2, 3], torch.float32, True),
|
||||
([2, 2, 3], torch.float32, True),
|
||||
([2, 2, 3], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
return torch.ops.aten.native_layer_norm(
|
||||
x, list, weight, bias, eps=0.5)[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NativeLayerNormModule())
|
||||
def NativeLayerNormModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LayerNormModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -138,6 +163,8 @@ def LayerNormLastDimModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(2, 5, 2, 2, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class LayerNormNormalizeOverAllDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -1336,6 +1336,26 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
|||
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps `,` $cudnn_enable attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `,` type($cudnn_enable) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
TorchIntListType:$normalized_shape,
|
||||
AnyTorchOptionalTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
Torch_FloatType:$eps
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$layer_norm,
|
||||
AnyTorchTensorType:$mean,
|
||||
AnyTorchTensorType:$variance
|
||||
);
|
||||
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($layer_norm) `,` type($mean) `,` type($variance)";
|
||||
}
|
||||
|
||||
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -693,11 +693,12 @@ public:
|
|||
// Step 4. Get var.
|
||||
// Step 5. Get layernorm.
|
||||
namespace {
|
||||
class ConvertAtenLayerNormOp : public OpConversionPattern<AtenLayerNormOp> {
|
||||
class ConvertAtenNativeLayerNormOp
|
||||
: public OpConversionPattern<AtenNativeLayerNormOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenLayerNormOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MLIRContext *context = op->getContext();
|
||||
Location loc = op->getLoc();
|
||||
|
@ -889,9 +890,14 @@ public:
|
|||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, layerNorm);
|
||||
Type layerNormResultType = getTypeConverter()->convertType(op.getType(0));
|
||||
Type meanResultType = getTypeConverter()->convertType(op.getType(1));
|
||||
Type varResultType = getTypeConverter()->convertType(op.getType(2));
|
||||
Value layerNorm_ =
|
||||
rewriter.create<tensor::CastOp>(loc, layerNormResultType, layerNorm);
|
||||
Value mean_ = rewriter.create<tensor::CastOp>(loc, meanResultType, mean);
|
||||
Value var_ = rewriter.create<tensor::CastOp>(loc, varResultType, var);
|
||||
rewriter.replaceOp(op, {layerNorm_, mean_, var_});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -3659,8 +3665,8 @@ public:
|
|||
patterns.add<ConvertAtenCatOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenGatherOp>();
|
||||
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNativeLayerNormOp>();
|
||||
patterns.add<ConvertAtenNativeLayerNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBroadcastToOp>();
|
||||
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
|
|
|
@ -477,6 +477,33 @@ class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
||||
using OpRewritePattern<AtenLayerNormOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenLayerNormOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
|
||||
auto input = op.input().getType().cast<BaseTensorType>();
|
||||
if (!input.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor should have known sizes.");
|
||||
int64_t inputRank = input.getSizes().size();
|
||||
Value normalizedShape = op.normalized_shape();
|
||||
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||||
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
||||
std::vector<int64_t> meanVarSizes;
|
||||
for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++)
|
||||
meanVarSizes.push_back(input.getSizes()[i]);
|
||||
auto meanVarType = input.getWithSizesAndDtype(
|
||||
llvm::makeArrayRef(meanVarSizes), input.getDtype());
|
||||
auto nativeLayerNorm = rewriter.create<AtenNativeLayerNormOp>(
|
||||
loc, op.getType(), meanVarType, meanVarType, op.input(),
|
||||
op.normalized_shape(), op.weight(), op.bias(), op.eps());
|
||||
rewriter.replaceOp(op, nativeLayerNorm.getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
@ -522,6 +549,9 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenAddcmulOp>();
|
||||
patterns.add<DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(context);
|
||||
target.addIllegalOp<AtenAddcdivOp>();
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
patterns.add<DecomposeAtenLayerNormOp>(context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -473,6 +473,8 @@ public:
|
|||
return visitBinaryScalarOp(scalarOp);
|
||||
} else if (auto nllForwardOp = dyn_cast<AtenNllLossForwardOp>(op)) {
|
||||
return visitAtenNllLossForwardOp(nllForwardOp, operands);
|
||||
} else if (auto nativeLayerNormOp = dyn_cast<AtenNativeLayerNormOp>(op)) {
|
||||
return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands);
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
|
@ -609,6 +611,9 @@ private:
|
|||
ChangeResult
|
||||
visitAtenNllLossForwardOp(AtenNllLossForwardOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult visitAtenNativeLayerNormOp(
|
||||
AtenNativeLayerNormOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -1605,6 +1610,45 @@ ChangeResult TypeAnalyzer::visitAtenAddCLikeOp(
|
|||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenNativeLayerNormOp(
|
||||
AtenNativeLayerNormOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
|
||||
auto layerNormKnowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
auto meanKnowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
auto varKnowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
|
||||
layerNormKnowledge.hasSizes = input.hasSizes;
|
||||
layerNormKnowledge.sizes = input.sizes;
|
||||
layerNormKnowledge.dtype = input.dtype;
|
||||
|
||||
int64_t layerNormSize = input.sizes.size();
|
||||
Value normalizedShape = op.normalized_shape();
|
||||
SmallVector<Value> normalizedShapeSizesTorchInt;
|
||||
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
|
||||
std::vector<int64_t> meanVarSizes;
|
||||
if (input.hasSizes) {
|
||||
for (int i = normalizedShapeSizesTorchInt.size(); i < layerNormSize; i++)
|
||||
meanVarSizes.push_back(input.sizes[i]);
|
||||
}
|
||||
meanKnowledge.hasSizes = input.hasSizes;
|
||||
meanKnowledge.sizes = meanVarSizes;
|
||||
meanKnowledge.dtype = input.dtype;
|
||||
varKnowledge.hasSizes = input.hasSizes;
|
||||
varKnowledge.sizes = meanVarSizes;
|
||||
varKnowledge.dtype = input.dtype;
|
||||
|
||||
auto resultLattice =
|
||||
getLatticeElement(op.getResult(0)).join(layerNormKnowledge);
|
||||
resultLattice |= getLatticeElement(op.getResult(1)).join(meanKnowledge);
|
||||
resultLattice |= getLatticeElement(op.getResult(2)).join(varKnowledge);
|
||||
|
||||
return resultLattice;
|
||||
}
|
||||
// -----------------------------------------------------------------------------
|
||||
// Transforms.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -502,6 +502,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit(
|
||||
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
|
||||
)
|
||||
emit (
|
||||
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue