mirror of https://github.com/llvm/torch-mlir
Refine static shapes for conv2d and maxpool2d
parent
4486de5ef3
commit
ccfdfd1b80
|
@ -104,3 +104,32 @@ class Conv2dWithPaddingDilationStrideModule(torch.nn.Module):
|
|||
def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils):
|
||||
t = tu.rand(5, 2, 10, 20)
|
||||
module.forward(t)
|
||||
|
||||
|
||||
class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.conv = torch.nn.Conv2d(in_channels=2,
|
||||
out_channels=10,
|
||||
kernel_size=3,
|
||||
padding=3,
|
||||
stride=2,
|
||||
dilation=3,
|
||||
bias=False)
|
||||
self.train(False)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, 2, 10, 20], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Conv2dWithPaddingDilationStrideStaticModule())
|
||||
def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
|
||||
t = tu.rand(5, 2, 10, 20)
|
||||
module.forward(t)
|
||||
|
|
|
@ -824,16 +824,61 @@ ChangeResult TypeAnalyzer::visitAtenLinearOp(
|
|||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
static int64_t getOutputDimForOpWithKernel(int64_t dimIn, int64_t padding,
|
||||
int64_t dilation, int64_t kernelSize,
|
||||
int64_t stride) {
|
||||
return ((dimIn + 2 * padding - dilation * (kernelSize - 1) - 1) / stride) + 1;
|
||||
}
|
||||
|
||||
template <class Op>
|
||||
std::vector<int64_t>
|
||||
computeOpWithKernelOutputShape(Op op, const ValueKnowledge &ifm,
|
||||
int64_t features, int64_t kernelHeight,
|
||||
int64_t kernelWidth) {
|
||||
std::vector<int64_t> result = {ifm.sizes[0], // N
|
||||
features, // F
|
||||
kUnknownSize, kUnknownSize};
|
||||
|
||||
SmallVector<int64_t> padding;
|
||||
if (!matchPattern(op.padding(), m_TorchConstantIntList(padding)))
|
||||
return result;
|
||||
SmallVector<int64_t, 2> stride;
|
||||
if (!matchPattern(op.stride(), m_TorchConstantIntList(stride)))
|
||||
return result;
|
||||
SmallVector<int64_t, 2> dilation;
|
||||
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilation)))
|
||||
return result;
|
||||
|
||||
int64_t ifmHeight = ifm.sizes[2];
|
||||
if (ifmHeight != kUnknownSize && kernelHeight != kUnknownSize)
|
||||
result[2] = getOutputDimForOpWithKernel(ifmHeight, padding[0], dilation[0],
|
||||
kernelHeight, stride[0]);
|
||||
|
||||
int64_t ifmWidth = ifm.sizes[3];
|
||||
if (ifmWidth != kUnknownSize && kernelWidth != kUnknownSize)
|
||||
result[3] = getOutputDimForOpWithKernel(ifmWidth, padding[1], dilation[1],
|
||||
kernelWidth, stride[1]);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenConv2dOp(
|
||||
AtenConv2dOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
auto &ifm = operands[0]->getValue();
|
||||
auto &weights = operands[1]->getValue();
|
||||
if (weights.hasSizes && ifm.hasSizes)
|
||||
knowledge.sizes = computeOpWithKernelOutputShape(
|
||||
op, ifm, weights.sizes[0], weights.sizes[2], weights.sizes[3]);
|
||||
else
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
|
||||
// Running some experiments in PyTorch, the bias doesn't seem to
|
||||
// contribute to the final element type.
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(op->getContext(),
|
||||
{&ifm, &weights});
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
|
@ -842,7 +887,15 @@ ChangeResult TypeAnalyzer::visitAtenMaxPool2dOp(
|
|||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
auto &ifm = operands[0]->getValue();
|
||||
SmallVector<int64_t, 2> kernelSize;
|
||||
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSize)))
|
||||
kernelSize = SmallVector<int64_t, 2>{kUnknownSize, kUnknownSize};
|
||||
if (ifm.hasSizes)
|
||||
knowledge.sizes = computeOpWithKernelOutputShape(
|
||||
op, ifm, ifm.sizes[1], kernelSize[0], kernelSize[1]);
|
||||
else
|
||||
knowledge.sizes.resize(4, kUnknownSize);
|
||||
knowledge.dtype = operands[0]->getValue().dtype;
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
|
|
@ -93,6 +93,18 @@ builtin.func @g(%arg0:!torch.vtensor<*,f32>, %arg1:!torch.vtensor<*,f32>, %arg2:
|
|||
return %3 :!torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @h
|
||||
// CHECK: torch.aten.conv2d{{.*}} -> !torch.vtensor<[1,16,62,62],f32>
|
||||
builtin.func @h(%arg0:!torch.vtensor<[1,8,64,64],f32>, %arg1:!torch.vtensor<[16,8,3,3],f32>, %arg2:!torch.vtensor<*,f32>) ->!torch.vtensor {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%3 = torch.aten.conv2d %arg0, %arg1, %arg2, %stride, %padding, %dilation, %int1 : !torch.vtensor<[1,8,64,64],f32>, !torch.vtensor<[16,8,3,3],f32>, !torch.vtensor<*,f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.int ->!torch.vtensor
|
||||
return %3 :!torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
|
@ -110,6 +122,70 @@ builtin.func @f(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor {
|
|||
return %27 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @g
|
||||
builtin.func @g(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%bool_false = torch.constant.bool false
|
||||
%krnl = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,32,32],f32>
|
||||
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
|
||||
return %27 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @h
|
||||
builtin.func @h(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%bool_false = torch.constant.bool false
|
||||
%krnl = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,62,62],f32>
|
||||
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
|
||||
return %27 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @i
|
||||
builtin.func @i(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%bool_false = torch.constant.bool false
|
||||
%krnl = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%padding = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,66,66],f32>
|
||||
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
|
||||
return %27 : !torch.vtensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @j
|
||||
builtin.func @j(%arg0: !torch.vtensor<[1,8,64,64],f32>) -> !torch.vtensor {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%bool_false = torch.constant.bool false
|
||||
%krnl = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
%dilation = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: torch.aten.max_pool2d{{.*}} -> !torch.vtensor<[1,8,32,32],f32>
|
||||
%27 = torch.aten.max_pool2d %arg0, %krnl, %stride, %padding, %dilation, %bool_false : !torch.vtensor<[1,8,64,64],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor
|
||||
return %27 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @f
|
||||
|
|
Loading…
Reference in New Issue