From da88efad893f5470213b902f6980c7123ad8c2f2 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Mon, 1 Apr 2024 21:25:02 +0800 Subject: [PATCH] [Torch] Fix bug of DecomposeAtenSelectIntOp (#3087) Fix bug of DecomposeAtenSelectIntOp. Because it may use resultTy when resultTy has not been inferred. ``` auto resultTy = op.getType().cast(); if (sliceTy.getSizes().size() == resultTy.getSizes().size()) { rewriter.replaceOp(op, slice); return success(); } ``` So I add restriction. --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index ba140bc72..8c69d3f0d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -764,6 +764,12 @@ public: Value dim = op.getDim(); Value self = op.getSelf(); + auto resultTy = op.getType().cast(); + if (!resultTy.hasSizes() || !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have sizes and dtype"); + } + // convert `start` to non-negative: start += int(start < 0) * dimSize Value zero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); @@ -785,7 +791,6 @@ public: op.getSelf(), dim, start, startPlusOne, /*step=*/one); auto sliceTy = cast(slice.getType()); - auto resultTy = cast(op.getResult().getType()); if (sliceTy.getSizes().size() == resultTy.getSizes().size()) { rewriter.replaceOp(op, slice); return success();