use ScatterReduceTwo instead because ScatterReduce fails for unknown reasons

Xida Ren 2024-09-25 16:14:01 +00:00
parent 1ba1506266
commit 8a999e17ed
1 changed files with 14 additions and 10 deletions

View File

@ -5189,7 +5189,6 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
loc, flattenedInputType, input, inputDimsStart, inputDimsEnd);
// nonzero_mask = (t != 0)
Value zero = c(0);
auto boolMaskType = inputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), rewriter.getI1Type());
Value boolMask = rewriter.create<AtenNeScalarOp>(loc, boolMaskType,
@ -5208,7 +5207,7 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
dyn_cast<BaseTensorType>(flattenedInputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), si64Type));
Value cumulativeSum = rewriter.create<AtenCumsumOp>(loc, cumulativeSumType,
intMask, zero, noneCst);
intMask, c(0), noneCst);
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value subtracted = rewriter.create<AtenSubScalarOp>(
@ -5216,11 +5215,11 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
// destination_indices = torch.clamp(destination_indices, min=0)
Value indices = rewriter.create<AtenClampOp>(loc, cumulativeSumType,
subtracted, zero, noneCst);
subtracted, c(0), noneCst);
// iota = torch.tensor(range(len(t))) * nonzero_mask.int()
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
loc, cumulativeSumType, zero,
loc, cumulativeSumType, c(0),
rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])),
one, noneCst, noneCst, noneCst, noneCst);
@ -5239,10 +5238,15 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
// dim=0,
// index=destination_indices,
// src=iota, reduce='add')
Value reduceStr = rewriter.create<ConstantStrOp>(loc, "add");
Value scatteredTensor = rewriter.create<AtenScatterReduceOp>(
loc, cumulativeSumType, zerosTensor, zero, indices, multiplied,
reduceStr);
Value reduceStr = rewriter.create<ConstantStrOp>(loc, "sum");
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
Value scatteredTensor = rewriter.create<AtenScatterReduceTwoOp>(
loc, cumulativeSumType, zerosTensor, /*axis=*/constAxis,
/*dims=*/indices, /*src=*/multiplied, reduceStr, cstFalse);
// result_flat = compacted[:torch.sum(nonzero_mask)]
auto scalarType = ValueTensorType::get(rewriter.getContext(),
@ -5254,13 +5258,13 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
auto slicedResultType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
Value slicedResult = rewriter.create<AtenSliceTensorOp>(
loc, slicedResultType, scatteredTensor, zero, zero, numNonzero, one);
loc, slicedResultType, scatteredTensor, c(0), c(0), numNonzero, one);
// strides = torch.cumprod(torch.flip(inputShapeTensor, [0]), 0).flip(0)
Value flippedShape = rewriter.create<AtenFlipOp>(
loc, shapeType, inputShapeTensor, makeOneElementList(c(0)));
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
loc, shapeType, flippedShape, zero, noneCst);
loc, shapeType, flippedShape, c(0), noneCst);
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
loc, shapeType, cumulativeProduct, makeOneElementList(c(0)));
// strides = torch.cat([strides[1:], torch.tensor([1],