mirror of https://github.com/llvm/torch-mlir
[tmtensor] Add support for i64 index type for tm_tensor.scatter
This restriction is a bit needless and could be wrong for larger memory blocks.pull/3687/head
parent
9a4c8c606c
commit
cd6ca7021e
|
@ -793,7 +793,6 @@ public:
|
|||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value input = op.getSelf();
|
||||
Value values = op.getValues();
|
||||
auto inputType = cast<ValueTensorType>(input.getType());
|
||||
|
@ -915,9 +914,6 @@ public:
|
|||
rewriter.create<AtenViewOp>(loc, valuesType, values, valuesDimsList);
|
||||
|
||||
// `TMTensor::ScatterOp` expects indices of element type i32.
|
||||
indices = convertTensorToDtype(
|
||||
rewriter, loc, indices,
|
||||
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
||||
|
||||
input = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc, typeConverter->convertType(input.getType()), input);
|
||||
|
|
|
@ -543,8 +543,9 @@ LogicalResult ScatterOp::verify() {
|
|||
|
||||
auto indicesType = getIndicesType();
|
||||
if (indicesType.getRank() != 2 ||
|
||||
!indicesType.getElementType().isInteger(32)) {
|
||||
return emitOpError("expected indices to be of rank 2 of i32 element type");
|
||||
!isa<IntegerType>(indicesType.getElementType())) {
|
||||
return emitOpError(
|
||||
"expected indices to be of rank 2 of integer element type");
|
||||
}
|
||||
auto indexDepth = getIndexDepth();
|
||||
if (ShapedType::isDynamic(indexDepth)) {
|
||||
|
|
Loading…
Reference in New Issue