Fix AtenArangeStartStepOp dynamic end support

pull/3876/head
AmosLewis 2024-11-19 19:41:29 -08:00
parent 35e20e04b8
commit f7cd0fc615
1 changed files with 6 additions and 4 deletions

View File

@ -5604,11 +5604,13 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
subtracted, c(0));
// iota = torch.tensor(range(len(t))) * nonzero_mask.int()
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value end =
rewriter.create<AtenSizeIntOp>(loc, flattenedInput, /*dim=*/constZero);
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
loc, cumulativeSumType, c(0),
rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])),
one, noneCst, noneCst, noneCst, noneCst);
loc, cumulativeSumType, c(0), end, one, noneCst, noneCst, noneCst,
noneCst);
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, cumulativeSumType,
rangeTensor, intMask);