mirror of https://github.com/llvm/torch-mlir
Add support for multi-dim input to `index_put_impl` (#722)
This commit adds support for multi-dimensional tensors as input to the `_index_put_impl_` op. The support was to some degree already there, since `ScatterOp` already supports multi-dimensional tensors. This commit also adds a bit more error checking to `index_put` and refactors the code for creating `ScatterOp`s to mimic the way one would make a `Linalg::GenericOp`.pull/726/head snapshot-20220331.361
parent
ccf924d3df
commit
51d4d55f8a
|
@ -35,6 +35,9 @@ Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
|||
// Helper to convert a tensor to a specific scalar type.
|
||||
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
||||
Type dtype);
|
||||
// Helper funtion to get rank of `Base tensor type`.
|
||||
// -1 is returned if the tensorRank can't be determined.
|
||||
int getTensorRank(Value tensor);
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
|
|
|
@ -48,6 +48,28 @@ using namespace mlir::torch::TMTensor;
|
|||
// that these patterns become mostly mechanical associations of
|
||||
// "aten.foo -> linalg.foo".
|
||||
|
||||
static Value createTMTensorScatterOp(
|
||||
OpBuilder &b, Location loc, Value updates, Value indices, Value original,
|
||||
bool uniqueIndices,
|
||||
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
|
||||
auto originalTensorType = original.getType().cast<RankedTensorType>();
|
||||
Type originalElementType = originalTensorType.getElementType();
|
||||
auto scatterOp = b.create<TMTensor::ScatterOp>(
|
||||
loc, originalTensorType, ValueRange{updates, indices},
|
||||
ValueRange{original}, uniqueIndices);
|
||||
|
||||
Region &scatterOpRegion = scatterOp.region();
|
||||
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
|
||||
scatterOpBlock.addArguments({originalElementType, originalElementType},
|
||||
{loc, loc});
|
||||
OpBuilder regionBuilder(scatterOpRegion);
|
||||
auto blockArgs = scatterOpBlock.getArguments();
|
||||
Value updatesElement = blockArgs[0];
|
||||
Value originalElement = blockArgs[1];
|
||||
bodyBuild(regionBuilder, loc, updatesElement, originalElement);
|
||||
return scatterOp->getResult(0);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
|
||||
// non-negative ints.
|
||||
|
@ -88,7 +110,7 @@ public:
|
|||
// TODO: Incorporate the weight argument.
|
||||
if (!weights.getType().isa<mlir::torch::Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented, the weights operand is not incorporated.");
|
||||
op, "Unimplemented: the weights operand is not incorporated.");
|
||||
|
||||
// Finding the maximum value in the input tensor.
|
||||
SmallVector<int64_t> maxTensorSizes;
|
||||
|
@ -129,9 +151,9 @@ public:
|
|||
indices = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||
|
||||
Type resultElemType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Type resultElemType = resultType.getElementType();
|
||||
|
||||
SmallVector<Value, 1> inputSizeDynamic =
|
||||
getTensorSizesUntilDim(rewriter, loc, input, 0);
|
||||
|
@ -152,25 +174,14 @@ public:
|
|||
Value bincountTensor = createInitTensor(rewriter, loc, {bincountSize},
|
||||
resultElemType, constantZero);
|
||||
|
||||
auto scatterOp = rewriter.create<TMTensor::ScatterOp>(
|
||||
loc, bincountTensor.getType(), ValueRange{updatesTensor, indices},
|
||||
ValueRange{bincountTensor},
|
||||
/*unique_indices=*/false);
|
||||
|
||||
Region &scatterOpRegion = scatterOp.region();
|
||||
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
|
||||
scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType},
|
||||
{loc, loc});
|
||||
auto blockArgs = scatterOpBlock.getArguments();
|
||||
|
||||
// Creating an add instruction inside the scatter op region to increment the
|
||||
// frequency counter with one.
|
||||
OpBuilder regionBuilder(scatterOpRegion);
|
||||
Value add = regionBuilder.create<arith::AddIOp>(loc,
|
||||
/*bincount=*/blockArgs[1],
|
||||
constantOne);
|
||||
regionBuilder.create<TMTensor::YieldOp>(loc, add);
|
||||
rewriter.replaceOp(op, scatterOp->getResult(0));
|
||||
Value scatterOp = createTMTensorScatterOp(
|
||||
rewriter, loc, updatesTensor, indices, bincountTensor,
|
||||
/*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value _, Value bincountElem) {
|
||||
Value add = b.create<arith::AddIOp>(loc, bincountElem, constantOne);
|
||||
b.create<TMTensor::YieldOp>(loc, add);
|
||||
});
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -192,14 +203,8 @@ public:
|
|||
Value values = adaptor.values();
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType valuesType = values.getType().cast<RankedTensorType>();
|
||||
Type resultElemType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
|
||||
// TODO: Add support for the input with rank other than one.
|
||||
if (inputType.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: input rank other than one is not supported");
|
||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
// The unsafe should be either `False` or `none`.
|
||||
if (!op.unsafe().getType().isa<Torch::NoneType>()) {
|
||||
|
@ -231,23 +236,14 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Indices list size should not be greater than the input rank.");
|
||||
|
||||
// TODO: Add support for cases with indices list size smaller than the input
|
||||
// rank.
|
||||
if ((int64_t)indicesList.size() < inputType.getRank())
|
||||
// TODO: Add support for cases with indices list size not equal to 1.
|
||||
if (indicesList.size() != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented, Indices list size smaller than input rank");
|
||||
op, "Unimplemented: Indices list size != 1");
|
||||
Value indexTensor = indicesList[0];
|
||||
|
||||
if (indicesList[0].getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Indices tensor must not be none.");
|
||||
|
||||
// TODO: Add support for the index with rank other than one.
|
||||
int64_t indexRank = typeConverter->convertType(indicesList[0].getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getRank();
|
||||
if (indexRank != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: index rank other than one is not supported");
|
||||
if (indexTensor.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(op, "Index tensor must not be None.");
|
||||
|
||||
// Creating a tm_tensor.scatter op with the following mapping:
|
||||
// 1.) Index tensor from the `indicesList` maps to the indices in scatter
|
||||
|
@ -255,48 +251,55 @@ public:
|
|||
// to i32 as required for the scatter op.
|
||||
// 2.) `values` is mapped to `updates` in scatter op.
|
||||
// 3.) `input` is mapped to `original` in scatter op.
|
||||
ValueTensorType indexType =
|
||||
indicesList[0].getType().cast<ValueTensorType>();
|
||||
SmallVector<int64_t> expandedIndexSizes{indexType.getSizes()[0], 1};
|
||||
ValueTensorType expandedIndexType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedIndexSizes), indexType.getDtype());
|
||||
if (getTensorRank(indexTensor) != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: index tensor with rank != 1 is not supported");
|
||||
auto indexTensorType = indexTensor.getType().cast<BaseTensorType>();
|
||||
int64_t indexTensorSize = indexTensorType.getSizes()[0];
|
||||
SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1};
|
||||
ValueTensorType expandedIndexTensorType = ValueTensorType::get(
|
||||
context, llvm::makeArrayRef(expandedIndexTensorSizes),
|
||||
indexTensorType.getDtype());
|
||||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(
|
||||
loc, expandedIndexType, indicesList[0], torchCstOne);
|
||||
loc, expandedIndexTensorType, indexTensor, torchCstOne);
|
||||
|
||||
// Converting the index element type to i32.
|
||||
// `TMTensor::ScatterOp` expects indices of element type i32.
|
||||
Value indices = convertTensorToDtype(
|
||||
rewriter, loc, expandedIndexTensor,
|
||||
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
||||
indices = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||
|
||||
auto scatterOp = rewriter.create<TMTensor::ScatterOp>(
|
||||
loc, input.getType(), ValueRange{values, indices}, ValueRange{input},
|
||||
/*unique_indices=*/false);
|
||||
|
||||
Region &scatterOpRegion = scatterOp.region();
|
||||
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
|
||||
scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType},
|
||||
{loc, loc});
|
||||
auto blockArgs = scatterOpBlock.getArguments();
|
||||
|
||||
OpBuilder regionBuilder(scatterOpRegion);
|
||||
Value update = blockArgs[0];
|
||||
Value original = blockArgs[1];
|
||||
Value yieldValue = update;
|
||||
// Create an add instruction inside the scatter op region to increment the
|
||||
// `original` value with the value from `updates` if the accumulate flag is
|
||||
// true.
|
||||
bool invalidInputTypeFound = false;
|
||||
Value scatterOp = createTMTensorScatterOp(
|
||||
rewriter, loc, values, indices, input, /*uniqueIndices=*/false,
|
||||
[&](OpBuilder &b, Location loc, Value valuesElement,
|
||||
Value inputElement) {
|
||||
Value yieldValue = valuesElement;
|
||||
if (accumulate) {
|
||||
if (inputType.getElementType().isa<mlir::IntegerType>())
|
||||
yieldValue = regionBuilder.create<arith::AddIOp>(loc, original, update);
|
||||
else if (inputType.getElementType().isa<mlir::FloatType>())
|
||||
yieldValue = regionBuilder.create<arith::AddFOp>(loc, original, update);
|
||||
if (inputElement.getType().isa<mlir::IntegerType>()) {
|
||||
yieldValue =
|
||||
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
|
||||
} else if (inputElement.getType().isa<mlir::FloatType>()) {
|
||||
yieldValue =
|
||||
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
|
||||
} else {
|
||||
invalidInputTypeFound = true;
|
||||
return;
|
||||
}
|
||||
regionBuilder.create<TMTensor::YieldOp>(loc, yieldValue);
|
||||
rewriter.replaceOp(op, scatterOp->getResult(0));
|
||||
}
|
||||
b.create<TMTensor::YieldOp>(loc, yieldValue);
|
||||
});
|
||||
|
||||
if (invalidInputTypeFound) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"unimplemented: input tensor must be of integer type or float type");
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -21,19 +21,6 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// Helper funtion to get rank of `Base tensor type`.
|
||||
// -1 is returned if the tensorRank can't be determined.
|
||||
static int getTensorRank(Value tensor) {
|
||||
int tensorRank = -1;
|
||||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
||||
|
||||
if (tensorType.hasSizes()) {
|
||||
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
|
||||
tensorRank = tensorShape.size();
|
||||
}
|
||||
return tensorRank;
|
||||
}
|
||||
|
||||
// Helper function to compute the return type of the reduction function.
|
||||
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
|
||||
// the input tensor.
|
||||
|
|
|
@ -95,3 +95,14 @@ Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
|||
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
|
||||
return converted;
|
||||
}
|
||||
|
||||
int Torch::getTensorRank(Value tensor) {
|
||||
int tensorRank = -1;
|
||||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
||||
|
||||
if (tensorType.hasSizes()) {
|
||||
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
|
||||
tensorRank = tensorShape.size();
|
||||
}
|
||||
return tensorRank;
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module):
|
||||
class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -29,13 +29,61 @@ class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module):
|
|||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatNonAccumulateModule())
|
||||
def IndexPutImplOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: IndexPutImpl1DFloatNonAccumulateModule())
|
||||
def IndexPutImpl1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(100), torch.randint(100, (250,)),
|
||||
tu.rand(250))
|
||||
|
||||
|
||||
class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module):
|
||||
class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index,), value,
|
||||
accumulate=False,
|
||||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImpl2DFloatNonAccumulateModule())
|
||||
def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), torch.randint(4, (5,)),
|
||||
tu.rand(5, 8))
|
||||
|
||||
|
||||
class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index,), value,
|
||||
accumulate=False,
|
||||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule())
|
||||
def IndexPutImpl3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5,)),
|
||||
tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
class IndexPutImpl1DIntNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -53,13 +101,13 @@ class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module):
|
|||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimIntNonAccumulateModule())
|
||||
def IndexPutImplOneDimIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: IndexPutImpl1DIntNonAccumulateModule())
|
||||
def IndexPutImpl1DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
|
||||
torch.randint(10000, (300,)))
|
||||
|
||||
|
||||
class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module):
|
||||
class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -79,13 +127,61 @@ class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module):
|
|||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatAccumulateModule())
|
||||
def IndexPutImplOneDimFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: IndexPutImpl1DFloatAccumulateModule())
|
||||
def IndexPutImpl1DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1000), torch.randint(10, (500,)),
|
||||
tu.rand(500))
|
||||
|
||||
|
||||
class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module):
|
||||
class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input.clone(), (index,), value,
|
||||
accumulate=True,
|
||||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImpl2DFloatAccumulateModule())
|
||||
def IndexPutImpl2DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8), torch.randint(4, (5,)),
|
||||
tu.rand(5, 8))
|
||||
|
||||
|
||||
class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input.clone(), (index,), value,
|
||||
accumulate=True,
|
||||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImpl3DFloatAccumulateModule())
|
||||
def IndexPutImpl3DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5,)),
|
||||
tu.rand(5, 8, 6))
|
||||
|
||||
|
||||
class IndexPutImpl1DIntAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -105,14 +201,14 @@ class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module):
|
|||
unsafe=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimIntAccumulateModule())
|
||||
def IndexPutImplOneDimIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule())
|
||||
def IndexPutImpl1DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
|
||||
torch.randint(1000, (10,)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module):
|
||||
class IndexPut1DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -128,13 +224,13 @@ class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module):
|
|||
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutOneDimFloatNonAccumulateModule())
|
||||
def IndexPutOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: IndexPut1DFloatNonAccumulateModule())
|
||||
def IndexPut1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(100), torch.randint(100, (250,)),
|
||||
tu.rand(250))
|
||||
|
||||
|
||||
class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module):
|
||||
class IndexPut1DIntNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -150,13 +246,13 @@ class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module):
|
|||
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutOneDimIntNonAccumulateModule())
|
||||
def IndexPutOneDimIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: IndexPut1DIntNonAccumulateModule())
|
||||
def IndexPut1DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
|
||||
torch.randint(10000, (300,)))
|
||||
|
||||
|
||||
class IndexPutOneDimFloatAccumulateModule(torch.nn.Module):
|
||||
class IndexPut1DFloatAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -172,13 +268,13 @@ class IndexPutOneDimFloatAccumulateModule(torch.nn.Module):
|
|||
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutOneDimFloatAccumulateModule())
|
||||
def IndexPutOneDimFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: IndexPut1DFloatAccumulateModule())
|
||||
def IndexPut1DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1000), torch.randint(10, (500,)),
|
||||
tu.rand(500))
|
||||
|
||||
|
||||
class IndexPutOneDimIntAccumulateModule(torch.nn.Module):
|
||||
class IndexPut1DIntAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -194,7 +290,7 @@ class IndexPutOneDimIntAccumulateModule(torch.nn.Module):
|
|||
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IndexPutOneDimIntAccumulateModule())
|
||||
def IndexPutOneDimIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: IndexPut1DIntAccumulateModule())
|
||||
def IndexPut1DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
|
||||
torch.randint(1000, (10,)))
|
||||
|
|
Loading…
Reference in New Issue