mirror of https://github.com/llvm/torch-mlir
parent
46a2189a41
commit
a52aded0b9
|
@ -100,6 +100,8 @@ class MmTanhModule(torch.nn.Module):
|
|||
def matmul(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MmTanhModule())
|
||||
def MmTanhModule_basic(module, tu: TestUtils):
|
||||
|
@ -192,6 +194,8 @@ class AdaptiveAvgPool2dModule(torch.nn.Module):
|
|||
def AdaptiveAvgPool2dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 3, 8, 9))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class FlattenStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -211,6 +215,8 @@ class FlattenStaticModule(torch.nn.Module):
|
|||
def FlattenStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class FlattenRank0Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -230,6 +236,8 @@ class FlattenRank0Module(torch.nn.Module):
|
|||
def FlattenRank0Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor(4.0))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class FlattenDynamicModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -249,6 +257,8 @@ class FlattenDynamicModule(torch.nn.Module):
|
|||
def FlattenDynamicModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 3, 8, 9, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MaxPool2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -266,6 +276,8 @@ class MaxPool2dModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return self.mp2d(x)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MaxPool2dModule())
|
||||
def MaxPool2dModule_basic(module, tu: TestUtils):
|
||||
|
@ -284,6 +296,8 @@ class TransposeIntModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.transpose(x, 0, 1)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TransposeIntModule())
|
||||
def TransposeIntModule_basic(module, tu: TestUtils):
|
||||
|
@ -305,6 +319,8 @@ class PermuteModule(torch.nn.Module):
|
|||
def PermuteModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TransposeIntNegDimsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -317,6 +333,8 @@ class TransposeIntNegDimsModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.transpose(x, -1, -2)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TransposeIntNegDimsModule())
|
||||
def TransposeIntNegDimsModule_basic(module, tu: TestUtils):
|
||||
|
@ -335,6 +353,8 @@ class PermuteNegativeIndexModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return x.permute(0, -1, 1)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
@register_test_case(module_factory=lambda: PermuteNegativeIndexModule())
|
||||
def PermuteNegativeIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 2))
|
||||
|
@ -357,6 +377,8 @@ class TensorsConcatModule(torch.nn.Module):
|
|||
def TensorsConcatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 2, 4), tu.rand(2, 1, 4), tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -376,6 +398,8 @@ class GatherModule(torch.nn.Module):
|
|||
def GatherModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]]))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AddSizeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -396,6 +420,8 @@ class AddSizeIntModule(torch.nn.Module):
|
|||
def AddSizeIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AddSizeIntNegDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -417,6 +443,8 @@ class AddSizeIntNegDimModule(torch.nn.Module):
|
|||
def AddSizeIntNegDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class EmbeddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -438,6 +466,7 @@ class EmbeddingModule(torch.nn.Module):
|
|||
def EmbeddingModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 3)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SoftmaxIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -474,6 +503,8 @@ class _SoftmaxModule(torch.nn.Module):
|
|||
def _SoftmaxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SoftmaxIntNegDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -494,6 +525,8 @@ class SoftmaxIntNegDimModule(torch.nn.Module):
|
|||
def SoftmaxIntNegDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class SoftmaxIntArgTypeF64Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -513,6 +546,7 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
|
|||
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).double())
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BroadcastToModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -531,6 +565,8 @@ class BroadcastToModule(torch.nn.Module):
|
|||
def BroadcastToModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 1))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ExpandModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -548,6 +584,9 @@ class ExpandModule(torch.nn.Module):
|
|||
def ExpandModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 1))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class OnesModuleInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -563,6 +602,8 @@ class OnesModuleInt(torch.nn.Module):
|
|||
def OnesModuleInt_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class OnesModuleFloat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -594,6 +635,7 @@ class OnesModuleFalsePinMemory(torch.nn.Module):
|
|||
def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ContiguousModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -681,6 +723,7 @@ class NumToTensorFloatModule(torch.nn.Module):
|
|||
def NumToTensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
# This test can be removed once we have one real op returning 3 float32 tensors
|
||||
class ReturnThreeTensorFloat32(torch.nn.Module):
|
||||
|
|
|
@ -42,6 +42,7 @@ from . import matmul
|
|||
from . import view
|
||||
from . import scalar
|
||||
from . import squeeze
|
||||
from . import slice_like
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
|
|
|
@ -0,0 +1,227 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SliceModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[0:5:1, 1:3:1, 2:4:1]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceModule())
|
||||
def SliceModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
|
||||
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:8, :5, 8:]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceOutOfUpperBoundIndexModule())
|
||||
def SliceOutOfUpperBoundIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SliceOutOfLowerBoundEndIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:-8,-7:,:]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceOutOfLowerBoundEndIndexModule())
|
||||
def SliceOutOfLowerBoundEndIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SliceOutOfLowerBoundStartIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[-8:3:1, 1:3:1, 2:4:1]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexModule())
|
||||
def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
|
||||
class SliceEndSleStartModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:0, 4:3, :-7]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceEndSleStartModule())
|
||||
def SliceEndSleStartModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
# This Test currently xfails due to https://github.com/llvm/torch-mlir/issues/448
|
||||
class SliceStartEqEndModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[5:5, 3:3, -1:]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceStartEqEndModule())
|
||||
def SliceStartEqEndModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4,7))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SliceSizeTwoStepModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[0:5:2, 0:3:2, 0:4:2]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceSizeTwoStepModule())
|
||||
def SliceSizeTwoStepModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10,5,17))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SliceNegIdxModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:-1, -2:-1]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceNegIdxModule())
|
||||
def SliceNegIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3,9))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SliceSingleIdxModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[0]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceSingleIdxModule())
|
||||
def SliceSingleIdxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,8))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SliceWholeTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:, :]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceWholeTensorModule())
|
||||
def SliceWholeTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,8))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SelectIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.select(0,0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SelectIntModule())
|
||||
def SelectIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (5,5)))
|
||||
|
||||
# ==============================================================================
|
||||
|
|
@ -17,8 +17,13 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"QuantizedMLP_basic",
|
||||
"IouOfModule_basic",
|
||||
}
|
||||
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
# Fails due to https://github.com/llvm/torch-mlir/issues/448
|
||||
SIZE_ZERO_TENSOR_XFAILS = {
|
||||
"SliceEndSleStartModule_basic",
|
||||
"SliceStartEqEndModule_basic",
|
||||
"SliceOutOfUpperBoundIndexModule_basic",
|
||||
}
|
||||
REFBACKEND_XFAIL_SET = set.union(COMMON_TORCH_MLIR_LOWERING_XFAILS, SIZE_ZERO_TENSOR_XFAILS)
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
|
|
|
@ -2712,6 +2712,107 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenSliceTensorOp : public OpConversionPattern<AtenSliceTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenSliceTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto input = adaptor.self();
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
int64_t resultRank = resultType.getRank();
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("unimplemented: dim is not constant");
|
||||
|
||||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||
Value dimSize = inputShape[dim];
|
||||
|
||||
auto adjustStartOrEnd = [&](Value startOrEndTorchType,
|
||||
Value startOrEndBuiltin, Value valueForNone) {
|
||||
if (startOrEndTorchType.getType().isa<Torch::NoneType>())
|
||||
return valueForNone;
|
||||
auto dimSizeAsInt = castIndexToInt(rewriter, loc, dimSize);
|
||||
Value startOrEndToPositive =
|
||||
toPositiveDimDynamic(rewriter, loc, startOrEndBuiltin, dimSizeAsInt);
|
||||
// startOrEnd < 0 ? 0 : startOrEnd
|
||||
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
||||
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, startOrEndToPositive, cst0);
|
||||
Value startOrEndAtLeastZero = rewriter.create<SelectOp>(
|
||||
loc, predDimSltZero, cst0, startOrEndToPositive);
|
||||
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
|
||||
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
|
||||
Value startOrEndBoundedByDimSize = rewriter.create<SelectOp>(
|
||||
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
|
||||
|
||||
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
|
||||
};
|
||||
|
||||
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
|
||||
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
|
||||
|
||||
int64_t step;
|
||||
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
|
||||
if (!op.step().getType().isa<Torch::NoneType>())
|
||||
return op->emitError("unimplemented: step is not constant");
|
||||
step = 1;
|
||||
}
|
||||
|
||||
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
|
||||
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
|
||||
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
|
||||
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
|
||||
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
|
||||
resultSize =
|
||||
rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
|
||||
|
||||
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
|
||||
resultShape[dim] = resultSize;
|
||||
|
||||
SmallVector<Value> offsets(inputType.getRank(), zero);
|
||||
SmallVector<Value> strides(inputType.getRank(), one);
|
||||
offsets[dim] = start;
|
||||
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
|
||||
|
||||
Value result = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, input, offsets, resultShape, strides);
|
||||
|
||||
// TODO: This code is for selectOp, remove once squeeze dim is added
|
||||
if (resultRank < inputType.getRank()) {
|
||||
SmallVector<ReassociationIndices> reassociation(resultRank);
|
||||
int64_t resultIdx = 0;
|
||||
for (auto i : llvm::seq<int64_t>(0, inputType.getRank())) {
|
||||
if (resultIdx < resultRank)
|
||||
reassociation[resultIdx].push_back(i);
|
||||
if (i != dim)
|
||||
resultIdx++;
|
||||
}
|
||||
result = rewriter.create<linalg::TensorCollapseShapeOp>(loc, result,
|
||||
reassociation);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {
|
||||
public:
|
||||
|
@ -3265,6 +3366,8 @@ public:
|
|||
patterns.add<ConvertAtenFill_ScalarOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNumelOp>();
|
||||
patterns.add<ConvertAtenNumelOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSliceTensorOp>();
|
||||
patterns.add<ConvertAtenSliceTensorOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -126,6 +126,26 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSelectIntOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value end =
|
||||
rewriter.create<AtenAddIntOp>(loc, one.getType(), op.index(), one);
|
||||
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(op, op.getResult().getType(),
|
||||
op.self(), op.dim(),
|
||||
op.index(), end, one);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
||||
// exp(x)/sum(exp(x)).
|
||||
template <typename OpTy>
|
||||
|
@ -487,6 +507,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenAddmmOp>();
|
||||
patterns.add<DecomposeAtenMeanOp>(context);
|
||||
target.addIllegalOp<AtenMeanOp>();
|
||||
patterns.add<DecomposeAtenSelectIntOp>(context);
|
||||
target.addIllegalOp<AtenSelectIntOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
||||
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
||||
|
|
|
@ -92,8 +92,8 @@ public:
|
|||
} else if (isa<AtenSqueezeOp, AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||
AtenTransposeIntOp, TensorStaticInfoCastOp,
|
||||
AtenBroadcastToOp, AtenToDtypeOp, AtenContiguousOp,
|
||||
AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
||||
AtenFill_ScalarOp>(op)) {
|
||||
AtenPermuteOp, AtenViewOp, AtenExpandOp, AtenFill_ScalarOp,
|
||||
AtenSliceTensorOp, AtenSelectIntOp>(op)) {
|
||||
// AtenContiguousOp might return a view, so this is conservatively
|
||||
// correct. We could potentially be more precise and identify the cases
|
||||
// that it does not return a view and treat those as having value
|
||||
|
|
|
@ -390,13 +390,11 @@ public:
|
|||
};
|
||||
return visitSliceLikeOp(indexSelect, operands, setDim);
|
||||
} else if (auto selectInt = dyn_cast<AtenSelectIntOp>(op)) {
|
||||
// Select one element from the target dim. All the other dims are the same
|
||||
// as input.
|
||||
// Slices along dim at index. Result shape same as input except dim is
|
||||
// removed.
|
||||
auto setDim = [](int64_t &targetDim, int64_t dim,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
targetDim = 1;
|
||||
};
|
||||
return visitSliceLikeOp(selectInt, operands, setDim);
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {};
|
||||
return visitSliceLikeOp(selectInt, operands, setDim, /*keepDim=*/false);
|
||||
} else if (auto sliceTensor = dyn_cast<AtenSliceTensorOp>(op)) {
|
||||
// Select several elements from the target dim according to the start,
|
||||
// end, step. All the other dims are the same as input.
|
||||
|
@ -540,7 +538,7 @@ private:
|
|||
template <typename OpTy>
|
||||
ChangeResult
|
||||
visitSliceLikeOp(OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands,
|
||||
SetDimSizeFn setDim);
|
||||
SetDimSizeFn setDim, bool keepDim = true);
|
||||
ChangeResult
|
||||
visitAtenGatherOp(AtenGatherOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
@ -1222,7 +1220,7 @@ ChangeResult TypeAnalyzer::visitTypeConversionOp(
|
|||
template <typename OpTy>
|
||||
ChangeResult TypeAnalyzer::visitSliceLikeOp(
|
||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands,
|
||||
SetDimSizeFn setDim) {
|
||||
SetDimSizeFn setDim, bool keepDim) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
|
@ -1248,6 +1246,8 @@ ChangeResult TypeAnalyzer::visitSliceLikeOp(
|
|||
}
|
||||
knowledge.sizes = input.sizes;
|
||||
setDim(knowledge.sizes[dim], dim, operands);
|
||||
if (!keepDim)
|
||||
knowledge.sizes.erase(knowledge.sizes.begin() + dim);
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
|
|
|
@ -750,8 +750,8 @@ builtin.func @torch.aten.index_select$unknown_dim(%input: !torch.tensor<[2,3,4],
|
|||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||
// CHECK-SAME: %[[INDEX:.*]]: !torch.int) -> !torch.tensor {
|
||||
// CHECK: %[[DIM:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RET:.*]] = torch.aten.select.int %[[INPUT]], %[[DIM]], %[[INDEX]] : !torch.tensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.tensor<[2,1,4],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,1,4],f32> to !torch.tensor
|
||||
// CHECK: %[[RET:.*]] = torch.aten.select.int %[[INPUT]], %[[DIM]], %[[INDEX]] : !torch.tensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.tensor<[2,4],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[2,4],f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
|
||||
builtin.func @torch.aten.select.int(%input: !torch.tensor<[2,3,4], f32>, %index: !torch.int) -> !torch.tensor {
|
||||
|
|
Loading…
Reference in New Issue