mirror of https://github.com/llvm/torch-mlir
parent
a6c3050dd0
commit
396ab35c9d
|
@ -32,7 +32,6 @@ def SliceModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
|
|
||||||
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -43,7 +42,10 @@ class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x[:8, :5, 8:]
|
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||||
|
result = x[:8, :5, 8:]
|
||||||
|
cat_tensor = torch.ones((6,4,1), dtype=torch.float32)
|
||||||
|
return torch.cat((result,cat_tensor), dim=2)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule())
|
@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule())
|
||||||
|
@ -90,7 +92,7 @@ def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
|
|
||||||
class SliceEndSleStartModule(torch.nn.Module):
|
class SliceEndSleStartModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -101,7 +103,10 @@ class SliceEndSleStartModule(torch.nn.Module):
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x[:0, 4:3, :-7]
|
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||||
|
result = x[:, 4:3, :]
|
||||||
|
cat_tensor = torch.ones((6,1,7), dtype=torch.float32)
|
||||||
|
return torch.cat((result, cat_tensor), dim=1)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: SliceEndSleStartModule())
|
@register_test_case(module_factory=lambda: SliceEndSleStartModule())
|
||||||
|
@ -110,7 +115,7 @@ def SliceEndSleStartModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
|
|
||||||
class SliceStartEqEndModule(torch.nn.Module):
|
class SliceStartEqEndModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -121,7 +126,10 @@ class SliceStartEqEndModule(torch.nn.Module):
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x[5:5, 3:3, -1:]
|
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||||
|
result = x[5:5, :, :]
|
||||||
|
cat_tensor = torch.ones((1,4,7), dtype=torch.float32)
|
||||||
|
return torch.cat((result, cat_tensor), dim=0)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: SliceStartEqEndModule())
|
@register_test_case(module_factory=lambda: SliceStartEqEndModule())
|
||||||
|
|
|
@ -17,13 +17,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||||
"QuantizedMLP_basic",
|
"QuantizedMLP_basic",
|
||||||
"IouOfModule_basic",
|
"IouOfModule_basic",
|
||||||
}
|
}
|
||||||
# Fails due to https://github.com/llvm/torch-mlir/issues/448
|
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||||
SIZE_ZERO_TENSOR_XFAILS = {
|
|
||||||
"SliceEndSleStartModule_basic",
|
|
||||||
"SliceStartEqEndModule_basic",
|
|
||||||
"SliceOutOfUpperBoundIndexModule_basic",
|
|
||||||
}
|
|
||||||
REFBACKEND_XFAIL_SET = set.union(COMMON_TORCH_MLIR_LOWERING_XFAILS, SIZE_ZERO_TENSOR_XFAILS)
|
|
||||||
|
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
|
|
|
@ -3047,9 +3047,17 @@ public:
|
||||||
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
|
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (op.start().getType().isa<OptionalType>() ||
|
||||||
|
op.end().getType().isa<OptionalType>())
|
||||||
|
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||||
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
|
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
|
||||||
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
|
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
|
||||||
|
|
||||||
|
// end >= start ? end : start
|
||||||
|
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::sge, end, start);
|
||||||
|
end = rewriter.create<SelectOp>(loc, endSgeStart, end, start);
|
||||||
|
|
||||||
int64_t step;
|
int64_t step;
|
||||||
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
|
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
|
||||||
if (!op.step().getType().isa<Torch::NoneType>())
|
if (!op.step().getType().isa<Torch::NoneType>())
|
||||||
|
|
Loading…
Reference in New Issue