[tmtensor] Add support for i64 index type for tm_tensor.scatter

This restriction is a bit needless and could be wrong for larger memory
blocks.
pull/3687/head
Rob Suderman 2024-09-04 17:05:09 -07:00
parent 9a4c8c606c
commit cd6ca7021e
2 changed files with 3 additions and 6 deletions

View File

@ -793,7 +793,6 @@ public:
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
MLIRContext *context = op->getContext();
Value input = op.getSelf();
Value values = op.getValues();
auto inputType = cast<ValueTensorType>(input.getType());
@ -915,9 +914,6 @@ public:
rewriter.create<AtenViewOp>(loc, valuesType, values, valuesDimsList);
// `TMTensor::ScatterOp` expects indices of element type i32.
indices = convertTensorToDtype(
rewriter, loc, indices,
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
input = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(input.getType()), input);

View File

@ -543,8 +543,9 @@ LogicalResult ScatterOp::verify() {
auto indicesType = getIndicesType();
if (indicesType.getRank() != 2 ||
!indicesType.getElementType().isInteger(32)) {
return emitOpError("expected indices to be of rank 2 of i32 element type");
!isa<IntegerType>(indicesType.getElementType())) {
return emitOpError(
"expected indices to be of rank 2 of integer element type");
}
auto indexDepth = getIndexDepth();
if (ShapedType::isDynamic(indexDepth)) {