diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c9541a8a1..d4aaac608 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5189,7 +5189,6 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { loc, flattenedInputType, input, inputDimsStart, inputDimsEnd); // nonzero_mask = (t != 0) - Value zero = c(0); auto boolMaskType = inputType.getWithSizesAndDtype( flattenedInputType.getOptionalSizes(), rewriter.getI1Type()); Value boolMask = rewriter.create(loc, boolMaskType, @@ -5208,7 +5207,7 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { dyn_cast(flattenedInputType.getWithSizesAndDtype( flattenedInputType.getOptionalSizes(), si64Type)); Value cumulativeSum = rewriter.create(loc, cumulativeSumType, - intMask, zero, noneCst); + intMask, c(0), noneCst); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value subtracted = rewriter.create( @@ -5216,11 +5215,11 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { // destination_indices = torch.clamp(destination_indices, min=0) Value indices = rewriter.create(loc, cumulativeSumType, - subtracted, zero, noneCst); + subtracted, c(0), noneCst); // iota = torch.tensor(range(len(t))) * nonzero_mask.int() Value rangeTensor = rewriter.create( - loc, cumulativeSumType, zero, + loc, cumulativeSumType, c(0), rewriter.create( loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])), one, noneCst, noneCst, noneCst, noneCst); @@ -5239,10 +5238,15 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { // dim=0, // index=destination_indices, // src=iota, reduce='add') - Value reduceStr = rewriter.create(loc, "add"); - Value scatteredTensor = rewriter.create( - loc, cumulativeSumType, zerosTensor, zero, indices, multiplied, - reduceStr); + Value reduceStr = rewriter.create(loc, "sum"); + Value constAxis = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + Value cstFalse = rewriter.create(loc, false); + Value scatteredTensor = rewriter.create( + loc, cumulativeSumType, zerosTensor, /*axis=*/constAxis, + /*dims=*/indices, /*src=*/multiplied, reduceStr, cstFalse); // result_flat = compacted[:torch.sum(nonzero_mask)] auto scalarType = ValueTensorType::get(rewriter.getContext(), @@ -5254,13 +5258,13 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { auto slicedResultType = Torch::ValueTensorType::get( rewriter.getContext(), SmallVector{kUnknownSize}, si64Type); Value slicedResult = rewriter.create( - loc, slicedResultType, scatteredTensor, zero, zero, numNonzero, one); + loc, slicedResultType, scatteredTensor, c(0), c(0), numNonzero, one); // strides = torch.cumprod(torch.flip(inputShapeTensor, [0]), 0).flip(0) Value flippedShape = rewriter.create( loc, shapeType, inputShapeTensor, makeOneElementList(c(0))); Value cumulativeProduct = rewriter.create( - loc, shapeType, flippedShape, zero, noneCst); + loc, shapeType, flippedShape, c(0), noneCst); Value flippedCumulativeProduct = rewriter.create( loc, shapeType, cumulativeProduct, makeOneElementList(c(0))); // strides = torch.cat([strides[1:], torch.tensor([1],