mirror of https://github.com/llvm/torch-mlir
Add torch.aten.batch_norm Linalg lowering support
1. Added a simplified version of torch.aten.batch_norm which only handles inference and assumes the weight, bias, running_mean, running_var are not None. 2. Removed the primitive types check in verifyLinalgCompatibleTypes check since now we have proper type converter to handle torch types conversion. The checks for RankedTensorType is kept because the type converter doesn't guarantee the converted builtin tensor type is ranked. A separate verification pass to verify the invariant expected by later passes will need to be added before those can be removed as well.pull/235/head
parent
bbd749620e
commit
6dddb4d4fe
|
@ -0,0 +1,75 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.torchscript.e2e_test.framework import TestUtils
|
||||||
|
from torch_mlir.torchscript.e2e_test.registry import register_test_case
|
||||||
|
from torch_mlir.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
class BatchNorm1DModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bn1d = torch.nn.BatchNorm1d(4)
|
||||||
|
self.bn1d.eval()
|
||||||
|
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
||||||
|
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
||||||
|
self.bn1d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||||
|
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([10, 4, 3], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.bn1d(x)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BatchNorm1DModule())
|
||||||
|
def BatchNorm1DModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 4, 3))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
class BatchNorm2DModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bn2d = torch.nn.BatchNorm2d(2)
|
||||||
|
self.bn2d.eval()
|
||||||
|
self.bn2d.running_mean = torch.tensor([0.5, 0.4])
|
||||||
|
self.bn2d.running_var = torch.tensor([3.0, 2.0])
|
||||||
|
self.bn2d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0]))
|
||||||
|
self.bn2d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4]))
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([10, 2, 3, 3], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.bn2d(x)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BatchNorm2DModule())
|
||||||
|
def BatchNorm2DModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 2, 3, 3))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
class BatchNorm3DModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bn3d = torch.nn.BatchNorm3d(5)
|
||||||
|
self.bn3d.eval()
|
||||||
|
self.bn3d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4])
|
||||||
|
self.bn3d.running_var = torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0])
|
||||||
|
self.bn3d.weight = torch.nn.Parameter(torch.tensor([3.0, 2.0, 4.0, 2.0, 3.0]))
|
||||||
|
self.bn3d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.2, 0.4]))
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 5, 3, 6, 4], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.bn3d(x)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BatchNorm3DModule())
|
||||||
|
def BatchNorm3DModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 5, 3, 6, 4))
|
|
@ -25,6 +25,7 @@ from torch_mlir.torchscript.e2e_test.configs import (
|
||||||
import basic
|
import basic
|
||||||
import vision_models
|
import vision_models
|
||||||
import mlp
|
import mlp
|
||||||
|
import batchnorm
|
||||||
import quantized_models
|
import quantized_models
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
|
|
|
@ -64,7 +64,7 @@ class ValueReport:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def failed(self):
|
def failed(self):
|
||||||
return not torch.allclose(self.value, self.golden_value)
|
return not torch.allclose(self.value, self.golden_value, rtol=1e-04, atol=1e-08)
|
||||||
|
|
||||||
def error_str(self):
|
def error_str(self):
|
||||||
assert self.failed
|
assert self.failed
|
||||||
|
|
|
@ -42,23 +42,15 @@ using namespace mlir::NPCOMP::Torch;
|
||||||
|
|
||||||
static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
// For now, use a small allowlist of types we don't reject.
|
// Check the value tensor is ranked as expected by Linalg.
|
||||||
// The main culprit in practice is an unknown dtype
|
// TODO: Remove this check but use a separate verification pass to verify the
|
||||||
// when RefineTypes isn't smart enough to propagate it everywhere.
|
// invariants expected by later passes.
|
||||||
// For tensors, we consider the post-conversion tensor type (this pass is
|
|
||||||
// doing a type conversion).
|
|
||||||
auto isValidLinalgType = [](Type type) {
|
auto isValidLinalgType = [](Type type) {
|
||||||
if (auto tensor = type.dyn_cast<ValueTensorType>()) {
|
auto tensor = type.dyn_cast<ValueTensorType>();
|
||||||
if (auto rankedTensor =
|
return !tensor ||
|
||||||
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>()) {
|
tensor.toBuiltinTensor().dyn_cast_or_null<RankedTensorType>();
|
||||||
if (BaseMemRefType::isValidElementType(rankedTensor.getElementType()))
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (type.isa<mlir::FloatType, IntegerType, IndexType>())
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
|
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
|
||||||
llvm::all_of(op->getResultTypes(), isValidLinalgType);
|
llvm::all_of(op->getResultTypes(), isValidLinalgType);
|
||||||
if (!valid)
|
if (!valid)
|
||||||
|
@ -66,6 +58,113 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenBatchNormOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
AtenBatchNormOp::Adaptor adaptor(operands);
|
||||||
|
MLIRContext *context = op->getContext();
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value input = adaptor.input();
|
||||||
|
Value weight = adaptor.weight();
|
||||||
|
Value bias = adaptor.bias();
|
||||||
|
Value runningMean = adaptor.running_mean();
|
||||||
|
Value runningVar = adaptor.running_var();
|
||||||
|
Value training = adaptor.training();
|
||||||
|
Value eps = adaptor.eps();
|
||||||
|
|
||||||
|
// TODO: Handle the None cases for the optional parameters:
|
||||||
|
// weight, bias, running_mean, running_var.
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
auto weightType = weight.getType().cast<RankedTensorType>();
|
||||||
|
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||||
|
auto runningMeanType = runningMean.getType().cast<RankedTensorType>();
|
||||||
|
auto runningVarType = runningVar.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
auto inputRank = inputType.getRank();
|
||||||
|
if (inputRank <= 2)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "input should have rank larger than 2");
|
||||||
|
|
||||||
|
if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
|
||||||
|
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "expect weight, bias, running_mean and running_var to be rank 1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Add support for training.
|
||||||
|
auto constFalse = rewriter.create<ConstantOp>(
|
||||||
|
loc, IntegerAttr::get(IntegerType::get(context, 1), 0));
|
||||||
|
auto trainingFalse =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, training, constFalse);
|
||||||
|
rewriter.create<AssertOp>(
|
||||||
|
loc, trainingFalse,
|
||||||
|
rewriter.getStringAttr("training is not supported for now"));
|
||||||
|
|
||||||
|
// num_features – C from an expected input of size (N,C,D,H,W ...)
|
||||||
|
Value numFeatures = rewriter.create<memref::DimOp>(loc, input, 1);
|
||||||
|
auto contractingDim0EqualsNumFeatures = [&](Value v) {
|
||||||
|
auto dim0 = rewriter.create<memref::DimOp>(loc, v, 0);
|
||||||
|
auto dim0Equal =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, numFeatures, dim0);
|
||||||
|
rewriter.create<AssertOp>(
|
||||||
|
loc, dim0Equal,
|
||||||
|
rewriter.getStringAttr(
|
||||||
|
"expect the size of dim 0 equal to the number of features"));
|
||||||
|
};
|
||||||
|
contractingDim0EqualsNumFeatures(weight);
|
||||||
|
contractingDim0EqualsNumFeatures(bias);
|
||||||
|
contractingDim0EqualsNumFeatures(runningMean);
|
||||||
|
contractingDim0EqualsNumFeatures(runningVar);
|
||||||
|
|
||||||
|
auto indexingMap = AffineMap::get(
|
||||||
|
/*dimCount=*/inputRank,
|
||||||
|
/*symbolCount=*/0, rewriter.getAffineDimExpr(1), context);
|
||||||
|
SmallVector<AffineMap> indexingMaps = {
|
||||||
|
rewriter.getMultiDimIdentityMap(inputRank), // input
|
||||||
|
indexingMap, // weight
|
||||||
|
indexingMap, // bias
|
||||||
|
indexingMap, // runningMean
|
||||||
|
indexingMap, // runningVar
|
||||||
|
rewriter.getMultiDimIdentityMap(inputRank), // output
|
||||||
|
};
|
||||||
|
SmallVector<StringRef> iteratorTypes(inputRank, "parallel");
|
||||||
|
Value batchNorm =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, input.getType(),
|
||||||
|
ValueRange{input, weight, bias, runningMean, runningVar}, input,
|
||||||
|
/*indexingMaps=*/indexingMaps,
|
||||||
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value input = args[0], weight = args[1], bias = args[2],
|
||||||
|
mean = args[3], var = args[4];
|
||||||
|
// ((input - mean) / sqrt(var + eps)) * weight + bias
|
||||||
|
Value inputSubMean = b.create<SubFOp>(loc, input, mean);
|
||||||
|
// The eps is always f64.
|
||||||
|
Value truncatedEps =
|
||||||
|
b.create<FPTruncOp>(loc, var.getType(), eps);
|
||||||
|
Value varPlusEps = b.create<AddFOp>(loc, var, truncatedEps);
|
||||||
|
Value rSTD = b.create<math::RsqrtOp>(loc, varPlusEps);
|
||||||
|
Value temp = b.create<MulFOp>(loc, inputSubMean, rSTD);
|
||||||
|
Value timesWeight = b.create<MulFOp>(loc, temp, weight);
|
||||||
|
Value plusBias = b.create<AddFOp>(loc, timesWeight, bias);
|
||||||
|
b.create<linalg::YieldOp>(loc, plusBias);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, batchNorm);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -344,6 +443,8 @@ public:
|
||||||
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenLinearOp>();
|
target.addIllegalOp<AtenLinearOp>();
|
||||||
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
|
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenBatchNormOp>();
|
||||||
|
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenTanhOp>();
|
target.addIllegalOp<AtenTanhOp>();
|
||||||
patterns.add<ConvertUnaryOp>(typeConverter, context);
|
patterns.add<ConvertUnaryOp>(typeConverter, context);
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
|
Loading…
Reference in New Issue