mirror of https://github.com/llvm/torch-mlir
[torch] `aten.eye` should use dynamic dims when no static dims are available (#3202)
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>pull/3269/head
parent
72349f7522
commit
315dc6c3e3
|
@ -1059,44 +1059,44 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
int64_t n;
|
auto outType = op.getType().dyn_cast<BaseTensorType>();
|
||||||
|
|
||||||
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"unimplemented: n must be constant");
|
|
||||||
int64_t m;
|
|
||||||
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
|
|
||||||
return rewriter.notifyMatchFailure(op,
|
|
||||||
"unimplemented: m must be constant");
|
|
||||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
|
||||||
auto outType = dyn_cast<BaseTensorType>(op.getType());
|
|
||||||
if (!outType)
|
if (!outType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
if (!outType.hasDtype()) {
|
if (!outType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
if (n < 0) {
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
|
|
||||||
}
|
|
||||||
if (m < 0) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto context = op.getContext();
|
auto context = op.getContext();
|
||||||
auto int64Dtype = getDtypeIntValueForType(
|
auto int64Dtype = getDtypeIntValueForType(
|
||||||
rewriter, loc,
|
rewriter, loc,
|
||||||
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
||||||
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
||||||
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
|
|
||||||
|
int64_t n = kUnknownSize;
|
||||||
|
int64_t m = kUnknownSize;
|
||||||
|
// prioritize getting shape from output shape
|
||||||
|
if (outType.hasSizes() && outType.getSizes().size() == 2) {
|
||||||
|
n = outType.getSizes().front();
|
||||||
|
m = outType.getSizes().back();
|
||||||
|
}
|
||||||
|
// if output shape is not available, try to get shape from input
|
||||||
|
if (n == kUnknownSize)
|
||||||
|
matchPattern(op.getN(), m_TorchConstantInt(&n));
|
||||||
|
if (m == kUnknownSize)
|
||||||
|
matchPattern(op.getM(), m_TorchConstantInt(&m));
|
||||||
|
|
||||||
|
// prepare two unsqueezed ranges that are equal on and only on the diagonal
|
||||||
|
auto rangeNSize = llvm::SmallVector<int64_t, 1>({n});
|
||||||
|
Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type);
|
||||||
Value rangeN = rewriter.create<AtenArangeOp>(
|
Value rangeN = rewriter.create<AtenArangeOp>(
|
||||||
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
|
loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||||
/*device=*/op.getDevice(), /*pin_memory=*/none);
|
/*device=*/op.getDevice(), /*pin_memory=*/none);
|
||||||
|
|
||||||
auto arangeType1 =
|
auto rangeMSize = llvm::SmallVector<int64_t, 1>({m});
|
||||||
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
|
Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type);
|
||||||
Value rangeM = rewriter.create<AtenArangeOp>(
|
Value rangeM = rewriter.create<AtenArangeOp>(
|
||||||
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
|
loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||||
/*device=*/none, /*pin_memory=*/none);
|
/*device=*/none, /*pin_memory=*/none);
|
||||||
|
|
||||||
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
@ -1109,7 +1109,6 @@ public:
|
||||||
}
|
}
|
||||||
Value unsqzRangeN = *unsqzTensorInfo;
|
Value unsqzRangeN = *unsqzTensorInfo;
|
||||||
|
|
||||||
// compare unsqueezed input with boundaries
|
|
||||||
auto eqType = ValueTensorType::get(
|
auto eqType = ValueTensorType::get(
|
||||||
context, cast<BaseTensorType>(op.getType()).getSizes(),
|
context, cast<BaseTensorType>(op.getType()).getSizes(),
|
||||||
IntegerType::get(context, 1));
|
IntegerType::get(context, 1));
|
||||||
|
|
Loading…
Reference in New Issue