mirror of https://github.com/llvm/torch-mlir
parent
a6c3050dd0
commit
396ab35c9d
|
@ -32,7 +32,6 @@ def SliceModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
|
||||
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -43,7 +42,10 @@ class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
|||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:8, :5, 8:]
|
||||
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||
result = x[:8, :5, 8:]
|
||||
cat_tensor = torch.ones((6,4,1), dtype=torch.float32)
|
||||
return torch.cat((result,cat_tensor), dim=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule())
|
||||
|
@ -90,7 +92,7 @@ def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
|
||||
|
||||
class SliceEndSleStartModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -101,7 +103,10 @@ class SliceEndSleStartModule(torch.nn.Module):
|
|||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:0, 4:3, :-7]
|
||||
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||
result = x[:, 4:3, :]
|
||||
cat_tensor = torch.ones((6,1,7), dtype=torch.float32)
|
||||
return torch.cat((result, cat_tensor), dim=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceEndSleStartModule())
|
||||
|
@ -110,7 +115,7 @@ def SliceEndSleStartModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
|
||||
|
||||
class SliceStartEqEndModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -121,7 +126,10 @@ class SliceStartEqEndModule(torch.nn.Module):
|
|||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[5:5, 3:3, -1:]
|
||||
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||
result = x[5:5, :, :]
|
||||
cat_tensor = torch.ones((1,4,7), dtype=torch.float32)
|
||||
return torch.cat((result, cat_tensor), dim=0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceStartEqEndModule())
|
||||
|
|
|
@ -17,13 +17,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"QuantizedMLP_basic",
|
||||
"IouOfModule_basic",
|
||||
}
|
||||
# Fails due to https://github.com/llvm/torch-mlir/issues/448
|
||||
SIZE_ZERO_TENSOR_XFAILS = {
|
||||
"SliceEndSleStartModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
}
|
||||
REFBACKEND_XFAIL_SET = set.union(COMMON_TORCH_MLIR_LOWERING_XFAILS, SIZE_ZERO_TENSOR_XFAILS)
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
|
|
|
@ -3047,9 +3047,17 @@ public:
|
|||
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
|
||||
};
|
||||
|
||||
if (op.start().getType().isa<OptionalType>() ||
|
||||
op.end().getType().isa<OptionalType>())
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
|
||||
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
|
||||
|
||||
// end >= start ? end : start
|
||||
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, end, start);
|
||||
end = rewriter.create<SelectOp>(loc, endSgeStart, end, start);
|
||||
|
||||
int64_t step;
|
||||
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
|
||||
if (!op.step().getType().isa<Torch::NoneType>())
|
||||
|
|
Loading…
Reference in New Issue