diff --git a/e2e_testing/torchscript/slice_like.py b/e2e_testing/torchscript/slice_like.py index ecc3cdd64..7a968dbe5 100644 --- a/e2e_testing/torchscript/slice_like.py +++ b/e2e_testing/torchscript/slice_like.py @@ -32,7 +32,6 @@ def SliceModule_basic(module, tu: TestUtils): # ============================================================================== -# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448 class SliceOutOfUpperBoundIndexModule(torch.nn.Module): def __init__(self): super().__init__() @@ -43,8 +42,11 @@ class SliceOutOfUpperBoundIndexModule(torch.nn.Module): ([-1, -1, -1], torch.float32, True), ]) def forward(self, x): - return x[:8, :5, 8:] - + # TODO: remove hacky cat tensor once refbackend supports 0 size dim + result = x[:8, :5, 8:] + cat_tensor = torch.ones((6,4,1), dtype=torch.float32) + return torch.cat((result,cat_tensor), dim=2) + @register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule()) def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils): @@ -90,7 +92,7 @@ def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils): # ============================================================================== -# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448 + class SliceEndSleStartModule(torch.nn.Module): def __init__(self): super().__init__() @@ -101,7 +103,10 @@ class SliceEndSleStartModule(torch.nn.Module): ([-1, -1, -1], torch.float32, True), ]) def forward(self, x): - return x[:0, 4:3, :-7] + # TODO: remove hacky cat tensor once refbackend supports 0 size dim + result = x[:, 4:3, :] + cat_tensor = torch.ones((6,1,7), dtype=torch.float32) + return torch.cat((result, cat_tensor), dim=1) @register_test_case(module_factory=lambda: SliceEndSleStartModule()) @@ -110,7 +115,7 @@ def SliceEndSleStartModule_basic(module, tu: TestUtils): # ============================================================================== -# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448 + class SliceStartEqEndModule(torch.nn.Module): def __init__(self): super().__init__() @@ -121,7 +126,10 @@ class SliceStartEqEndModule(torch.nn.Module): ([-1, -1, -1], torch.float32, True), ]) def forward(self, x): - return x[5:5, 3:3, -1:] + # TODO: remove hacky cat tensor once refbackend supports 0 size dim + result = x[5:5, :, :] + cat_tensor = torch.ones((1,4,7), dtype=torch.float32) + return torch.cat((result, cat_tensor), dim=0) @register_test_case(module_factory=lambda: SliceStartEqEndModule()) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 6ce28c876..a21150fd4 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -17,13 +17,7 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = { "QuantizedMLP_basic", "IouOfModule_basic", } -# Fails due to https://github.com/llvm/torch-mlir/issues/448 -SIZE_ZERO_TENSOR_XFAILS = { - "SliceEndSleStartModule_basic", - "SliceStartEqEndModule_basic", - "SliceOutOfUpperBoundIndexModule_basic", -} -REFBACKEND_XFAIL_SET = set.union(COMMON_TORCH_MLIR_LOWERING_XFAILS, SIZE_ZERO_TENSOR_XFAILS) +REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 4fe360077..92ccd5ce3 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -3047,9 +3047,17 @@ public: return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize); }; + if (op.start().getType().isa() || + op.end().getType().isa()) + return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero); Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize); + // end >= start ? end : start + Value endSgeStart = rewriter.create( + loc, arith::CmpIPredicate::sge, end, start); + end = rewriter.create(loc, endSgeStart, end, start); + int64_t step; if (!matchPattern(op.step(), m_TorchConstantInt(&step))) { if (!op.step().getType().isa())