Refine static shapes for conv2d and maxpool2d

pull/488/head
Liam Fitzpatrick 2021-12-27 15:57:19 +00:00 committed by Ramiro Leal-Cavazos
parent 4486de5ef3
commit ccfdfd1b80
3 changed files with 162 additions and 4 deletions

View File

@ -104,3 +104,32 @@ class Conv2dWithPaddingDilationStrideModule(torch.nn.Module):
def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)
class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(in_channels=2,
out_channels=10,
kernel_size=3,
padding=3,
stride=2,
dilation=3,
bias=False)
self.train(False)
@export
@annotate_args([
None,
([5, 2, 10, 20], torch.float32, True),
])
def forward(self, x):
return self.conv(x)
@register_test_case(
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule())
def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)

View File

@ -824,16 +824,61 @@ ChangeResult TypeAnalyzer::visitAtenLinearOp(
return getLatticeElement(op->getResult(0)).join(knowledge);
}
static int64_t getOutputDimForOpWithKernel(int64_t dimIn, int64_t padding,
int64_t dilation, int64_t kernelSize,
int64_t stride) {
return ((dimIn + 2 * padding - dilation * (kernelSize - 1) - 1) / stride) + 1;
}
template <class Op>
std::vector<int64_t>
computeOpWithKernelOutputShape(Op op, const ValueKnowledge &ifm,
int64_t features, int64_t kernelHeight,
int64_t kernelWidth) {
std::vector<int64_t> result = {ifm.sizes[0], // N
features, // F
kUnknownSize, kUnknownSize};
SmallVector<int64_t> padding;
if (!matchPattern(op.padding(), m_TorchConstantIntList(padding)))
return result;
SmallVector<int64_t, 2> stride;
if (!matchPattern(op.stride(), m_TorchConstantIntList(stride)))
return result;
SmallVector<int64_t, 2> dilation;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))
return result;
int64_t ifmHeight = ifm.sizes[2];
if (ifmHeight != kUnknownSize && kernelHeight != kUnknownSize)
result[2] = getOutputDimForOpWithKernel(ifmHeight, padding[0], dilation[0],
kernelHeight, stride[0]);
int64_t ifmWidth = ifm.sizes[3];
if (ifmWidth != kUnknownSize && kernelWidth != kUnknownSize)
result[3] = getOutputDimForOpWithKernel(ifmWidth, padding[1], dilation[1],
kernelWidth, stride[1]);
return result;
}
ChangeResult TypeAnalyzer::visitAtenConv2dOp(
AtenConv2dOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
knowledge.sizes.resize(4, kUnknownSize);
auto &ifm = operands[0]->getValue();
auto &weights = operands[1]->getValue();
if (weights.hasSizes && ifm.hasSizes)
knowledge.sizes = computeOpWithKernelOutputShape(
op, ifm, weights.sizes[0], weights.sizes[2], weights.sizes[3]);
else
knowledge.sizes.resize(4, kUnknownSize);
// Running some experiments in PyTorch, the bias doesn't seem to
// contribute to the final element type.
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(op->getContext(),
{&ifm, &weights});
return getLatticeElement(op->getResult(0)).join(knowledge);
}
@ -842,7 +887,15 @@ ChangeResult TypeAnalyzer::visitAtenMaxPool2dOp(
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
knowledge.sizes.resize(4, kUnknownSize);
auto &ifm = operands[0]->getValue();
SmallVector<int64_t, 2> kernelSize;
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))
kernelSize = SmallVector<int64_t, 2>{kUnknownSize, kUnknownSize};
if (ifm.hasSizes)
knowledge.sizes = computeOpWithKernelOutputShape(
op, ifm, ifm.sizes[1], kernelSize[0], kernelSize[1]);
else
knowledge.sizes.resize(4, kUnknownSize);
knowledge.dtype = operands[0]->getValue().dtype;
return getLatticeElement(op->getResult(0)).join(knowledge);
}

View File

@ -93,6 +93,18 @@ builtin.func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:
return %3 :!torch.vtensor
}
// CHECK-LABEL: func @h
// CHECK: torch.aten.conv2d{{.*}} -> !torch.vtensor<[1,16,62,62],f32>
builtin.func @h(%arg0:!torch.vtensor<[1,8,64,64],f32>, %arg1:!torch.vtensor<[16,8,3,3],f32>, %arg2:!torch.vtensor<*,f32>) ->!torch.vtensor {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %stride, %padding, %dilation, %int1 : !torch.vtensor<[1,8,64,64],f32>, !torch.vtensor<[16,8,3,3],f32>, !torch.vtensor<*,f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.int ->!torch.vtensor
return %3 :!torch.vtensor
}
// -----
// CHECK-LABEL: func @f
@ -110,6 +122,70 @@ builtin.func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
return %27 : !torch.vtensor
}
// CHECK-LABEL: func @g
builtin.func @g(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%bool_false = torch.constant.bool false
%krnl = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,32,32],f32>
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
return %27 : !torch.vtensor
}
// CHECK-LABEL: func @h
builtin.func @h(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%bool_false = torch.constant.bool false
%krnl = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,62,62],f32>
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
return %27 : !torch.vtensor
}
// CHECK-LABEL: func @i
builtin.func @i(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%bool_false = torch.constant.bool false
%krnl = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%padding = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,66,66],f32>
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
return %27 : !torch.vtensor
}
// CHECK-LABEL: func @j
builtin.func @j(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%bool_false = torch.constant.bool false
%krnl = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,32,32],f32>
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
return %27 : !torch.vtensor
}
// -----
// CHECK-LABEL: func @f