Fix typo in `inputRank` check of `AtenBatchNormOp` (#1046)

The original conversion pattern for `AtenBatchNormOp` required that
the input rank be greater than 2; however, the only
expectation in the conversion pattern and in Pytorch is that the input
rank is greater than 1, since the second dimension of the input must
match the size of the `weight`, `bias`, `runningMean`, and
`runningVar` inputs. This commit fixes the `inputRank` check.
pull/1063/head
Ramiro Leal-Cavazos 2022-07-15 16:35:59 +00:00 committed by GitHub
parent 3589134d31
commit afdaa60dd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 2 deletions

View File

@ -93,6 +93,7 @@ TOSA_PASS_SET = {
"TypePromotionAlphaWiderModule_basic", "TypePromotionAlphaWiderModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic",
"BatchNorm1DModule_basic", "BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic", "BatchNorm2DModule_basic",
"BatchNorm3DModule_basic", "BatchNorm3DModule_basic",
"FlattenStaticModule_basic", "FlattenStaticModule_basic",

View File

@ -1134,9 +1134,9 @@ public:
auto runningVarType = runningVar.getType().cast<RankedTensorType>(); auto runningVarType = runningVar.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank(); auto inputRank = inputType.getRank();
if (inputRank <= 2) if (inputRank < 2)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "input should have rank larger than 2"); op, "input should have rank larger than 1");
if (weightType.getRank() != 1 || biasType.getRank() != 1 || if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) { runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {

View File

@ -37,6 +37,32 @@ def BatchNorm1DModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class BatchNorm1DWith2DInputModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn1d = torch.nn.BatchNorm1d(4)
self.bn1d.eval()
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
self.bn1d.weight = torch.nn.Parameter(
torch.tensor([3.0, 2.0, 4.0, 5.0]))
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
@export
@annotate_args([
None,
([10, 4], torch.float32, True),
])
def forward(self, x):
return self.bn1d(x)
@register_test_case(module_factory=lambda: BatchNorm1DWith2DInputModule())
def BatchNorm1DWith2DInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4))
# ==============================================================================
class BatchNorm2DModule(torch.nn.Module): class BatchNorm2DModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()