mirror of https://github.com/llvm/torch-mlir
[Torch] Add decompose for 1d torch.nonzero
parent
06d17897f0
commit
35e20e04b8
|
@ -5523,6 +5523,192 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenNonzeroOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto si64Type = rewriter.getIntegerType(64, true);
|
||||||
|
Value si64Dtype = getDtypeIntValueForType(rewriter, loc, si64Type);
|
||||||
|
// helper for making int constants
|
||||||
|
std::function<Value(int64_t)> c = [&](int64_t val) {
|
||||||
|
Value newIntConstant =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(val));
|
||||||
|
return newIntConstant;
|
||||||
|
};
|
||||||
|
std::function<Value(Value)> makeOneElementList = [&](Value element) {
|
||||||
|
auto listType = Torch::ListType::get(element.getType());
|
||||||
|
return rewriter.create<PrimListConstructOp>(loc, listType,
|
||||||
|
ArrayRef<Value>{element});
|
||||||
|
};
|
||||||
|
|
||||||
|
Value input = op.getSelf();
|
||||||
|
auto inputType = dyn_cast<BaseTensorType>(input.getType());
|
||||||
|
int64_t inputRank = inputType.getSizes().size();
|
||||||
|
|
||||||
|
// original_shape = t.shape
|
||||||
|
auto shapeType = Torch::ValueTensorType::get(
|
||||||
|
rewriter.getContext(), SmallVector<int64_t>{inputRank}, si64Type);
|
||||||
|
Value inputShapeTensor =
|
||||||
|
rewriter.create<Torch::Aten_ShapeAsTensorOp>(loc, shapeType, input);
|
||||||
|
|
||||||
|
// t = flatten(t)
|
||||||
|
int64_t flattenedSize = 1;
|
||||||
|
if (inputType.hasSizes()) {
|
||||||
|
for (auto size : inputType.getSizes()) {
|
||||||
|
flattenedSize *= size;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
flattenedSize = kUnknownSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto flattendInputShape = SmallVector<int64_t>{flattenedSize};
|
||||||
|
auto flattenedInputType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
flattendInputShape, inputType.getOptionalDtype());
|
||||||
|
|
||||||
|
Value inputDimsStart =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value inputDimsEnd = rewriter.create<ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
||||||
|
|
||||||
|
Value flattenedInput = rewriter.create<AtenFlattenUsingIntsOp>(
|
||||||
|
loc, flattenedInputType, input, inputDimsStart, inputDimsEnd);
|
||||||
|
|
||||||
|
// nonzero_mask = (t != 0)
|
||||||
|
auto boolMaskType = inputType.getWithSizesAndDtype(
|
||||||
|
flattenedInputType.getOptionalSizes(), rewriter.getI1Type());
|
||||||
|
Value boolMask = rewriter.create<AtenNeScalarOp>(loc, boolMaskType,
|
||||||
|
flattenedInput, c(0));
|
||||||
|
|
||||||
|
// nonzero_mask = nonzero_mask.int()
|
||||||
|
Value falseCst = rewriter.create<ConstantBoolOp>(loc, false);
|
||||||
|
Value noneCst = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
auto intMaskType = flattenedInputType.getWithSizesAndDtype(
|
||||||
|
flattenedInputType.getOptionalSizes(), si64Type); // ####
|
||||||
|
Value intMask = rewriter.create<AtenToDtypeOp>(
|
||||||
|
loc, intMaskType, boolMask, si64Dtype, falseCst, falseCst, noneCst);
|
||||||
|
|
||||||
|
// destination_indices = torch.cumsum(nonzero_mask, 0) - 1
|
||||||
|
auto cumulativeSumType =
|
||||||
|
dyn_cast<BaseTensorType>(flattenedInputType.getWithSizesAndDtype(
|
||||||
|
flattenedInputType.getOptionalSizes(), si64Type));
|
||||||
|
Value cumulativeSum = rewriter.create<AtenCumsumOp>(loc, cumulativeSumType,
|
||||||
|
intMask, c(0), noneCst);
|
||||||
|
Value one =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
Value subtracted = rewriter.create<AtenSubScalarOp>(
|
||||||
|
loc, cumulativeSumType, cumulativeSum, one, /*alpha=*/one);
|
||||||
|
|
||||||
|
// destination_indices = torch.clamp(destination_indices, min=0)
|
||||||
|
Value indices = rewriter.create<AtenClampMinOp>(loc, cumulativeSumType,
|
||||||
|
subtracted, c(0));
|
||||||
|
|
||||||
|
// iota = torch.tensor(range(len(t))) * nonzero_mask.int()
|
||||||
|
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
|
||||||
|
loc, cumulativeSumType, c(0),
|
||||||
|
rewriter.create<ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])),
|
||||||
|
one, noneCst, noneCst, noneCst, noneCst);
|
||||||
|
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, cumulativeSumType,
|
||||||
|
rangeTensor, intMask);
|
||||||
|
|
||||||
|
// scatter_self = torch.zeros_like(t, dtype=torch.int64)
|
||||||
|
// AtenFullLike doesn't support index type so we have to use si64
|
||||||
|
auto zerosTensorType = cumulativeSumType.getWithSizesAndDtype(
|
||||||
|
cumulativeSumType.getOptionalSizes(), si64Type);
|
||||||
|
Value zerosTensor = rewriter.create<AtenZerosLikeOp>(
|
||||||
|
loc, zerosTensorType, cumulativeSum, si64Dtype, noneCst, noneCst,
|
||||||
|
noneCst, noneCst);
|
||||||
|
|
||||||
|
// compacted = scatter_self.scatter_(
|
||||||
|
// dim=0,
|
||||||
|
// index=destination_indices,
|
||||||
|
// src=iota, reduce='add')
|
||||||
|
Value reduceStr = rewriter.create<ConstantStrOp>(loc, "sum");
|
||||||
|
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
|
||||||
|
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
||||||
|
Value scatteredTensor = rewriter.create<AtenScatterReduceTwoOp>(
|
||||||
|
loc, cumulativeSumType, zerosTensor, /*axis=*/constAxis,
|
||||||
|
/*dims=*/indices, /*src=*/multiplied, reduceStr, cstFalse);
|
||||||
|
|
||||||
|
// result_flat = compacted[:torch.sum(nonzero_mask)]
|
||||||
|
auto scalarType = ValueTensorType::get(rewriter.getContext(),
|
||||||
|
ArrayRef<int64_t>{}, si64Type);
|
||||||
|
Value sumMask =
|
||||||
|
rewriter.create<AtenSumOp>(loc, scalarType, intMask, noneCst);
|
||||||
|
Value numNonzero = rewriter.create<AtenIntTensorOp>(loc, sumMask);
|
||||||
|
|
||||||
|
auto slicedResultType = Torch::ValueTensorType::get(
|
||||||
|
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
|
||||||
|
Value slicedResult =
|
||||||
|
rewriter.create<AtenSliceTensorOp>(loc, slicedResultType,
|
||||||
|
/*self=*/scatteredTensor,
|
||||||
|
/*dim=*/c(0),
|
||||||
|
/*start=*/c(0),
|
||||||
|
/*end=*/numNonzero,
|
||||||
|
/*step=*/one);
|
||||||
|
|
||||||
|
// strides = torch.cumprod(torch.flip(inputShapeTensor, [0]), 0).flip(0)
|
||||||
|
Value flippedShape = rewriter.create<AtenFlipOp>(
|
||||||
|
loc, shapeType, inputShapeTensor, makeOneElementList(c(0)));
|
||||||
|
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
|
||||||
|
loc, shapeType, flippedShape, c(0), noneCst);
|
||||||
|
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
|
||||||
|
loc, shapeType, cumulativeProduct, makeOneElementList(c(0)));
|
||||||
|
// strides = torch.cat([strides[1:], torch.tensor([1],
|
||||||
|
// device=t.device)])
|
||||||
|
auto oneTensorType = ValueTensorType::get(
|
||||||
|
rewriter.getContext(), SmallVector<int64_t>{1}, si64Type);
|
||||||
|
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
|
||||||
|
loc, oneTensorType, c(1), si64Dtype, noneCst, noneCst, noneCst);
|
||||||
|
|
||||||
|
auto slicedStrideType = Torch::ValueTensorType::get(
|
||||||
|
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
|
||||||
|
si64Type);
|
||||||
|
Value strideSliceStart = c(1);
|
||||||
|
Value strideSliceEnd = c(inputRank);
|
||||||
|
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
|
||||||
|
loc, slicedStrideType, flippedCumulativeProduct, /*dim*/ c(0),
|
||||||
|
/*start=*/strideSliceStart, /*end=*/strideSliceEnd, /*step=*/c(1));
|
||||||
|
|
||||||
|
auto tensorListElementType = Torch::ValueTensorType::get(
|
||||||
|
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
|
||||||
|
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(tensorListElementType),
|
||||||
|
SmallVector<Value>{slicedStrides, oneTensor});
|
||||||
|
Value strides =
|
||||||
|
rewriter.create<Torch::AtenCatOp>(loc, shapeType, tensorList, c(0));
|
||||||
|
|
||||||
|
// multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
|
||||||
|
// inputShapeTensor
|
||||||
|
auto unsqueezedResultType = ValueTensorType::get(
|
||||||
|
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1}, si64Type);
|
||||||
|
Value unsqueezedResult = rewriter.create<AtenUnsqueezeOp>(
|
||||||
|
loc, unsqueezedResultType, slicedResult, c(1));
|
||||||
|
|
||||||
|
auto unsqueezedStridesType = ValueTensorType::get(
|
||||||
|
rewriter.getContext(), SmallVector<int64_t>{1, inputRank}, si64Type);
|
||||||
|
Value unsqueezedStrides = rewriter.create<AtenUnsqueezeOp>(
|
||||||
|
loc, unsqueezedStridesType, strides, c(0));
|
||||||
|
|
||||||
|
auto dividedBroadcastType = ValueTensorType::get(
|
||||||
|
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, inputRank},
|
||||||
|
si64Type);
|
||||||
|
Value divided = rewriter.create<AtenFloorDivideOp>(
|
||||||
|
loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides);
|
||||||
|
|
||||||
|
auto resultType = cast<BaseTensorType>(op.getType());
|
||||||
|
Value modded = rewriter.create<AtenRemainderTensorOp>(
|
||||||
|
loc, resultType, divided, inputShapeTensor);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, modded);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
|
// Decompose aten.addmm into aten.mm and aten.add.Tensor op.
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
|
class DecomposeAtenAddmmOp : public OpRewritePattern<AtenAddmmOp> {
|
||||||
|
@ -10573,6 +10759,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
|
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
|
||||||
patterns);
|
patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTanhBackwardOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenNonzeroOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
|
||||||
|
|
|
@ -6255,3 +6255,26 @@ def AtenPolarDoubleModule_basic(module, tu: TestUtils):
|
||||||
module.forward(
|
module.forward(
|
||||||
tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64)
|
tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AtenNonzero1DModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1], torch.bool, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.nonzero(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenNonzero1DModule())
|
||||||
|
def AtenNonzero1DModule_one_nonzero(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([0, 0, 5, 0, 0, 0], dtype=torch.int))
|
||||||
|
|
Loading…
Reference in New Issue