[torch] `aten.eye` should use dynamic dims when no static dims are available (#3202)

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
pull/3269/head
Xida Ren (Cedar) 2024-04-30 13:41:03 -04:00 committed by GitHub
parent 72349f7522
commit 315dc6c3e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 23 additions and 24 deletions

View File

@ -1059,44 +1059,44 @@ public:
LogicalResult matchAndRewrite(AtenEyeMOp op, LogicalResult matchAndRewrite(AtenEyeMOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
int64_t n; auto outType = op.getType().dyn_cast<BaseTensorType>();
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
return rewriter.notifyMatchFailure(op,
"unimplemented: n must be constant");
int64_t m;
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
return rewriter.notifyMatchFailure(op,
"unimplemented: m must be constant");
Value none = rewriter.create<ConstantNoneOp>(loc);
auto outType = dyn_cast<BaseTensorType>(op.getType());
if (!outType) if (!outType)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported"); op, "Only tensor types input are currently supported");
if (!outType.hasDtype()) { if (!outType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype"); return rewriter.notifyMatchFailure(op, "result should have dtype");
} }
if (n < 0) { Value none = rewriter.create<ConstantNoneOp>(loc);
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
}
if (m < 0) {
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
}
auto context = op.getContext(); auto context = op.getContext();
auto int64Dtype = getDtypeIntValueForType( auto int64Dtype = getDtypeIntValueForType(
rewriter, loc, rewriter, loc,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
int64_t n = kUnknownSize;
int64_t m = kUnknownSize;
// prioritize getting shape from output shape
if (outType.hasSizes() && outType.getSizes().size() == 2) {
n = outType.getSizes().front();
m = outType.getSizes().back();
}
// if output shape is not available, try to get shape from input
if (n == kUnknownSize)
matchPattern(op.getN(), m_TorchConstantInt(&n));
if (m == kUnknownSize)
matchPattern(op.getM(), m_TorchConstantInt(&m));
// prepare two unsqueezed ranges that are equal on and only on the diagonal
auto rangeNSize = llvm::SmallVector<int64_t, 1>({n});
Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type);
Value rangeN = rewriter.create<AtenArangeOp>( Value rangeN = rewriter.create<AtenArangeOp>(
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/op.getDevice(), /*pin_memory=*/none); /*device=*/op.getDevice(), /*pin_memory=*/none);
auto arangeType1 = auto rangeMSize = llvm::SmallVector<int64_t, 1>({m});
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type); Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type);
Value rangeM = rewriter.create<AtenArangeOp>( Value rangeM = rewriter.create<AtenArangeOp>(
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none); /*device=*/none, /*pin_memory=*/none);
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>( Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
@ -1109,7 +1109,6 @@ public:
} }
Value unsqzRangeN = *unsqzTensorInfo; Value unsqzRangeN = *unsqzTensorInfo;
// compare unsqueezed input with boundaries
auto eqType = ValueTensorType::get( auto eqType = ValueTensorType::get(
context, cast<BaseTensorType>(op.getType()).getSizes(), context, cast<BaseTensorType>(op.getType()).getSizes(),
IntegerType::get(context, 1)); IntegerType::get(context, 1));