Small fixes for slice edge cases (#476)

pull/487/head snapshot-20211215.146
Daniel Garvey 2021-12-15 15:54:41 -06:00 committed by GitHub
parent a6c3050dd0
commit 396ab35c9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 14 deletions

View File

@ -32,7 +32,6 @@ def SliceModule_basic(module, tu: TestUtils):
# ==============================================================================
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -43,8 +42,11 @@ class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return x[:8, :5, 8:]
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
result = x[:8, :5, 8:]
cat_tensor = torch.ones((6,4,1), dtype=torch.float32)
return torch.cat((result,cat_tensor), dim=2)
@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule())
def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils):
@ -90,7 +92,7 @@ def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):
# ==============================================================================
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
class SliceEndSleStartModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -101,7 +103,10 @@ class SliceEndSleStartModule(torch.nn.Module):
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return x[:0, 4:3, :-7]
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
result = x[:, 4:3, :]
cat_tensor = torch.ones((6,1,7), dtype=torch.float32)
return torch.cat((result, cat_tensor), dim=1)
@register_test_case(module_factory=lambda: SliceEndSleStartModule())
@ -110,7 +115,7 @@ def SliceEndSleStartModule_basic(module, tu: TestUtils):
# ==============================================================================
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
class SliceStartEqEndModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -121,7 +126,10 @@ class SliceStartEqEndModule(torch.nn.Module):
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return x[5:5, 3:3, -1:]
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
result = x[5:5, :, :]
cat_tensor = torch.ones((1,4,7), dtype=torch.float32)
return torch.cat((result, cat_tensor), dim=0)
@register_test_case(module_factory=lambda: SliceStartEqEndModule())

View File

@ -17,13 +17,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"IouOfModule_basic",
}
# Fails due to https://github.com/llvm/torch-mlir/issues/448
SIZE_ZERO_TENSOR_XFAILS = {
"SliceEndSleStartModule_basic",
"SliceStartEqEndModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
}
REFBACKEND_XFAIL_SET = set.union(COMMON_TORCH_MLIR_LOWERING_XFAILS, SIZE_ZERO_TENSOR_XFAILS)
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.

View File

@ -3047,9 +3047,17 @@ public:
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
};
if (op.start().getType().isa<OptionalType>() ||
op.end().getType().isa<OptionalType>())
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
// end >= start ? end : start
Value endSgeStart = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, end, start);
end = rewriter.create<SelectOp>(loc, endSgeStart, end, start);
int64_t step;
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
if (!op.step().getType().isa<Torch::NoneType>())