Add torch.aten.batch_norm Linalg lowering support

1. Added a simplified version of torch.aten.batch_norm which only handles
inference and assumes the weight, bias, running_mean, running_var are not
None.

2. Removed the primitive types check in verifyLinalgCompatibleTypes check
since now we have proper type converter to handle torch types conversion.
The checks for RankedTensorType is kept because the type converter
doesn't guarantee the converted builtin tensor type is ranked. A
separate verification pass to verify the invariant expected by later
passes will need to be added before those can be removed as well.
pull/235/head
Yi Zhang 2021-06-16 21:07:04 +00:00 committed by Sean Silva
parent bbd749620e
commit 6dddb4d4fe
4 changed files with 193 additions and 16 deletions

View File

@ -0,0 +1,75 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
from torch_mlir.torchscript.e2e_test.framework import TestUtils
from torch_mlir.torchscript.e2e_test.registry import register_test_case
from torch_mlir.torchscript.annotations import annotate_args, export
# ==============================================================================
class BatchNorm1DModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn1d = torch.nn.BatchNorm1d(4)
self.bn1d.eval()
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0]))
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
@export
@annotate_args([
None,
([10, 4, 3], torch.float32, True),
])
def forward(self, x):
return self.bn1d(x)
@register_test_case(module_factory=lambda: BatchNorm1DModule())
def BatchNorm1DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 3))
# ==============================================================================
class BatchNorm2DModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn2d = torch.nn.BatchNorm2d(2)
self.bn2d.eval()
self.bn2d.running_mean = torch.tensor([0.5, 0.4])
self.bn2d.running_var = torch.tensor([3.0, 2.0])
self.bn2d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0]))
self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4]))
@export
@annotate_args([
None,
([10, 2, 3, 3], torch.float32, True),
])
def forward(self, x):
return self.bn2d(x)
@register_test_case(module_factory=lambda: BatchNorm2DModule())
def BatchNorm2DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 2, 3, 3))
# ==============================================================================
class BatchNorm3DModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn3d = torch.nn.BatchNorm3d(5)
self.bn3d.eval()
self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])
self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])
self.bn3d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
self.bn3d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
@export
@annotate_args([
None,
([2, 5, 3, 6, 4], torch.float32, True),
])
def forward(self, x):
return self.bn3d(x)
@register_test_case(module_factory=lambda: BatchNorm3DModule())
def BatchNorm3DModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 3, 6, 4))

View File

@ -25,6 +25,7 @@ from torch_mlir.torchscript.e2e_test.configs import (
import basic import basic
import vision_models import vision_models
import mlp import mlp
import batchnorm
import quantized_models import quantized_models
def _get_argparse(): def _get_argparse():

View File

@ -64,7 +64,7 @@ class ValueReport:
@property @property
def failed(self): def failed(self):
return not torch.allclose(self.value, self.golden_value) return not torch.allclose(self.value, self.golden_value, rtol=1e-04, atol=1e-08)
def error_str(self): def error_str(self):
assert self.failed assert self.failed

View File

@ -42,23 +42,15 @@ using namespace mlir::NPCOMP::Torch;
static LogicalResult verifyLinalgCompatibleTypes(Operation *op, static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
// For now, use a small allowlist of types we don't reject. // Check the value tensor is ranked as expected by Linalg.
// The main culprit in practice is an unknown dtype // TODO: Remove this check but use a separate verification pass to verify the
// when RefineTypes isn't smart enough to propagate it everywhere. // invariants expected by later passes.
// For tensors, we consider the post-conversion tensor type (this pass is
// doing a type conversion).
auto isValidLinalgType = [](Type type) { auto isValidLinalgType = [](Type type) {
if (auto tensor = type.dyn_cast<ValueTensorType>()) { auto tensor = type.dyn_cast<ValueTensorType>();
if (auto rankedTensor = return !tensor ||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>()) { tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
if (BaseMemRefType::isValidElementType(rankedTensor.getElementType()))
return true;
}
}
if (type.isa<mlir::FloatType, IntegerType, IndexType>())
return true;
return false;
}; };
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) && bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
llvm::all_of(op->getResultTypes(), isValidLinalgType); llvm::all_of(op->getResultTypes(), isValidLinalgType);
if (!valid) if (!valid)
@ -66,6 +58,113 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
return success(); return success();
} }
namespace {
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenBatchNormOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
AtenBatchNormOp::Adaptor adaptor(operands);
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
Value runningMean = adaptor.running_mean();
Value runningVar = adaptor.running_var();
Value training = adaptor.training();
Value eps = adaptor.eps();
// TODO: Handle the None cases for the optional parameters:
// weight, bias, running_mean, running_var.
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();
auto biasType = bias.getType().cast<RankedTensorType>();
auto runningMeanType = runningMean.getType().cast<RankedTensorType>();
auto runningVarType = runningVar.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
if (inputRank <= 2)
return rewriter.notifyMatchFailure(
op, "input should have rank larger than 2");
if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expect weight, bias, running_mean and running_var to be rank 1");
}
// TODO: Add support for training.
auto constFalse = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(IntegerType::get(context, 1), 0));
auto trainingFalse =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, training, constFalse);
rewriter.create<AssertOp>(
loc, trainingFalse,
rewriter.getStringAttr("training is not supported for now"));
// num_features C from an expected input of size (N,C,D,H,W ...)
Value numFeatures = rewriter.create<memref::DimOp>(loc, input, 1);
auto contractingDim0EqualsNumFeatures = [&](Value v) {
auto dim0 = rewriter.create<memref::DimOp>(loc, v, 0);
auto dim0Equal =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, numFeatures, dim0);
rewriter.create<AssertOp>(
loc, dim0Equal,
rewriter.getStringAttr(
"expect the size of dim 0 equal to the number of features"));
};
contractingDim0EqualsNumFeatures(weight);
contractingDim0EqualsNumFeatures(bias);
contractingDim0EqualsNumFeatures(runningMean);
contractingDim0EqualsNumFeatures(runningVar);
auto indexingMap = AffineMap::get(
/*dimCount=*/inputRank,
/*symbolCount=*/0, rewriter.getAffineDimExpr(1), context);
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(inputRank), // input
indexingMap, // weight
indexingMap, // bias
indexingMap, // runningMean
indexingMap, // runningVar
rewriter.getMultiDimIdentityMap(inputRank), // output
};
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
Value batchNorm =
rewriter
.create<linalg::GenericOp>(
loc, input.getType(),
ValueRange{input, weight, bias, runningMean, runningVar}, input,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value input = args[0], weight = args[1], bias = args[2],
mean = args[3], var = args[4];
// ((input - mean) / sqrt(var + eps)) * weight + bias
Value inputSubMean = b.create<SubFOp>(loc, input, mean);
// The eps is always f64.
Value truncatedEps =
b.create<FPTruncOp>(loc, var.getType(), eps);
Value varPlusEps = b.create<AddFOp>(loc, var, truncatedEps);
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
Value temp = b.create<MulFOp>(loc, inputSubMean, rSTD);
Value timesWeight = b.create<MulFOp>(loc, temp, weight);
Value plusBias = b.create<AddFOp>(loc, timesWeight, bias);
b.create<linalg::YieldOp>(loc, plusBias);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, batchNorm);
return success();
}
};
} // namespace
namespace { namespace {
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> { class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
public: public:
@ -344,6 +443,8 @@ public:
patterns.add<ConvertAtenMmOp>(typeConverter, context); patterns.add<ConvertAtenMmOp>(typeConverter, context);
target.addIllegalOp<AtenLinearOp>(); target.addIllegalOp<AtenLinearOp>();
patterns.add<ConvertAtenLinearOp>(typeConverter, context); patterns.add<ConvertAtenLinearOp>(typeConverter, context);
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<AtenTanhOp>(); target.addIllegalOp<AtenTanhOp>();
patterns.add<ConvertUnaryOp>(typeConverter, context); patterns.add<ConvertUnaryOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,