diff --git a/e2e_testing/torchscript/batchnorm.py b/e2e_testing/torchscript/batchnorm.py index 44298f6c5..39071b99f 100644 --- a/e2e_testing/torchscript/batchnorm.py +++ b/e2e_testing/torchscript/batchnorm.py @@ -87,8 +87,33 @@ class BatchNorm3DModule(torch.nn.Module): def BatchNorm3DModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 3, 6, 4)) +# ============================================================================== + + +class NativeLayerNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5, 2, 2, 3], torch.float32, True), + ([2, 2, 3], torch.float32, True), + ([2, 2, 3], torch.float32, True), + ]) + def forward(self, x, weight, bias): + list = [2, 2, 3] + return torch.ops.aten.native_layer_norm( + x, list, weight, bias, eps=0.5)[0] + + +@register_test_case(module_factory=lambda: NativeLayerNormModule()) +def NativeLayerNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 2, 2, 3), tu.rand(2, 2, 3), tu.rand(2, 2, 3)) # ============================================================================== + + class LayerNormModule(torch.nn.Module): def __init__(self): super().__init__() @@ -138,6 +163,8 @@ def LayerNormLastDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 5, 2, 2, 3)) # ============================================================================== + + class LayerNormNormalizeOverAllDimsModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 0feac941a..145a26056 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1336,6 +1336,26 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps `,` $cudnn_enable attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `,` type($cudnn_enable) `->` type($result)"; } +def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + TorchIntListType:$normalized_shape, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_FloatType:$eps + ); + let results = (outs + AnyTorchTensorType:$layer_norm, + AnyTorchTensorType:$mean, + AnyTorchTensorType:$variance + ); + let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($layer_norm) `,` type($mean) `,` type($variance)"; +} + def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 4b81a9147..46531cfde 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -693,11 +693,12 @@ public: // Step 4. Get var. // Step 5. Get layernorm. namespace { -class ConvertAtenLayerNormOp : public OpConversionPattern { +class ConvertAtenNativeLayerNormOp + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenLayerNormOp op, OpAdaptor adaptor, + matchAndRewrite(AtenNativeLayerNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { MLIRContext *context = op->getContext(); Location loc = op->getLoc(); @@ -889,9 +890,14 @@ public: b.create(loc, result); }) .getResult(0); - - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, layerNorm); + Type layerNormResultType = getTypeConverter()->convertType(op.getType(0)); + Type meanResultType = getTypeConverter()->convertType(op.getType(1)); + Type varResultType = getTypeConverter()->convertType(op.getType(2)); + Value layerNorm_ = + rewriter.create(loc, layerNormResultType, layerNorm); + Value mean_ = rewriter.create(loc, meanResultType, mean); + Value var_ = rewriter.create(loc, varResultType, var); + rewriter.replaceOp(op, {layerNorm_, mean_, var_}); return success(); } }; @@ -3659,8 +3665,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0300fb4f0..7651fd955 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -477,6 +477,33 @@ class DecomposeAtenAddCLikeOp : public OpRewritePattern { return success(); } }; + +class DecomposeAtenLayerNormOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLayerNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + auto input = op.input().getType().cast(); + if (!input.hasSizes()) + return rewriter.notifyMatchFailure( + op, "input tensor should have known sizes."); + int64_t inputRank = input.getSizes().size(); + Value normalizedShape = op.normalized_shape(); + SmallVector normalizedShapeSizesTorchInt; + getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); + std::vector meanVarSizes; + for (int i = normalizedShapeSizesTorchInt.size(); i < inputRank; i++) + meanVarSizes.push_back(input.getSizes()[i]); + auto meanVarType = input.getWithSizesAndDtype( + llvm::makeArrayRef(meanVarSizes), input.getDtype()); + auto nativeLayerNorm = rewriter.create( + loc, op.getType(), meanVarType, meanVarType, op.input(), + op.normalized_shape(), op.weight(), op.bias(), op.eps()); + rewriter.replaceOp(op, nativeLayerNorm.getResult(0)); + return success(); + } +}; } // namespace namespace { @@ -522,6 +549,9 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add>(context); target.addIllegalOp(); + target.addIllegalOp(); + patterns.add(context); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { return signalPassFailure(); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 8491ff79a..5c57a240c 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -473,6 +473,8 @@ public: return visitBinaryScalarOp(scalarOp); } else if (auto nllForwardOp = dyn_cast(op)) { return visitAtenNllLossForwardOp(nllForwardOp, operands); + } else if (auto nativeLayerNormOp = dyn_cast(op)) { + return visitAtenNativeLayerNormOp(nativeLayerNormOp, operands); } // Otherwise, this is an unknown operation. Just mark all results as @@ -609,6 +611,9 @@ private: ChangeResult visitAtenNllLossForwardOp(AtenNllLossForwardOp op, ArrayRef *> operands); + ChangeResult visitAtenNativeLayerNormOp( + AtenNativeLayerNormOp op, + ArrayRef *> operands); }; } // namespace @@ -1605,6 +1610,45 @@ ChangeResult TypeAnalyzer::visitAtenAddCLikeOp( return getLatticeElement(op->getResult(0)).join(knowledge); } +ChangeResult TypeAnalyzer::visitAtenNativeLayerNormOp( + AtenNativeLayerNormOp op, + ArrayRef *> operands) { + auto input = operands[0]->getValue(); + + auto layerNormKnowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + auto meanKnowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + auto varKnowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + + layerNormKnowledge.hasSizes = input.hasSizes; + layerNormKnowledge.sizes = input.sizes; + layerNormKnowledge.dtype = input.dtype; + + int64_t layerNormSize = input.sizes.size(); + Value normalizedShape = op.normalized_shape(); + SmallVector normalizedShapeSizesTorchInt; + getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); + std::vector meanVarSizes; + if (input.hasSizes) { + for (int i = normalizedShapeSizesTorchInt.size(); i < layerNormSize; i++) + meanVarSizes.push_back(input.sizes[i]); + } + meanKnowledge.hasSizes = input.hasSizes; + meanKnowledge.sizes = meanVarSizes; + meanKnowledge.dtype = input.dtype; + varKnowledge.hasSizes = input.hasSizes; + varKnowledge.sizes = meanVarSizes; + varKnowledge.dtype = input.dtype; + + auto resultLattice = + getLatticeElement(op.getResult(0)).join(layerNormKnowledge); + resultLattice |= getLatticeElement(op.getResult(1)).join(meanKnowledge); + resultLattice |= getLatticeElement(op.getResult(2)).join(varKnowledge); + + return resultLattice; +} // ----------------------------------------------------------------------------- // Transforms. // ----------------------------------------------------------------------------- diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 98ba5c72b..163aefc6b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -502,6 +502,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) + emit ( + "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" + ) emit( "aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)" )