mirror of https://github.com/llvm/torch-mlir
Add support for multi-dim input to `index_put_impl` (#722)
This commit adds support for multi-dimensional tensors as input to the `_index_put_impl_` op. The support was to some degree already there, since `ScatterOp` already supports multi-dimensional tensors. This commit also adds a bit more error checking to `index_put` and refactors the code for creating `ScatterOp`s to mimic the way one would make a `Linalg::GenericOp`.pull/726/head snapshot-20220331.361
parent
ccf924d3df
commit
51d4d55f8a
|
@ -35,6 +35,9 @@ Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||||
// Helper to convert a tensor to a specific scalar type.
|
// Helper to convert a tensor to a specific scalar type.
|
||||||
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
||||||
Type dtype);
|
Type dtype);
|
||||||
|
// Helper funtion to get rank of `Base tensor type`.
|
||||||
|
// -1 is returned if the tensorRank can't be determined.
|
||||||
|
int getTensorRank(Value tensor);
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -48,6 +48,28 @@ using namespace mlir::torch::TMTensor;
|
||||||
// that these patterns become mostly mechanical associations of
|
// that these patterns become mostly mechanical associations of
|
||||||
// "aten.foo -> linalg.foo".
|
// "aten.foo -> linalg.foo".
|
||||||
|
|
||||||
|
static Value createTMTensorScatterOp(
|
||||||
|
OpBuilder &b, Location loc, Value updates, Value indices, Value original,
|
||||||
|
bool uniqueIndices,
|
||||||
|
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
|
||||||
|
auto originalTensorType = original.getType().cast<RankedTensorType>();
|
||||||
|
Type originalElementType = originalTensorType.getElementType();
|
||||||
|
auto scatterOp = b.create<TMTensor::ScatterOp>(
|
||||||
|
loc, originalTensorType, ValueRange{updates, indices},
|
||||||
|
ValueRange{original}, uniqueIndices);
|
||||||
|
|
||||||
|
Region &scatterOpRegion = scatterOp.region();
|
||||||
|
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
|
||||||
|
scatterOpBlock.addArguments({originalElementType, originalElementType},
|
||||||
|
{loc, loc});
|
||||||
|
OpBuilder regionBuilder(scatterOpRegion);
|
||||||
|
auto blockArgs = scatterOpBlock.getArguments();
|
||||||
|
Value updatesElement = blockArgs[0];
|
||||||
|
Value originalElement = blockArgs[1];
|
||||||
|
bodyBuild(regionBuilder, loc, updatesElement, originalElement);
|
||||||
|
return scatterOp->getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
|
// aten::bincount op counts the frequency of each value in a 1-d input tensor of
|
||||||
// non-negative ints.
|
// non-negative ints.
|
||||||
|
@ -88,7 +110,7 @@ public:
|
||||||
// TODO: Incorporate the weight argument.
|
// TODO: Incorporate the weight argument.
|
||||||
if (!weights.getType().isa<mlir::torch::Torch::NoneType>())
|
if (!weights.getType().isa<mlir::torch::Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented, the weights operand is not incorporated.");
|
op, "Unimplemented: the weights operand is not incorporated.");
|
||||||
|
|
||||||
// Finding the maximum value in the input tensor.
|
// Finding the maximum value in the input tensor.
|
||||||
SmallVector<int64_t> maxTensorSizes;
|
SmallVector<int64_t> maxTensorSizes;
|
||||||
|
@ -129,9 +151,9 @@ public:
|
||||||
indices = typeConverter->materializeTargetConversion(
|
indices = typeConverter->materializeTargetConversion(
|
||||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||||
|
|
||||||
Type resultElemType = typeConverter->convertType(op->getResult(0).getType())
|
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>();
|
||||||
.getElementType();
|
Type resultElemType = resultType.getElementType();
|
||||||
|
|
||||||
SmallVector<Value, 1> inputSizeDynamic =
|
SmallVector<Value, 1> inputSizeDynamic =
|
||||||
getTensorSizesUntilDim(rewriter, loc, input, 0);
|
getTensorSizesUntilDim(rewriter, loc, input, 0);
|
||||||
|
@ -152,25 +174,14 @@ public:
|
||||||
Value bincountTensor = createInitTensor(rewriter, loc, {bincountSize},
|
Value bincountTensor = createInitTensor(rewriter, loc, {bincountSize},
|
||||||
resultElemType, constantZero);
|
resultElemType, constantZero);
|
||||||
|
|
||||||
auto scatterOp = rewriter.create<TMTensor::ScatterOp>(
|
Value scatterOp = createTMTensorScatterOp(
|
||||||
loc, bincountTensor.getType(), ValueRange{updatesTensor, indices},
|
rewriter, loc, updatesTensor, indices, bincountTensor,
|
||||||
ValueRange{bincountTensor},
|
/*uniqueIndices=*/false,
|
||||||
/*unique_indices=*/false);
|
[&](OpBuilder &b, Location loc, Value _, Value bincountElem) {
|
||||||
|
Value add = b.create<arith::AddIOp>(loc, bincountElem, constantOne);
|
||||||
Region &scatterOpRegion = scatterOp.region();
|
b.create<TMTensor::YieldOp>(loc, add);
|
||||||
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
|
});
|
||||||
scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType},
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||||
{loc, loc});
|
|
||||||
auto blockArgs = scatterOpBlock.getArguments();
|
|
||||||
|
|
||||||
// Creating an add instruction inside the scatter op region to increment the
|
|
||||||
// frequency counter with one.
|
|
||||||
OpBuilder regionBuilder(scatterOpRegion);
|
|
||||||
Value add = regionBuilder.create<arith::AddIOp>(loc,
|
|
||||||
/*bincount=*/blockArgs[1],
|
|
||||||
constantOne);
|
|
||||||
regionBuilder.create<TMTensor::YieldOp>(loc, add);
|
|
||||||
rewriter.replaceOp(op, scatterOp->getResult(0));
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -192,14 +203,8 @@ public:
|
||||||
Value values = adaptor.values();
|
Value values = adaptor.values();
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||||
RankedTensorType valuesType = values.getType().cast<RankedTensorType>();
|
RankedTensorType valuesType = values.getType().cast<RankedTensorType>();
|
||||||
Type resultElemType = typeConverter->convertType(op->getResult(0).getType())
|
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>();
|
||||||
.getElementType();
|
|
||||||
|
|
||||||
// TODO: Add support for the input with rank other than one.
|
|
||||||
if (inputType.getRank() != 1)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: input rank other than one is not supported");
|
|
||||||
|
|
||||||
// The unsafe should be either `False` or `none`.
|
// The unsafe should be either `False` or `none`.
|
||||||
if (!op.unsafe().getType().isa<Torch::NoneType>()) {
|
if (!op.unsafe().getType().isa<Torch::NoneType>()) {
|
||||||
|
@ -231,23 +236,14 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Indices list size should not be greater than the input rank.");
|
op, "Indices list size should not be greater than the input rank.");
|
||||||
|
|
||||||
// TODO: Add support for cases with indices list size smaller than the input
|
// TODO: Add support for cases with indices list size not equal to 1.
|
||||||
// rank.
|
if (indicesList.size() != 1)
|
||||||
if ((int64_t)indicesList.size() < inputType.getRank())
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented, Indices list size smaller than input rank");
|
op, "Unimplemented: Indices list size != 1");
|
||||||
|
Value indexTensor = indicesList[0];
|
||||||
|
|
||||||
if (indicesList[0].getType().isa<Torch::NoneType>())
|
if (indexTensor.getType().isa<Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op, "Index tensor must not be None.");
|
||||||
"Indices tensor must not be none.");
|
|
||||||
|
|
||||||
// TODO: Add support for the index with rank other than one.
|
|
||||||
int64_t indexRank = typeConverter->convertType(indicesList[0].getType())
|
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getRank();
|
|
||||||
if (indexRank != 1)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: index rank other than one is not supported");
|
|
||||||
|
|
||||||
// Creating a tm_tensor.scatter op with the following mapping:
|
// Creating a tm_tensor.scatter op with the following mapping:
|
||||||
// 1.) Index tensor from the `indicesList` maps to the indices in scatter
|
// 1.) Index tensor from the `indicesList` maps to the indices in scatter
|
||||||
|
@ -255,48 +251,55 @@ public:
|
||||||
// to i32 as required for the scatter op.
|
// to i32 as required for the scatter op.
|
||||||
// 2.) `values` is mapped to `updates` in scatter op.
|
// 2.) `values` is mapped to `updates` in scatter op.
|
||||||
// 3.) `input` is mapped to `original` in scatter op.
|
// 3.) `input` is mapped to `original` in scatter op.
|
||||||
ValueTensorType indexType =
|
if (getTensorRank(indexTensor) != 1)
|
||||||
indicesList[0].getType().cast<ValueTensorType>();
|
return rewriter.notifyMatchFailure(
|
||||||
SmallVector<int64_t> expandedIndexSizes{indexType.getSizes()[0], 1};
|
op, "unimplemented: index tensor with rank != 1 is not supported");
|
||||||
ValueTensorType expandedIndexType = ValueTensorType::get(
|
auto indexTensorType = indexTensor.getType().cast<BaseTensorType>();
|
||||||
context, llvm::makeArrayRef(expandedIndexSizes), indexType.getDtype());
|
int64_t indexTensorSize = indexTensorType.getSizes()[0];
|
||||||
|
SmallVector<int64_t> expandedIndexTensorSizes{indexTensorSize, 1};
|
||||||
|
ValueTensorType expandedIndexTensorType = ValueTensorType::get(
|
||||||
|
context, llvm::makeArrayRef(expandedIndexTensorSizes),
|
||||||
|
indexTensorType.getDtype());
|
||||||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(1));
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(
|
Value expandedIndexTensor = rewriter.create<AtenUnsqueezeOp>(
|
||||||
loc, expandedIndexType, indicesList[0], torchCstOne);
|
loc, expandedIndexTensorType, indexTensor, torchCstOne);
|
||||||
|
|
||||||
// Converting the index element type to i32.
|
// `TMTensor::ScatterOp` expects indices of element type i32.
|
||||||
Value indices = convertTensorToDtype(
|
Value indices = convertTensorToDtype(
|
||||||
rewriter, loc, expandedIndexTensor,
|
rewriter, loc, expandedIndexTensor,
|
||||||
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
||||||
indices = typeConverter->materializeTargetConversion(
|
indices = typeConverter->materializeTargetConversion(
|
||||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||||
|
|
||||||
auto scatterOp = rewriter.create<TMTensor::ScatterOp>(
|
bool invalidInputTypeFound = false;
|
||||||
loc, input.getType(), ValueRange{values, indices}, ValueRange{input},
|
Value scatterOp = createTMTensorScatterOp(
|
||||||
/*unique_indices=*/false);
|
rewriter, loc, values, indices, input, /*uniqueIndices=*/false,
|
||||||
|
[&](OpBuilder &b, Location loc, Value valuesElement,
|
||||||
|
Value inputElement) {
|
||||||
|
Value yieldValue = valuesElement;
|
||||||
|
if (accumulate) {
|
||||||
|
if (inputElement.getType().isa<mlir::IntegerType>()) {
|
||||||
|
yieldValue =
|
||||||
|
b.create<arith::AddIOp>(loc, inputElement, valuesElement);
|
||||||
|
} else if (inputElement.getType().isa<mlir::FloatType>()) {
|
||||||
|
yieldValue =
|
||||||
|
b.create<arith::AddFOp>(loc, inputElement, valuesElement);
|
||||||
|
} else {
|
||||||
|
invalidInputTypeFound = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.create<TMTensor::YieldOp>(loc, yieldValue);
|
||||||
|
});
|
||||||
|
|
||||||
Region &scatterOpRegion = scatterOp.region();
|
if (invalidInputTypeFound) {
|
||||||
auto &scatterOpBlock = scatterOpRegion.emplaceBlock();
|
return rewriter.notifyMatchFailure(
|
||||||
scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType},
|
op,
|
||||||
{loc, loc});
|
"unimplemented: input tensor must be of integer type or float type");
|
||||||
auto blockArgs = scatterOpBlock.getArguments();
|
|
||||||
|
|
||||||
OpBuilder regionBuilder(scatterOpRegion);
|
|
||||||
Value update = blockArgs[0];
|
|
||||||
Value original = blockArgs[1];
|
|
||||||
Value yieldValue = update;
|
|
||||||
// Create an add instruction inside the scatter op region to increment the
|
|
||||||
// `original` value with the value from `updates` if the accumulate flag is
|
|
||||||
// true.
|
|
||||||
if (accumulate) {
|
|
||||||
if (inputType.getElementType().isa<mlir::IntegerType>())
|
|
||||||
yieldValue = regionBuilder.create<arith::AddIOp>(loc, original, update);
|
|
||||||
else if (inputType.getElementType().isa<mlir::FloatType>())
|
|
||||||
yieldValue = regionBuilder.create<arith::AddFOp>(loc, original, update);
|
|
||||||
}
|
}
|
||||||
regionBuilder.create<TMTensor::YieldOp>(loc, yieldValue);
|
|
||||||
rewriter.replaceOp(op, scatterOp->getResult(0));
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -21,19 +21,6 @@ using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
// Helper funtion to get rank of `Base tensor type`.
|
|
||||||
// -1 is returned if the tensorRank can't be determined.
|
|
||||||
static int getTensorRank(Value tensor) {
|
|
||||||
int tensorRank = -1;
|
|
||||||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
|
||||||
|
|
||||||
if (tensorType.hasSizes()) {
|
|
||||||
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
|
|
||||||
tensorRank = tensorShape.size();
|
|
||||||
}
|
|
||||||
return tensorRank;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to compute the return type of the reduction function.
|
// Helper function to compute the return type of the reduction function.
|
||||||
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
|
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
|
||||||
// the input tensor.
|
// the input tensor.
|
||||||
|
|
|
@ -95,3 +95,14 @@ Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
||||||
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
|
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
|
||||||
return converted;
|
return converted;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int Torch::getTensorRank(Value tensor) {
|
||||||
|
int tensorRank = -1;
|
||||||
|
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
||||||
|
|
||||||
|
if (tensorType.hasSizes()) {
|
||||||
|
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
|
||||||
|
tensorRank = tensorShape.size();
|
||||||
|
}
|
||||||
|
return tensorRank;
|
||||||
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module):
|
class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -29,13 +29,61 @@ class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module):
|
||||||
unsafe=False)
|
unsafe=False)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatNonAccumulateModule())
|
@register_test_case(module_factory=lambda: IndexPutImpl1DFloatNonAccumulateModule())
|
||||||
def IndexPutImplOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
def IndexPutImpl1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(100), torch.randint(100, (250,)),
|
module.forward(tu.rand(100), torch.randint(100, (250,)),
|
||||||
tu.rand(250))
|
tu.rand(250))
|
||||||
|
|
||||||
|
|
||||||
class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module):
|
class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, index, value):
|
||||||
|
return torch.ops.aten._index_put_impl_(input, (index,), value,
|
||||||
|
accumulate=False,
|
||||||
|
unsafe=False)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IndexPutImpl2DFloatNonAccumulateModule())
|
||||||
|
def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 8), torch.randint(4, (5,)),
|
||||||
|
tu.rand(5, 8))
|
||||||
|
|
||||||
|
|
||||||
|
class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, index, value):
|
||||||
|
return torch.ops.aten._index_put_impl_(input, (index,), value,
|
||||||
|
accumulate=False,
|
||||||
|
unsafe=False)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule())
|
||||||
|
def IndexPutImpl3DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5,)),
|
||||||
|
tu.rand(5, 8, 6))
|
||||||
|
|
||||||
|
|
||||||
|
class IndexPutImpl1DIntNonAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -53,13 +101,13 @@ class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module):
|
||||||
unsafe=False)
|
unsafe=False)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimIntNonAccumulateModule())
|
@register_test_case(module_factory=lambda: IndexPutImpl1DIntNonAccumulateModule())
|
||||||
def IndexPutImplOneDimIntNonAccumulateModule_basic(module, tu: TestUtils):
|
def IndexPutImpl1DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
|
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
|
||||||
torch.randint(10000, (300,)))
|
torch.randint(10000, (300,)))
|
||||||
|
|
||||||
|
|
||||||
class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module):
|
class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -79,13 +127,61 @@ class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module):
|
||||||
unsafe=False)
|
unsafe=False)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatAccumulateModule())
|
@register_test_case(module_factory=lambda: IndexPutImpl1DFloatAccumulateModule())
|
||||||
def IndexPutImplOneDimFloatAccumulateModule_basic(module, tu: TestUtils):
|
def IndexPutImpl1DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1000), torch.randint(10, (500,)),
|
module.forward(tu.rand(1000), torch.randint(10, (500,)),
|
||||||
tu.rand(500))
|
tu.rand(500))
|
||||||
|
|
||||||
|
|
||||||
class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module):
|
class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, index, value):
|
||||||
|
return torch.ops.aten._index_put_impl_(input.clone(), (index,), value,
|
||||||
|
accumulate=True,
|
||||||
|
unsafe=False)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IndexPutImpl2DFloatAccumulateModule())
|
||||||
|
def IndexPutImpl2DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 8), torch.randint(4, (5,)),
|
||||||
|
tu.rand(5, 8))
|
||||||
|
|
||||||
|
|
||||||
|
class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, input, index, value):
|
||||||
|
return torch.ops.aten._index_put_impl_(input.clone(), (index,), value,
|
||||||
|
accumulate=True,
|
||||||
|
unsafe=False)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: IndexPutImpl3DFloatAccumulateModule())
|
||||||
|
def IndexPutImpl3DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(10, 8, 6), torch.randint(4, (5,)),
|
||||||
|
tu.rand(5, 8, 6))
|
||||||
|
|
||||||
|
|
||||||
|
class IndexPutImpl1DIntAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -105,14 +201,14 @@ class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module):
|
||||||
unsafe=False)
|
unsafe=False)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexPutImplOneDimIntAccumulateModule())
|
@register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule())
|
||||||
def IndexPutImplOneDimIntAccumulateModule_basic(module, tu: TestUtils):
|
def IndexPutImpl1DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
|
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
|
||||||
torch.randint(1000, (10,)))
|
torch.randint(1000, (10,)))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module):
|
class IndexPut1DFloatNonAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -128,13 +224,13 @@ class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module):
|
||||||
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
|
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexPutOneDimFloatNonAccumulateModule())
|
@register_test_case(module_factory=lambda: IndexPut1DFloatNonAccumulateModule())
|
||||||
def IndexPutOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
def IndexPut1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(100), torch.randint(100, (250,)),
|
module.forward(tu.rand(100), torch.randint(100, (250,)),
|
||||||
tu.rand(250))
|
tu.rand(250))
|
||||||
|
|
||||||
|
|
||||||
class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module):
|
class IndexPut1DIntNonAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -150,13 +246,13 @@ class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module):
|
||||||
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
|
return torch.ops.aten.index_put(input, (index,), value, accumulate=False)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexPutOneDimIntNonAccumulateModule())
|
@register_test_case(module_factory=lambda: IndexPut1DIntNonAccumulateModule())
|
||||||
def IndexPutOneDimIntNonAccumulateModule_basic(module, tu: TestUtils):
|
def IndexPut1DIntNonAccumulateModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
|
module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)),
|
||||||
torch.randint(10000, (300,)))
|
torch.randint(10000, (300,)))
|
||||||
|
|
||||||
|
|
||||||
class IndexPutOneDimFloatAccumulateModule(torch.nn.Module):
|
class IndexPut1DFloatAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -172,13 +268,13 @@ class IndexPutOneDimFloatAccumulateModule(torch.nn.Module):
|
||||||
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
|
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexPutOneDimFloatAccumulateModule())
|
@register_test_case(module_factory=lambda: IndexPut1DFloatAccumulateModule())
|
||||||
def IndexPutOneDimFloatAccumulateModule_basic(module, tu: TestUtils):
|
def IndexPut1DFloatAccumulateModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(1000), torch.randint(10, (500,)),
|
module.forward(tu.rand(1000), torch.randint(10, (500,)),
|
||||||
tu.rand(500))
|
tu.rand(500))
|
||||||
|
|
||||||
|
|
||||||
class IndexPutOneDimIntAccumulateModule(torch.nn.Module):
|
class IndexPut1DIntAccumulateModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -194,7 +290,7 @@ class IndexPutOneDimIntAccumulateModule(torch.nn.Module):
|
||||||
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
|
return torch.ops.aten.index_put(input, (index,), value, accumulate=True)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: IndexPutOneDimIntAccumulateModule())
|
@register_test_case(module_factory=lambda: IndexPut1DIntAccumulateModule())
|
||||||
def IndexPutOneDimIntAccumulateModule_basic(module, tu: TestUtils):
|
def IndexPut1DIntAccumulateModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
|
module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)),
|
||||||
torch.randint(1000, (10,)))
|
torch.randint(1000, (10,)))
|
||||||
|
|
Loading…
Reference in New Issue