mirror of https://github.com/llvm/torch-mlir
use ScatterReduceTwo instead because ScatterReduce fails for unknown reasons
parent
1ba1506266
commit
8a999e17ed
|
@ -5189,7 +5189,6 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
|
|||
loc, flattenedInputType, input, inputDimsStart, inputDimsEnd);
|
||||
|
||||
// nonzero_mask = (t != 0)
|
||||
Value zero = c(0);
|
||||
auto boolMaskType = inputType.getWithSizesAndDtype(
|
||||
flattenedInputType.getOptionalSizes(), rewriter.getI1Type());
|
||||
Value boolMask = rewriter.create<AtenNeScalarOp>(loc, boolMaskType,
|
||||
|
@ -5208,7 +5207,7 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
|
|||
dyn_cast<BaseTensorType>(flattenedInputType.getWithSizesAndDtype(
|
||||
flattenedInputType.getOptionalSizes(), si64Type));
|
||||
Value cumulativeSum = rewriter.create<AtenCumsumOp>(loc, cumulativeSumType,
|
||||
intMask, zero, noneCst);
|
||||
intMask, c(0), noneCst);
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value subtracted = rewriter.create<AtenSubScalarOp>(
|
||||
|
@ -5216,11 +5215,11 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
|
|||
|
||||
// destination_indices = torch.clamp(destination_indices, min=0)
|
||||
Value indices = rewriter.create<AtenClampOp>(loc, cumulativeSumType,
|
||||
subtracted, zero, noneCst);
|
||||
subtracted, c(0), noneCst);
|
||||
|
||||
// iota = torch.tensor(range(len(t))) * nonzero_mask.int()
|
||||
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
|
||||
loc, cumulativeSumType, zero,
|
||||
loc, cumulativeSumType, c(0),
|
||||
rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])),
|
||||
one, noneCst, noneCst, noneCst, noneCst);
|
||||
|
@ -5239,10 +5238,15 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
|
|||
// dim=0,
|
||||
// index=destination_indices,
|
||||
// src=iota, reduce='add')
|
||||
Value reduceStr = rewriter.create<ConstantStrOp>(loc, "add");
|
||||
Value scatteredTensor = rewriter.create<AtenScatterReduceOp>(
|
||||
loc, cumulativeSumType, zerosTensor, zero, indices, multiplied,
|
||||
reduceStr);
|
||||
Value reduceStr = rewriter.create<ConstantStrOp>(loc, "sum");
|
||||
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
|
||||
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value scatteredTensor = rewriter.create<AtenScatterReduceTwoOp>(
|
||||
loc, cumulativeSumType, zerosTensor, /*axis=*/constAxis,
|
||||
/*dims=*/indices, /*src=*/multiplied, reduceStr, cstFalse);
|
||||
|
||||
// result_flat = compacted[:torch.sum(nonzero_mask)]
|
||||
auto scalarType = ValueTensorType::get(rewriter.getContext(),
|
||||
|
@ -5254,13 +5258,13 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
|
|||
auto slicedResultType = Torch::ValueTensorType::get(
|
||||
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
|
||||
Value slicedResult = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, slicedResultType, scatteredTensor, zero, zero, numNonzero, one);
|
||||
loc, slicedResultType, scatteredTensor, c(0), c(0), numNonzero, one);
|
||||
|
||||
// strides = torch.cumprod(torch.flip(inputShapeTensor, [0]), 0).flip(0)
|
||||
Value flippedShape = rewriter.create<AtenFlipOp>(
|
||||
loc, shapeType, inputShapeTensor, makeOneElementList(c(0)));
|
||||
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
|
||||
loc, shapeType, flippedShape, zero, noneCst);
|
||||
loc, shapeType, flippedShape, c(0), noneCst);
|
||||
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
|
||||
loc, shapeType, cumulativeProduct, makeOneElementList(c(0)));
|
||||
// strides = torch.cat([strides[1:], torch.tensor([1],
|
||||
|
|
Loading…
Reference in New Issue