mirror of https://github.com/llvm/torch-mlir
[torch] Support implicit batch for index_put (#3128)
If there is only a single value scattered there can be an implicit batch dimension. This includes a check for the implicit batch dimension when reshaping the update tensor. It includes an e2e test to verify correctness.pull/3149/head
parent
d4a30b7e67
commit
a1fe307a76
|
@ -778,12 +778,18 @@ public:
|
|||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value one = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
llvm::SmallVector<int64_t> valuesShape{valuesType.getSizes().front()};
|
||||
llvm::SmallVector<int64_t> valuesShape;
|
||||
llvm::SmallVector<Value> valuesDims;
|
||||
valuesDims.push_back(
|
||||
rewriter.create<Torch::AtenSizeIntOp>(loc, values, zero));
|
||||
int vDim = 0;
|
||||
|
||||
if (optionalIndicesCount + valuesType.getSizes().size() >
|
||||
inputType.getSizes().size()) {
|
||||
valuesShape.push_back(valuesType.getSizes().front());
|
||||
valuesDims.push_back(
|
||||
rewriter.create<Torch::AtenSizeIntOp>(loc, values, zero));
|
||||
vDim++;
|
||||
}
|
||||
|
||||
int vDim = 1;
|
||||
for (int i = 0, s = inputType.getSizes().size(); i < s; ++i) {
|
||||
if (i < optionalIndicesCount &&
|
||||
!isa<Torch::NoneType>(optionalIndicesList[i].getType())) {
|
||||
|
|
|
@ -1731,6 +1731,7 @@ ONNX_XFAIL_SET = {
|
|||
"HardtanhBackward_basic",
|
||||
"IndexPutImpl1DFloatAccumulateModule_basic",
|
||||
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||
"IndexPutImpl2DImplicitModule_basic",
|
||||
"IndexPutImpl1DIntAccumulateModule_basic",
|
||||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
"IndexPutImpl2DFloatAccumulateModule_basic",
|
||||
|
|
|
@ -61,6 +61,30 @@ class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module):
|
|||
def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8))
|
||||
|
||||
class IndexPutImpl2DImplicitModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([10, 8], torch.float32, True),
|
||||
([1], torch.int64, True),
|
||||
([8], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||
value,
|
||||
accumulate=False,
|
||||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: IndexPutImpl2DImplicitModule())
|
||||
def IndexPutImpl2DImplicitModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), tu.randint(1, high=4), tu.rand(8))
|
||||
|
||||
class IndexPutImpl2DNoneIndexStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue