mirror of https://github.com/llvm/torch-mlir
add e2e support for torch.eye operations (aten.eye, aten.eye.m) (#2478)
parent
e12937c642
commit
a2e694df40
|
@ -1283,6 +1283,14 @@ TOSA_PASS_SET = {
|
|||
"_SoftmaxModule_basic",
|
||||
"ElementwiseAddScalarInt8Module_basic",
|
||||
"ElementwiseSubTensorInt8Module_basic",
|
||||
"AtenEyeMModuleCPUDevice_basic",
|
||||
"AtenEyeMModuleDefaultDtype_basic",
|
||||
"AtenEyeMModuleFalsePinMemory_basic",
|
||||
"AtenEyeMModuleFloat2D_basic",
|
||||
"AtenEyeModuleCPUDevice_basic",
|
||||
"AtenEyeModuleDefaultDtype_basic",
|
||||
"AtenEyeModuleFalsePinMemory_basic",
|
||||
"AtenEyeModuleFloat2D_basic",
|
||||
}
|
||||
|
||||
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
||||
|
@ -1310,6 +1318,9 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
|
|||
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||
# RuntimeError: The size of tensor a (7) must match the size of tensor b (3) at non-singleton dimension 1
|
||||
"Add_Module_basic",
|
||||
# failed to legalize operation 'torch.aten.to.dtype' that was explicitly marked illegal
|
||||
"AtenEyeModuleInt2D_basic",
|
||||
"AtenEyeMModuleInt2D_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
|
||||
|
|
|
@ -7273,6 +7273,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.zeros\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.eye\"(%arg0: !torch.int, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" %0 = torch.prim.ListConstruct %arg0, %arg0 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.eye.m\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.ones\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
@ -10561,6 +10569,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.eye\"(%arg0: !torch.int, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.eye.m\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.ones\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
|
|
|
@ -419,6 +419,103 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenEyeOp : public OpRewritePattern<AtenEyeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenEyeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value n = op.getN();
|
||||
Value m = op.getN();
|
||||
rewriter.replaceOpWithNewOp<AtenEyeMOp>(op, op.getType(), n, m,
|
||||
op.getDtype(), op.getLayout(),
|
||||
op.getDevice(), op.getPinMemory());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenEyeMOp : public OpRewritePattern<AtenEyeMOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenEyeMOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
int64_t n;
|
||||
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: n must be constant");
|
||||
int64_t m;
|
||||
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unimplemented: m must be constant");
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
auto outType = op.getType().dyn_cast<BaseTensorType>();
|
||||
if (!outType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
if (!outType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
if (n < 0) {
|
||||
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
|
||||
}
|
||||
if (m < 0) {
|
||||
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
|
||||
}
|
||||
|
||||
auto context = op.getContext();
|
||||
auto int64Dtype = getDtypeIntValueForType(
|
||||
rewriter, loc,
|
||||
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
|
||||
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
||||
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
|
||||
Value rangeN = rewriter.create<AtenArangeOp>(
|
||||
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
/*device=*/op.getDevice(), /*pin_memory=*/none);
|
||||
|
||||
auto arangeType1 =
|
||||
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
|
||||
Value rangeM = rewriter.create<AtenArangeOp>(
|
||||
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
|
||||
/*device=*/none, /*pin_memory=*/none);
|
||||
|
||||
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(-1));
|
||||
auto unsqzTensorInfo =
|
||||
unsqueezeTensor(rewriter, op, rangeN, /*dim=*/constMinusOne);
|
||||
if (failed(unsqzTensorInfo)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"cannot generate unsqueeze tensor");
|
||||
}
|
||||
Value unsqzRangeN = *unsqzTensorInfo;
|
||||
|
||||
// compare unsqueezed input with boundaries
|
||||
auto eqType = ValueTensorType::get(
|
||||
context, op.getType().cast<BaseTensorType>().getSizes(),
|
||||
IntegerType::get(context, 1));
|
||||
Value eqTensor =
|
||||
rewriter.create<AtenEqTensorOp>(loc, eqType, unsqzRangeN, rangeM);
|
||||
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().isa<Torch::BoolType>()) {
|
||||
rewriter.replaceOp(op, eqTensor);
|
||||
return success();
|
||||
} else {
|
||||
auto zero =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
auto one =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
Value outTensor =
|
||||
rewriter.create<AtenWhereScalarOp>(loc, outType, eqTensor, one, zero);
|
||||
rewriter.replaceOp(op, outTensor);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenIsnanOp : public OpRewritePattern<AtenIsnanOp> {
|
||||
public:
|
||||
|
@ -5358,6 +5455,8 @@ public:
|
|||
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
||||
|
|
|
@ -423,6 +423,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenBernoulliPOp>();
|
||||
target.addIllegalOp<AtenBernoulliTensorOp>();
|
||||
target.addIllegalOp<AtenZeroOp>();
|
||||
target.addIllegalOp<AtenEyeOp>();
|
||||
target.addIllegalOp<AtenEyeMOp>();
|
||||
target.addIllegalOp<AtenIsnanOp>();
|
||||
target.addIllegalOp<AtenRandLikeOp>();
|
||||
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||
|
|
|
@ -660,6 +660,12 @@ def aten〇scaled_dot_product_attention〡shape(query: List[int], key: List[int]
|
|||
def aten〇zeros〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return size
|
||||
|
||||
def aten〇eye〡shape(n: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return [n, n]
|
||||
|
||||
def aten〇eye〇m〡shape(n: int, m: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return [n, m]
|
||||
|
||||
def aten〇ones〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return size
|
||||
|
||||
|
@ -3255,6 +3261,20 @@ def aten〇tensor〇bool〡dtype(t: bool, dtype: Optional[int] = None, device: O
|
|||
def aten〇zeros〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
return torch.float32 if dtype is None else dtype
|
||||
|
||||
@check_dtype_function([Invocation(2),
|
||||
Invocation(2, dtype=torch.int32),
|
||||
Invocation(2, dtype=torch.float16),
|
||||
Invocation(2, dtype=torch.complex64)])
|
||||
def aten〇eye〡dtype(n: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
return torch.float32 if dtype is None else dtype
|
||||
|
||||
@check_dtype_function([Invocation(2, 3),
|
||||
Invocation(2, 3, dtype=torch.int32),
|
||||
Invocation(2, 3, dtype=torch.float16),
|
||||
Invocation(2, 3, dtype=torch.complex64)])
|
||||
def aten〇eye〇m〡dtype(n: int, m: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
return torch.float32 if dtype is None else dtype
|
||||
|
||||
@check_dtype_function([Invocation([1]),
|
||||
Invocation([1], dtype=torch.int32),
|
||||
Invocation([1], dtype=torch.float16),
|
||||
|
|
|
@ -215,6 +215,186 @@ def OnesModuleCPUDevice_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenEyeModuleDefaultDtype(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.ops.aten.eye(3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEyeModuleDefaultDtype())
|
||||
def AtenEyeModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class AtenEyeModuleInt2D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.eye(3, dtype=torch.int64)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEyeModuleInt2D())
|
||||
def AtenEyeModuleInt2D_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class AtenEyeModuleFloat2D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.eye(3, dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEyeModuleFloat2D())
|
||||
def AtenEyeModuleFloat2D_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
class AtenEyeModuleFalsePinMemory(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.eye(3, dtype=torch.float32, pin_memory=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEyeModuleFalsePinMemory())
|
||||
def AtenEyeModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class AtenEyeModuleCPUDevice(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.eye(3, device="cpu")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEyeModuleCPUDevice())
|
||||
def AtenEyeModuleCPUDevice_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenEyeMModuleDefaultDtype(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.ops.aten.eye(3, 4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEyeMModuleDefaultDtype())
|
||||
def AtenEyeMModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class AtenEyeMModuleInt2D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.eye(3, 4, dtype=torch.int64)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEyeMModuleInt2D())
|
||||
def AtenEyeMModuleInt2D_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class AtenEyeMModuleFloat2D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.eye(3, 4, dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEyeMModuleFloat2D())
|
||||
def AtenEyeMModuleFloat2D_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
class AtenEyeMModuleFalsePinMemory(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.eye(3, 4, dtype=torch.float32, pin_memory=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEyeMModuleFalsePinMemory())
|
||||
def AtenEyeMModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class AtenEyeMModuleCPUDevice(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.eye(3, 4, device="cpu")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEyeMModuleCPUDevice())
|
||||
def AtenEyeMModuleCPUDevice_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class EmptyContiguousModule(torch.nn.Module):
|
||||
|
||||
|
|
Loading…
Reference in New Issue