mirror of https://github.com/llvm/torch-mlir
Fix AtenArangeStartStepOp dynamic end support
parent
35e20e04b8
commit
f7cd0fc615
|
@ -5604,11 +5604,13 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
|
|||
subtracted, c(0));
|
||||
|
||||
// iota = torch.tensor(range(len(t))) * nonzero_mask.int()
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value end =
|
||||
rewriter.create<AtenSizeIntOp>(loc, flattenedInput, /*dim=*/constZero);
|
||||
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
|
||||
loc, cumulativeSumType, c(0),
|
||||
rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(flattenedInputType.getSizes()[0])),
|
||||
one, noneCst, noneCst, noneCst, noneCst);
|
||||
loc, cumulativeSumType, c(0), end, one, noneCst, noneCst, noneCst,
|
||||
noneCst);
|
||||
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, cumulativeSumType,
|
||||
rangeTensor, intMask);
|
||||
|
||||
|
|
Loading…
Reference in New Issue