add e2e support for torch.eye operations (aten.eye, aten.eye.m) (#2478)

pull/2524/head
saienduri 2023-11-01 11:23:28 -07:00 committed by GitHub
parent e12937c642
commit a2e694df40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 344 additions and 0 deletions

View File

@ -1283,6 +1283,14 @@ TOSA_PASS_SET = {
"_SoftmaxModule_basic",
"ElementwiseAddScalarInt8Module_basic",
"ElementwiseSubTensorInt8Module_basic",
"AtenEyeMModuleCPUDevice_basic",
"AtenEyeMModuleDefaultDtype_basic",
"AtenEyeMModuleFalsePinMemory_basic",
"AtenEyeMModuleFloat2D_basic",
"AtenEyeModuleCPUDevice_basic",
"AtenEyeModuleDefaultDtype_basic",
"AtenEyeModuleFalsePinMemory_basic",
"AtenEyeModuleFloat2D_basic",
}
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
@ -1310,6 +1318,9 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
"IndexPutImpl1DIntNonAccumulateModule_basic",
# RuntimeError: The size of tensor a (7) must match the size of tensor b (3) at non-singleton dimension 1
"Add_Module_basic",
# failed to legalize operation 'torch.aten.to.dtype' that was explicitly marked illegal
"AtenEyeModuleInt2D_basic",
"AtenEyeMModuleInt2D_basic",
}
if torch_version_for_comparison() < version.parse("2.1.0.dev"):

View File

@ -7273,6 +7273,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.zeros\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.eye\"(%arg0: !torch.int, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct %arg0, %arg0 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.eye.m\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.ones\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
@ -10561,6 +10569,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.eye\"(%arg0: !torch.int, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.eye.m\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ones\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"

View File

@ -419,6 +419,103 @@ public:
};
} // namespace
namespace {
class DecomposeAtenEyeOp : public OpRewritePattern<AtenEyeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenEyeOp op,
PatternRewriter &rewriter) const override {
Value n = op.getN();
Value m = op.getN();
rewriter.replaceOpWithNewOp<AtenEyeMOp>(op, op.getType(), n, m,
op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenEyeMOp : public OpRewritePattern<AtenEyeMOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenEyeMOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
int64_t n;
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
return rewriter.notifyMatchFailure(op,
"unimplemented: n must be constant");
int64_t m;
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
return rewriter.notifyMatchFailure(op,
"unimplemented: m must be constant");
Value none = rewriter.create<ConstantNoneOp>(loc);
auto outType = op.getType().dyn_cast<BaseTensorType>();
if (!outType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
if (!outType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
if (n < 0) {
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
}
if (m < 0) {
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
}
auto context = op.getContext();
auto int64Dtype = getDtypeIntValueForType(
rewriter, loc,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
Value rangeN = rewriter.create<AtenArangeOp>(
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/op.getDevice(), /*pin_memory=*/none);
auto arangeType1 =
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
Value rangeM = rewriter.create<AtenArangeOp>(
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
auto unsqzTensorInfo =
unsqueezeTensor(rewriter, op, rangeN, /*dim=*/constMinusOne);
if (failed(unsqzTensorInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}
Value unsqzRangeN = *unsqzTensorInfo;
// compare unsqueezed input with boundaries
auto eqType = ValueTensorType::get(
context, op.getType().cast<BaseTensorType>().getSizes(),
IntegerType::get(context, 1));
Value eqTensor =
rewriter.create<AtenEqTensorOp>(loc, eqType, unsqzRangeN, rangeM);
Value dtype = op.getDtype();
if (dtype.getType().isa<Torch::BoolType>()) {
rewriter.replaceOp(op, eqTensor);
return success();
} else {
auto zero =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
auto one =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value outTensor =
rewriter.create<AtenWhereScalarOp>(loc, outType, eqTensor, one, zero);
rewriter.replaceOp(op, outTensor);
return success();
}
}
};
} // namespace
namespace {
class DecomposeAtenIsnanOp : public OpRewritePattern<AtenIsnanOp> {
public:
@ -5358,6 +5455,8 @@ public:
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);

View File

@ -423,6 +423,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenBernoulliPOp>();
target.addIllegalOp<AtenBernoulliTensorOp>();
target.addIllegalOp<AtenZeroOp>();
target.addIllegalOp<AtenEyeOp>();
target.addIllegalOp<AtenEyeMOp>();
target.addIllegalOp<AtenIsnanOp>();
target.addIllegalOp<AtenRandLikeOp>();
target.addIllegalOp<AtenHardsigmoidOp>();

View File

@ -660,6 +660,12 @@ def atenscaled_dot_product_attention〡shape(query: List[int], key: List[int]
def atenzeros〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
def ateneye〡shape(n: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return [n, n]
def ateneyem〡shape(n: int, m: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return [n, m]
def atenones〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
@ -3255,6 +3261,20 @@ def atentensorbool〡dtype(t: bool, dtype: Optional[int] = None, device: O
def atenzeros〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
return torch.float32 if dtype is None else dtype
@check_dtype_function([Invocation(2),
Invocation(2, dtype=torch.int32),
Invocation(2, dtype=torch.float16),
Invocation(2, dtype=torch.complex64)])
def ateneye〡dtype(n: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
return torch.float32 if dtype is None else dtype
@check_dtype_function([Invocation(2, 3),
Invocation(2, 3, dtype=torch.int32),
Invocation(2, 3, dtype=torch.float16),
Invocation(2, 3, dtype=torch.complex64)])
def ateneyem〡dtype(n: int, m: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
return torch.float32 if dtype is None else dtype
@check_dtype_function([Invocation([1]),
Invocation([1], dtype=torch.int32),
Invocation([1], dtype=torch.float16),

View File

@ -215,6 +215,186 @@ def OnesModuleCPUDevice_basic(module, tu: TestUtils):
# ==============================================================================
class AtenEyeModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.eye(3)
@register_test_case(module_factory=lambda: AtenEyeModuleDefaultDtype())
def AtenEyeModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward()
class AtenEyeModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.eye(3, dtype=torch.int64)
@register_test_case(module_factory=lambda: AtenEyeModuleInt2D())
def AtenEyeModuleInt2D_basic(module, tu: TestUtils):
module.forward()
class AtenEyeModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.eye(3, dtype=torch.float32)
@register_test_case(module_factory=lambda: AtenEyeModuleFloat2D())
def AtenEyeModuleFloat2D_basic(module, tu: TestUtils):
module.forward()
class AtenEyeModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.eye(3, dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: AtenEyeModuleFalsePinMemory())
def AtenEyeModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
class AtenEyeModuleCPUDevice(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.eye(3, device="cpu")
@register_test_case(module_factory=lambda: AtenEyeModuleCPUDevice())
def AtenEyeModuleCPUDevice_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class AtenEyeMModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.eye(3, 4)
@register_test_case(module_factory=lambda: AtenEyeMModuleDefaultDtype())
def AtenEyeMModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward()
class AtenEyeMModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.eye(3, 4, dtype=torch.int64)
@register_test_case(module_factory=lambda: AtenEyeMModuleInt2D())
def AtenEyeMModuleInt2D_basic(module, tu: TestUtils):
module.forward()
class AtenEyeMModuleFloat2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.eye(3, 4, dtype=torch.float32)
@register_test_case(module_factory=lambda: AtenEyeMModuleFloat2D())
def AtenEyeMModuleFloat2D_basic(module, tu: TestUtils):
module.forward()
class AtenEyeMModuleFalsePinMemory(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.eye(3, 4, dtype=torch.float32, pin_memory=False)
@register_test_case(module_factory=lambda: AtenEyeMModuleFalsePinMemory())
def AtenEyeMModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
class AtenEyeMModuleCPUDevice(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.eye(3, 4, device="cpu")
@register_test_case(module_factory=lambda: AtenEyeMModuleCPUDevice())
def AtenEyeMModuleCPUDevice_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class EmptyContiguousModule(torch.nn.Module):