mirror of https://github.com/llvm/torch-mlir
[torch] `aten.eye` should use dynamic dims when no static dims are available (#3202)
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>pull/3269/head
parent
72349f7522
commit
315dc6c3e3
|
@ -1059,44 +1059,44 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
int64_t n;
|
||||
|
||||
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: n must be constant");
|
||||
int64_t m;
|
||||
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: m must be constant");
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
auto outType = dyn_cast<BaseTensorType>(op.getType());
|
||||
auto outType = op.getType().dyn_cast<BaseTensorType>();
|
||||
if (!outType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
if (!outType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
if (n < 0) {
|
||||
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
|
||||
}
|
||||
if (m < 0) {
|
||||
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
|
||||
}
|
||||
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
auto context = op.getContext();
|
||||
auto int64Dtype = getDtypeIntValueForType(
|
||||
rewriter, loc,
|
||||
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
||||
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
||||
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
|
||||
|
||||
int64_t n = kUnknownSize;
|
||||
int64_t m = kUnknownSize;
|
||||
// prioritize getting shape from output shape
|
||||
if (outType.hasSizes() && outType.getSizes().size() == 2) {
|
||||
n = outType.getSizes().front();
|
||||
m = outType.getSizes().back();
|
||||
}
|
||||
// if output shape is not available, try to get shape from input
|
||||
if (n == kUnknownSize)
|
||||
matchPattern(op.getN(), m_TorchConstantInt(&n));
|
||||
if (m == kUnknownSize)
|
||||
matchPattern(op.getM(), m_TorchConstantInt(&m));
|
||||
|
||||
// prepare two unsqueezed ranges that are equal on and only on the diagonal
|
||||
auto rangeNSize = llvm::SmallVector<int64_t, 1>({n});
|
||||
Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type);
|
||||
Value rangeN = rewriter.create<AtenArangeOp>(
|
||||
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
/*device=*/op.getDevice(), /*pin_memory=*/none);
|
||||
|
||||
auto arangeType1 =
|
||||
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
|
||||
auto rangeMSize = llvm::SmallVector<int64_t, 1>({m});
|
||||
Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type);
|
||||
Value rangeM = rewriter.create<AtenArangeOp>(
|
||||
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
|
||||
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
@ -1109,7 +1109,6 @@ public:
|
|||
}
|
||||
Value unsqzRangeN = *unsqzTensorInfo;
|
||||
|
||||
// compare unsqueezed input with boundaries
|
||||
auto eqType = ValueTensorType::get(
|
||||
context, cast<BaseTensorType>(op.getType()).getSizes(),
|
||||
IntegerType::get(context, 1));
|
||||
|
|
Loading…
Reference in New Issue