mirror of https://github.com/llvm/torch-mlir
Fix typo in `inputRank` check of `AtenBatchNormOp` (#1046)
The original conversion pattern for `AtenBatchNormOp` required that the input rank be greater than 2; however, the only expectation in the conversion pattern and in Pytorch is that the input rank is greater than 1, since the second dimension of the input must match the size of the `weight`, `bias`, `runningMean`, and `runningVar` inputs. This commit fixes the `inputRank` check.pull/1063/head
parent
3589134d31
commit
afdaa60dd4
|
@ -93,6 +93,7 @@ TOSA_PASS_SET = {
|
||||||
"TypePromotionAlphaWiderModule_basic",
|
"TypePromotionAlphaWiderModule_basic",
|
||||||
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
"Conv2dWithPaddingDilationStrideStaticModule_basic",
|
||||||
"BatchNorm1DModule_basic",
|
"BatchNorm1DModule_basic",
|
||||||
|
"BatchNorm1DWith2DInputModule_basic",
|
||||||
"BatchNorm2DModule_basic",
|
"BatchNorm2DModule_basic",
|
||||||
"BatchNorm3DModule_basic",
|
"BatchNorm3DModule_basic",
|
||||||
"FlattenStaticModule_basic",
|
"FlattenStaticModule_basic",
|
||||||
|
|
|
@ -1134,9 +1134,9 @@ public:
|
||||||
auto runningVarType = runningVar.getType().cast<RankedTensorType>();
|
auto runningVarType = runningVar.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
if (inputRank <= 2)
|
if (inputRank < 2)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "input should have rank larger than 2");
|
op, "input should have rank larger than 1");
|
||||||
|
|
||||||
if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
|
if (weightType.getRank() != 1 || biasType.getRank() != 1 ||
|
||||||
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {
|
runningMeanType.getRank() != 1 || runningVarType.getRank() != 1) {
|
||||||
|
|
|
@ -37,6 +37,32 @@ def BatchNorm1DModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class BatchNorm1DWith2DInputModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.bn1d = torch.nn.BatchNorm1d(4)
|
||||||
|
self.bn1d.eval()
|
||||||
|
self.bn1d.running_mean = torch.tensor([0.5, 0.4, 0.3, 0.6])
|
||||||
|
self.bn1d.running_var = torch.tensor([3.0, 2.0, 4.0, 5.0])
|
||||||
|
self.bn1d.weight = torch.nn.Parameter(
|
||||||
|
torch.tensor([3.0, 2.0, 4.0, 5.0]))
|
||||||
|
self.bn1d.bias = torch.nn.Parameter(torch.tensor([0.5, 0.4, 0.3, 0.6]))
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([10, 4], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return self.bn1d(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: BatchNorm1DWith2DInputModule())
|
||||||
|
def BatchNorm1DWith2DInputModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class BatchNorm2DModule(torch.nn.Module):
|
class BatchNorm2DModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue