[torch] Support implicit batch for index_put (#3128)

If there is only a single value scattered there can be an implicit batch
dimension. This includes a check for the implicit batch dimension when
reshaping the update tensor. It includes an e2e test to verify
correctness.
pull/3149/head
Rob Suderman 2024-04-11 10:18:03 -07:00 committed by GitHub
parent d4a30b7e67
commit a1fe307a76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 4 deletions

View File

@ -778,12 +778,18 @@ public:
loc, rewriter.getI64IntegerAttr(0));
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
llvm::SmallVector<int64_t> valuesShape{valuesType.getSizes().front()};
llvm::SmallVector<int64_t> valuesShape;
llvm::SmallVector<Value> valuesDims;
valuesDims.push_back(
rewriter.create<Torch::AtenSizeIntOp>(loc, values, zero));
int vDim = 0;
if (optionalIndicesCount + valuesType.getSizes().size() >
inputType.getSizes().size()) {
valuesShape.push_back(valuesType.getSizes().front());
valuesDims.push_back(
rewriter.create<Torch::AtenSizeIntOp>(loc, values, zero));
vDim++;
}
int vDim = 1;
for (int i = 0, s = inputType.getSizes().size(); i < s; ++i) {
if (i < optionalIndicesCount &&
!isa<Torch::NoneType>(optionalIndicesList[i].getType())) {

View File

@ -1731,6 +1731,7 @@ ONNX_XFAIL_SET = {
"HardtanhBackward_basic",
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl2DImplicitModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexPutImpl2DFloatAccumulateModule_basic",

View File

@ -61,6 +61,30 @@ class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module):
def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8), tu.randint(5, high=4), tu.rand(5, 8))
class IndexPutImpl2DImplicitModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([10, 8], torch.float32, True),
([1], torch.int64, True),
([8], torch.float32, True),
])
def forward(self, input, index, value):
return torch.ops.aten._index_put_impl_(input, (index, ),
value,
accumulate=False,
unsafe=False)
@register_test_case(
module_factory=lambda: IndexPutImpl2DImplicitModule())
def IndexPutImpl2DImplicitModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8), tu.randint(1, high=4), tu.rand(8))
class IndexPutImpl2DNoneIndexStaticModule(torch.nn.Module):
def __init__(self):