From 6dddb4d4fe206f7007a6cf693700d842a2094933 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 16 Jun 2021 21:07:04 +0000 Subject: [PATCH] Add torch.aten.batch_norm Linalg lowering support 1. Added a simplified version of torch.aten.batch_norm which only handles inference and assumes the weight, bias, running_mean, running_var are not None. 2. Removed the primitive types check in verifyLinalgCompatibleTypes check since now we have proper type converter to handle torch types conversion. The checks for RankedTensorType is kept because the type converter doesn't guarantee the converted builtin tensor type is ranked. A separate verification pass to verify the invariant expected by later passes will need to be added before those can be removed as well. --- .../e2e_testing/torchscript/batchnorm.py | 75 ++++++++++ .../pytorch/e2e_testing/torchscript/main.py | 1 + .../torchscript/e2e_test/reporting.py | 2 +- .../TorchToLinalg/TorchToLinalg.cpp | 131 ++++++++++++++++-- 4 files changed, 193 insertions(+), 16 deletions(-) create mode 100644 frontends/pytorch/e2e_testing/torchscript/batchnorm.py diff --git a/frontends/pytorch/e2e_testing/torchscript/batchnorm.py b/frontends/pytorch/e2e_testing/torchscript/batchnorm.py new file mode 100644 index 000000000..06215ba8c --- /dev/null +++ b/frontends/pytorch/e2e_testing/torchscript/batchnorm.py @@ -0,0 +1,75 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch + +from torch_mlir.torchscript.e2e_test.framework import TestUtils +from torch_mlir.torchscript.e2e_test.registry import register_test_case +from torch_mlir.torchscript.annotations import annotate_args, export + +# ============================================================================== +class BatchNorm1DModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn1d = torch.nn.BatchNorm1d(4) + self.bn1d.eval() + self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6]) + self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0]) + self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0])) + self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6])) + @export + @annotate_args([ + None, + ([10, 4, 3], torch.float32, True), + ]) + def forward(self, x): + return self.bn1d(x) + +@register_test_case(module_factory=lambda: BatchNorm1DModule()) +def BatchNorm1DModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 4, 3)) + +# ============================================================================== +class BatchNorm2DModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn2d = torch.nn.BatchNorm2d(2) + self.bn2d.eval() + self.bn2d.running_mean = torch.tensor([0.5, 0.4]) + self.bn2d.running_var = torch.tensor([3.0, 2.0]) + self.bn2d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0])) + self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4])) + @export + @annotate_args([ + None, + ([10, 2, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return self.bn2d(x) + +@register_test_case(module_factory=lambda: BatchNorm2DModule()) +def BatchNorm2DModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 2, 3, 3)) + +# ============================================================================== +class BatchNorm3DModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn3d = torch.nn.BatchNorm3d(5) + self.bn3d.eval() + self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]) + self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]) + self.bn3d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])) + self.bn3d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])) + @export + @annotate_args([ + None, + ([2, 5, 3, 6, 4], torch.float32, True), + ]) + def forward(self, x): + return self.bn3d(x) + +@register_test_case(module_factory=lambda: BatchNorm3DModule()) +def BatchNorm3DModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 3, 6, 4)) diff --git a/frontends/pytorch/e2e_testing/torchscript/main.py b/frontends/pytorch/e2e_testing/torchscript/main.py index 77c34e662..ca9e176d5 100644 --- a/frontends/pytorch/e2e_testing/torchscript/main.py +++ b/frontends/pytorch/e2e_testing/torchscript/main.py @@ -25,6 +25,7 @@ from torch_mlir.torchscript.e2e_test.configs import ( import basic import vision_models import mlp +import batchnorm import quantized_models def _get_argparse(): diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py index e739e3432..1301f33c7 100644 --- a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py @@ -64,7 +64,7 @@ class ValueReport: @property def failed(self): - return not torch.allclose(self.value, self.golden_value) + return not torch.allclose(self.value, self.golden_value, rtol=1e-04, atol=1e-08) def error_str(self): assert self.failed diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 49f5491a9..6da4931b3 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -42,23 +42,15 @@ using namespace mlir::NPCOMP::Torch; static LogicalResult verifyLinalgCompatibleTypes(Operation *op, PatternRewriter &rewriter) { - // For now, use a small allowlist of types we don't reject. - // The main culprit in practice is an unknown dtype - // when RefineTypes isn't smart enough to propagate it everywhere. - // For tensors, we consider the post-conversion tensor type (this pass is - // doing a type conversion). + // Check the value tensor is ranked as expected by Linalg. + // TODO: Remove this check but use a separate verification pass to verify the + // invariants expected by later passes. auto isValidLinalgType = [](Type type) { - if (auto tensor = type.dyn_cast()) { - if (auto rankedTensor = - tensor.toBuiltinTensor().dyn_cast_or_null()) { - if (BaseMemRefType::isValidElementType(rankedTensor.getElementType())) - return true; - } - } - if (type.isa()) - return true; - return false; + auto tensor = type.dyn_cast(); + return !tensor || + tensor.toBuiltinTensor().dyn_cast_or_null(); }; + bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) && llvm::all_of(op->getResultTypes(), isValidLinalgType); if (!valid) @@ -66,6 +58,113 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op, return success(); } +namespace { +class ConvertAtenBatchNormOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenBatchNormOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + AtenBatchNormOp::Adaptor adaptor(operands); + MLIRContext *context = op->getContext(); + Location loc = op->getLoc(); + Value input = adaptor.input(); + Value weight = adaptor.weight(); + Value bias = adaptor.bias(); + Value runningMean = adaptor.running_mean(); + Value runningVar = adaptor.running_var(); + Value training = adaptor.training(); + Value eps = adaptor.eps(); + + // TODO: Handle the None cases for the optional parameters: + // weight, bias, running_mean, running_var. + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + auto inputType = input.getType().cast(); + auto weightType = weight.getType().cast(); + auto biasType = bias.getType().cast(); + auto runningMeanType = runningMean.getType().cast(); + auto runningVarType = runningVar.getType().cast(); + + auto inputRank = inputType.getRank(); + if (inputRank <= 2) + return rewriter.notifyMatchFailure( + op, "input should have rank larger than 2"); + + if (weightType.getRank() != 1 || biasType.getRank() != 1 || + runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) { + return rewriter.notifyMatchFailure( + op, "expect weight, bias, running_mean and running_var to be rank 1"); + } + + // TODO: Add support for training. + auto constFalse = rewriter.create( + loc, IntegerAttr::get(IntegerType::get(context, 1), 0)); + auto trainingFalse = + rewriter.create(loc, CmpIPredicate::eq, training, constFalse); + rewriter.create( + loc, trainingFalse, + rewriter.getStringAttr("training is not supported for now")); + + // num_features – C from an expected input of size (N,C,D,H,W ...) + Value numFeatures = rewriter.create(loc, input, 1); + auto contractingDim0EqualsNumFeatures = [&](Value v) { + auto dim0 = rewriter.create(loc, v, 0); + auto dim0Equal = + rewriter.create(loc, CmpIPredicate::eq, numFeatures, dim0); + rewriter.create( + loc, dim0Equal, + rewriter.getStringAttr( + "expect the size of dim 0 equal to the number of features")); + }; + contractingDim0EqualsNumFeatures(weight); + contractingDim0EqualsNumFeatures(bias); + contractingDim0EqualsNumFeatures(runningMean); + contractingDim0EqualsNumFeatures(runningVar); + + auto indexingMap = AffineMap::get( + /*dimCount=*/inputRank, + /*symbolCount=*/0, rewriter.getAffineDimExpr(1), context); + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + indexingMap, // weight + indexingMap, // bias + indexingMap, // runningMean + indexingMap, // runningVar + rewriter.getMultiDimIdentityMap(inputRank), // output + }; + SmallVector iteratorTypes(inputRank, "parallel"); + Value batchNorm = + rewriter + .create( + loc, input.getType(), + ValueRange{input, weight, bias, runningMean, runningVar}, input, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], weight = args[1], bias = args[2], + mean = args[3], var = args[4]; + // ((input - mean) / sqrt(var + eps)) * weight + bias + Value inputSubMean = b.create(loc, input, mean); + // The eps is always f64. + Value truncatedEps = + b.create(loc, var.getType(), eps); + Value varPlusEps = b.create(loc, var, truncatedEps); + Value rSTD = b.create(loc, varPlusEps); + Value temp = b.create(loc, inputSubMean, rSTD); + Value timesWeight = b.create(loc, temp, weight); + Value plusBias = b.create(loc, timesWeight, bias); + b.create(loc, plusBias); + }) + .getResult(0); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, batchNorm); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenMmOp : public OpConversionPattern { public: @@ -344,6 +443,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target,