diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 946657fda..64702c76b 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -35,6 +35,9 @@ Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, // Helper to convert a tensor to a specific scalar type. Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input, Type dtype); +// Helper funtion to get rank of `Base tensor type`. +// -1 is returned if the tensorRank can't be determined. +int getTensorRank(Value tensor); } // namespace Torch } // namespace torch diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index c1f82bff6..bd18e596e 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -48,6 +48,28 @@ using namespace mlir::torch::TMTensor; // that these patterns become mostly mechanical associations of // "aten.foo -> linalg.foo". +static Value createTMTensorScatterOp( + OpBuilder &b, Location loc, Value updates, Value indices, Value original, + bool uniqueIndices, + function_ref bodyBuild) { + auto originalTensorType = original.getType().cast(); + Type originalElementType = originalTensorType.getElementType(); + auto scatterOp = b.create( + loc, originalTensorType, ValueRange{updates, indices}, + ValueRange{original}, uniqueIndices); + + Region &scatterOpRegion = scatterOp.region(); + auto &scatterOpBlock = scatterOpRegion.emplaceBlock(); + scatterOpBlock.addArguments({originalElementType, originalElementType}, + {loc, loc}); + OpBuilder regionBuilder(scatterOpRegion); + auto blockArgs = scatterOpBlock.getArguments(); + Value updatesElement = blockArgs[0]; + Value originalElement = blockArgs[1]; + bodyBuild(regionBuilder, loc, updatesElement, originalElement); + return scatterOp->getResult(0); +} + namespace { // aten::bincount op counts the frequency of each value in a 1-d input tensor of // non-negative ints. @@ -88,7 +110,7 @@ public: // TODO: Incorporate the weight argument. if (!weights.getType().isa()) return rewriter.notifyMatchFailure( - op, "Unimplemented, the weights operand is not incorporated."); + op, "Unimplemented: the weights operand is not incorporated."); // Finding the maximum value in the input tensor. SmallVector maxTensorSizes; @@ -129,9 +151,9 @@ public: indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); - Type resultElemType = typeConverter->convertType(op->getResult(0).getType()) - .cast() - .getElementType(); + auto resultType = typeConverter->convertType(op->getResult(0).getType()) + .cast(); + Type resultElemType = resultType.getElementType(); SmallVector inputSizeDynamic = getTensorSizesUntilDim(rewriter, loc, input, 0); @@ -152,25 +174,14 @@ public: Value bincountTensor = createInitTensor(rewriter, loc, {bincountSize}, resultElemType, constantZero); - auto scatterOp = rewriter.create( - loc, bincountTensor.getType(), ValueRange{updatesTensor, indices}, - ValueRange{bincountTensor}, - /*unique_indices=*/false); - - Region &scatterOpRegion = scatterOp.region(); - auto &scatterOpBlock = scatterOpRegion.emplaceBlock(); - scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType}, - {loc, loc}); - auto blockArgs = scatterOpBlock.getArguments(); - - // Creating an add instruction inside the scatter op region to increment the - // frequency counter with one. - OpBuilder regionBuilder(scatterOpRegion); - Value add = regionBuilder.create(loc, - /*bincount=*/blockArgs[1], - constantOne); - regionBuilder.create(loc, add); - rewriter.replaceOp(op, scatterOp->getResult(0)); + Value scatterOp = createTMTensorScatterOp( + rewriter, loc, updatesTensor, indices, bincountTensor, + /*uniqueIndices=*/false, + [&](OpBuilder &b, Location loc, Value _, Value bincountElem) { + Value add = b.create(loc, bincountElem, constantOne); + b.create(loc, add); + }); + rewriter.replaceOpWithNewOp(op, resultType, scatterOp); return success(); } }; @@ -192,14 +203,8 @@ public: Value values = adaptor.values(); RankedTensorType inputType = input.getType().cast(); RankedTensorType valuesType = values.getType().cast(); - Type resultElemType = typeConverter->convertType(op->getResult(0).getType()) - .cast() - .getElementType(); - - // TODO: Add support for the input with rank other than one. - if (inputType.getRank() != 1) - return rewriter.notifyMatchFailure( - op, "unimplemented: input rank other than one is not supported"); + auto resultType = typeConverter->convertType(op->getResult(0).getType()) + .cast(); // The unsafe should be either `False` or `none`. if (!op.unsafe().getType().isa()) { @@ -231,23 +236,14 @@ public: return rewriter.notifyMatchFailure( op, "Indices list size should not be greater than the input rank."); - // TODO: Add support for cases with indices list size smaller than the input - // rank. - if ((int64_t)indicesList.size() < inputType.getRank()) + // TODO: Add support for cases with indices list size not equal to 1. + if (indicesList.size() != 1) return rewriter.notifyMatchFailure( - op, "Unimplemented, Indices list size smaller than input rank"); + op, "Unimplemented: Indices list size != 1"); + Value indexTensor = indicesList[0]; - if (indicesList[0].getType().isa()) - return rewriter.notifyMatchFailure(op, - "Indices tensor must not be none."); - - // TODO: Add support for the index with rank other than one. - int64_t indexRank = typeConverter->convertType(indicesList[0].getType()) - .cast() - .getRank(); - if (indexRank != 1) - return rewriter.notifyMatchFailure( - op, "unimplemented: index rank other than one is not supported"); + if (indexTensor.getType().isa()) + return rewriter.notifyMatchFailure(op, "Index tensor must not be None."); // Creating a tm_tensor.scatter op with the following mapping: // 1.) Index tensor from the `indicesList` maps to the indices in scatter @@ -255,48 +251,55 @@ public: // to i32 as required for the scatter op. // 2.) `values` is mapped to `updates` in scatter op. // 3.) `input` is mapped to `original` in scatter op. - ValueTensorType indexType = - indicesList[0].getType().cast(); - SmallVector expandedIndexSizes{indexType.getSizes()[0], 1}; - ValueTensorType expandedIndexType = ValueTensorType::get( - context, llvm::makeArrayRef(expandedIndexSizes), indexType.getDtype()); + if (getTensorRank(indexTensor) != 1) + return rewriter.notifyMatchFailure( + op, "unimplemented: index tensor with rank != 1 is not supported"); + auto indexTensorType = indexTensor.getType().cast(); + int64_t indexTensorSize = indexTensorType.getSizes()[0]; + SmallVector expandedIndexTensorSizes{indexTensorSize, 1}; + ValueTensorType expandedIndexTensorType = ValueTensorType::get( + context, llvm::makeArrayRef(expandedIndexTensorSizes), + indexTensorType.getDtype()); Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); Value expandedIndexTensor = rewriter.create( - loc, expandedIndexType, indicesList[0], torchCstOne); + loc, expandedIndexTensorType, indexTensor, torchCstOne); - // Converting the index element type to i32. + // `TMTensor::ScatterOp` expects indices of element type i32. Value indices = convertTensorToDtype( rewriter, loc, expandedIndexTensor, mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); - auto scatterOp = rewriter.create( - loc, input.getType(), ValueRange{values, indices}, ValueRange{input}, - /*unique_indices=*/false); + bool invalidInputTypeFound = false; + Value scatterOp = createTMTensorScatterOp( + rewriter, loc, values, indices, input, /*uniqueIndices=*/false, + [&](OpBuilder &b, Location loc, Value valuesElement, + Value inputElement) { + Value yieldValue = valuesElement; + if (accumulate) { + if (inputElement.getType().isa()) { + yieldValue = + b.create(loc, inputElement, valuesElement); + } else if (inputElement.getType().isa()) { + yieldValue = + b.create(loc, inputElement, valuesElement); + } else { + invalidInputTypeFound = true; + return; + } + } + b.create(loc, yieldValue); + }); - Region &scatterOpRegion = scatterOp.region(); - auto &scatterOpBlock = scatterOpRegion.emplaceBlock(); - scatterOpBlock.addArguments(TypeRange{resultElemType, resultElemType}, - {loc, loc}); - auto blockArgs = scatterOpBlock.getArguments(); - - OpBuilder regionBuilder(scatterOpRegion); - Value update = blockArgs[0]; - Value original = blockArgs[1]; - Value yieldValue = update; - // Create an add instruction inside the scatter op region to increment the - // `original` value with the value from `updates` if the accumulate flag is - // true. - if (accumulate) { - if (inputType.getElementType().isa()) - yieldValue = regionBuilder.create(loc, original, update); - else if (inputType.getElementType().isa()) - yieldValue = regionBuilder.create(loc, original, update); + if (invalidInputTypeFound) { + return rewriter.notifyMatchFailure( + op, + "unimplemented: input tensor must be of integer type or float type"); } - regionBuilder.create(loc, yieldValue); - rewriter.replaceOp(op, scatterOp->getResult(0)); + + rewriter.replaceOpWithNewOp(op, resultType, scatterOp); return success(); } }; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f66bad3db..70a83745b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -21,19 +21,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -// Helper funtion to get rank of `Base tensor type`. -// -1 is returned if the tensorRank can't be determined. -static int getTensorRank(Value tensor) { - int tensorRank = -1; - BaseTensorType tensorType = tensor.getType().cast(); - - if (tensorType.hasSizes()) { - ArrayRef tensorShape = tensorType.getSizes(); - tensorRank = tensorShape.size(); - } - return tensorRank; -} - // Helper function to compute the return type of the reduction function. // `dim` specifies the dimension to reduce and `keepDim` preserves the rank of // the input tensor. diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index f02d60fa8..e818e895f 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -95,3 +95,14 @@ Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc, loc, newType, input, convertIntVal, falseVal, falseVal, noneVal); return converted; } + +int Torch::getTensorRank(Value tensor) { + int tensorRank = -1; + BaseTensorType tensorType = tensor.getType().cast(); + + if (tensorType.hasSizes()) { + ArrayRef tensorShape = tensorType.getSizes(); + tensorRank = tensorShape.size(); + } + return tensorRank; +} diff --git a/python/torch_mlir_e2e_test/test_suite/index_put.py b/python/torch_mlir_e2e_test/test_suite/index_put.py index 0f2fc3698..6501e7f00 100644 --- a/python/torch_mlir_e2e_test/test_suite/index_put.py +++ b/python/torch_mlir_e2e_test/test_suite/index_put.py @@ -11,7 +11,7 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== -class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module): +class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() @@ -29,13 +29,61 @@ class IndexPutImplOneDimFloatNonAccumulateModule(torch.nn.Module): unsafe=False) -@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatNonAccumulateModule()) -def IndexPutImplOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: IndexPutImpl1DFloatNonAccumulateModule()) +def IndexPutImpl1DFloatNonAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(100), torch.randint(100, (250,)), tu.rand(250)) -class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module): +class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._index_put_impl_(input, (index,), value, + accumulate=False, + unsafe=False) + + +@register_test_case(module_factory=lambda: IndexPutImpl2DFloatNonAccumulateModule()) +def IndexPutImpl2DFloatNonAccumulateModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8), torch.randint(4, (5,)), + tu.rand(5, 8)) + + +class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._index_put_impl_(input, (index,), value, + accumulate=False, + unsafe=False) + + +@register_test_case(module_factory=lambda: IndexPutImpl3DFloatNonAccumulateModule()) +def IndexPutImpl3DFloatNonAccumulateModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), torch.randint(4, (5,)), + tu.rand(5, 8, 6)) + + +class IndexPutImpl1DIntNonAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() @@ -53,13 +101,13 @@ class IndexPutImplOneDimIntNonAccumulateModule(torch.nn.Module): unsafe=False) -@register_test_case(module_factory=lambda: IndexPutImplOneDimIntNonAccumulateModule()) -def IndexPutImplOneDimIntNonAccumulateModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: IndexPutImpl1DIntNonAccumulateModule()) +def IndexPutImpl1DIntNonAccumulateModule_basic(module, tu: TestUtils): module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)), torch.randint(10000, (300,))) -class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module): +class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() @@ -79,13 +127,61 @@ class IndexPutImplOneDimFloatAccumulateModule(torch.nn.Module): unsafe=False) -@register_test_case(module_factory=lambda: IndexPutImplOneDimFloatAccumulateModule()) -def IndexPutImplOneDimFloatAccumulateModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: IndexPutImpl1DFloatAccumulateModule()) +def IndexPutImpl1DFloatAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(1000), torch.randint(10, (500,)), tu.rand(500)) -class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module): +class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._index_put_impl_(input.clone(), (index,), value, + accumulate=True, + unsafe=False) + + +@register_test_case(module_factory=lambda: IndexPutImpl2DFloatAccumulateModule()) +def IndexPutImpl2DFloatAccumulateModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8), torch.randint(4, (5,)), + tu.rand(5, 8)) + + +class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._index_put_impl_(input.clone(), (index,), value, + accumulate=True, + unsafe=False) + + +@register_test_case(module_factory=lambda: IndexPutImpl3DFloatAccumulateModule()) +def IndexPutImpl3DFloatAccumulateModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), torch.randint(4, (5,)), + tu.rand(5, 8, 6)) + + +class IndexPutImpl1DIntAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() @@ -105,14 +201,14 @@ class IndexPutImplOneDimIntAccumulateModule(torch.nn.Module): unsafe=False) -@register_test_case(module_factory=lambda: IndexPutImplOneDimIntAccumulateModule()) -def IndexPutImplOneDimIntAccumulateModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: IndexPutImpl1DIntAccumulateModule()) +def IndexPutImpl1DIntAccumulateModule_basic(module, tu: TestUtils): module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)), torch.randint(1000, (10,))) # ============================================================================== -class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module): +class IndexPut1DFloatNonAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() @@ -128,13 +224,13 @@ class IndexPutOneDimFloatNonAccumulateModule(torch.nn.Module): return torch.ops.aten.index_put(input, (index,), value, accumulate=False) -@register_test_case(module_factory=lambda: IndexPutOneDimFloatNonAccumulateModule()) -def IndexPutOneDimFloatNonAccumulateModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: IndexPut1DFloatNonAccumulateModule()) +def IndexPut1DFloatNonAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(100), torch.randint(100, (250,)), tu.rand(250)) -class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module): +class IndexPut1DIntNonAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() @@ -150,13 +246,13 @@ class IndexPutOneDimIntNonAccumulateModule(torch.nn.Module): return torch.ops.aten.index_put(input, (index,), value, accumulate=False) -@register_test_case(module_factory=lambda: IndexPutOneDimIntNonAccumulateModule()) -def IndexPutOneDimIntNonAccumulateModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: IndexPut1DIntNonAccumulateModule()) +def IndexPut1DIntNonAccumulateModule_basic(module, tu: TestUtils): module.forward(torch.randint(1000, (200,)), torch.randint(100, (300,)), torch.randint(10000, (300,))) -class IndexPutOneDimFloatAccumulateModule(torch.nn.Module): +class IndexPut1DFloatAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() @@ -172,13 +268,13 @@ class IndexPutOneDimFloatAccumulateModule(torch.nn.Module): return torch.ops.aten.index_put(input, (index,), value, accumulate=True) -@register_test_case(module_factory=lambda: IndexPutOneDimFloatAccumulateModule()) -def IndexPutOneDimFloatAccumulateModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: IndexPut1DFloatAccumulateModule()) +def IndexPut1DFloatAccumulateModule_basic(module, tu: TestUtils): module.forward(tu.rand(1000), torch.randint(10, (500,)), tu.rand(500)) -class IndexPutOneDimIntAccumulateModule(torch.nn.Module): +class IndexPut1DIntAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() @@ -194,7 +290,7 @@ class IndexPutOneDimIntAccumulateModule(torch.nn.Module): return torch.ops.aten.index_put(input, (index,), value, accumulate=True) -@register_test_case(module_factory=lambda: IndexPutOneDimIntAccumulateModule()) -def IndexPutOneDimIntAccumulateModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: IndexPut1DIntAccumulateModule()) +def IndexPut1DIntAccumulateModule_basic(module, tu: TestUtils): module.forward(torch.randint(100, (10,)), torch.randint(10, (10,)), torch.randint(1000, (10,)))