diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 33a1c9f91..1aaf546c2 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -72,6 +72,9 @@ bool isBuiltInType(Type type); // std::nullopt is returned if the tensorRank can't be determined. std::optional getTensorRank(Value tensor); +// Helper function to get the number of elements in a tensor. +std::optional getTensorNumel(Value tensor); + bool isViewLikeOp(Operation *op); Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 54b852dcf..5ec22233b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3371,6 +3371,104 @@ public: }; } // namespace +// Decompose aten.masked_scatter: +// def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor: +// mask_int = mask + torch.zeros_like(self) +// prefix_sum = torch.cumsum(mask_int.flatten(), dim=0) +// mask_prefix = torch.clamp(prefix_sum - 1, min=0) +// mask = mask.to(torch.bool) +// source = source.flatten()[mask_prefix].reshape(mask.shape) +// return torch.where(mask, source, self) +namespace { +class DecomposeAtenMaskedScatterOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMaskedScatterOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + Value mask = op.getMask(); + Value source = op.getSource(); + Value self = op.getSelf(); + + auto selfTy = cast(self.getType()); + auto resTy = cast(op.getType()); + auto sourceTy = cast(source.getType()); + + if (!resTy || !resTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + if (!selfTy || !selfTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: no implementation for rankless tensor"); + if (!sourceTy || !sourceTy.areAllSizesKnown() || !sourceTy.hasDtype()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: no implementation for rankless tensor"); + + int64_t selfNumel = getTensorNumel(self).value(); // as selfTy has sizes + int64_t sourceNumel = + getTensorNumel(source).value(); // as sourceTy has sizes + int64_t selfRank = selfTy.getSizes().size(); + int64_t sourceRank = sourceTy.getSizes().size(); + + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constNone = rewriter.create(loc); + Value selfLastDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(selfRank - 1)); + Value sourceLastDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(sourceRank - 1)); + + auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); + auto int64Dtype = getDtypeIntValueForType( + rewriter, loc, + rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); + auto selfIntType = selfTy.getWithSizesAndDtype(selfTy.getSizes(), si64Type); + + Value zerosLike = rewriter.create( + loc, selfIntType, self, int64Dtype, constNone, constNone, constNone, + constNone); + Value maskInt = rewriter.create( + loc, selfIntType, mask, zerosLike, constOne); + + auto flattenMaskedType = selfTy.getWithSizesAndDtype( + /*optionalSizes=*/{selfNumel}, si64Type); + Value maskIntFlatten = rewriter.create( + loc, flattenMaskedType, maskInt, constZero, selfLastDim); + Value prefixSum = rewriter.create( + loc, flattenMaskedType, maskIntFlatten, + /*dim=*/constZero, constNone); + Value prefixSumMinusOne = rewriter.create( + loc, flattenMaskedType, prefixSum, constOne, constOne); + Value maskPrefix = rewriter.create( + loc, flattenMaskedType, prefixSumMinusOne, /*min=*/constZero, + /*max=*/constNone); + + auto sourceFlattenType = sourceTy.getWithSizesAndDtype( + /*optionalSizes=*/{sourceNumel}, sourceTy.getDtype()); + Value sourceFlatten = rewriter.create( + loc, sourceFlattenType, source, constZero, sourceLastDim); + + auto selectSourceType = sourceTy.getWithSizesAndDtype( + /*optionalSizes=*/{selfNumel}, sourceTy.getDtype()); + Value selectSource = rewriter.create( + loc, selectSourceType, sourceFlatten, constZero, maskPrefix); + + // Reshape normalized output back to the original input shape + auto selfShape = rewriter.create( + loc, Torch::ListType::get(IntType::get(context)), self); + Value sourceReshape = rewriter.create( + loc, selfTy, selectSource, selfShape); + rewriter.replaceOpWithNewOp(op, resTy, mask, + sourceReshape, self); + return success(); + } +}; +} // namespace + // Decompose aten._convolution-like to aten.convolution namespace { template @@ -7839,6 +7937,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index bda2d258a..0ca7ea9c4 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -390,6 +390,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index ed035b303..8101a2a5b 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -209,6 +209,19 @@ std::optional Torch::getTensorRank(Value tensor) { return tensorType.getSizes().size(); } +std::optional Torch::getTensorNumel(Value tensor) { + BaseTensorType tensorType = cast(tensor.getType()); + if (!tensorType.hasSizes()) + return std::nullopt; + int64_t numel = 1; + for (auto dim : tensorType.getSizes()) { + if (dim == ShapedType::kDynamic) + return ShapedType::kDynamic; + numel *= dim; + } + return numel; +} + bool Torch::isViewLikeOp(Operation *op) { // AtenContiguousOp might return a view, so this is conservatively // correct. We could potentially be more precise and identify the cases diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c6af8cbf4..fb6a09fa0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1072,6 +1072,7 @@ STABLEHLO_PASS_SET = { "LinspaceTwoSizeModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", "MaskedFillScalarIntValueStaticModule_basic", + "MaskedScatterStaticBasic_basic", "Matmul4dStatic_basic", "Matmul_2d", "Matmul_dot", @@ -2366,6 +2367,7 @@ ONNX_XFAIL_SET = { "LinalgNormKeepDimComplexModule_basic", "LinalgVectorNormComplexModule_basic", "LogSoftmaxBackwardModule_basic", + "MaskedScatterStaticBasic_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index cc4970573..8f7ea3291 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -12,6 +12,31 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== +class MaskedScatterStaticBasic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4, 4], torch.float32, True), + ([4, 4], torch.bool, True), + ([8, 8], torch.float32, True), + ] + ) + def forward(self, x, mask, y): + return torch.masked_scatter(x, mask, y) + + +@register_test_case(module_factory=lambda: MaskedScatterStaticBasic()) +def MaskedScatterStaticBasic_basic(module, tu: TestUtils): + x = torch.rand(4, 4) + mask = torch.rand(4, 4) > 0.5 + y = torch.rand(8, 8) + module.forward(x, mask, y) + + class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module): def __init__(self): super().__init__()