[MLIR][TORCH] Add torch.Device type to backend contract scalar types

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1266/head
Vivek Khandelwal 2022-08-19 12:23:21 +05:30
parent 9176b5ed29
commit 8cad02f87e
3 changed files with 21 additions and 1 deletions

View File

@ -124,6 +124,7 @@ TOSA_PASS_SET = {
"OnesModuleInt_basic",
"OnesModuleFloat_basic",
"OnesModuleFalsePinMemory_basic",
"OnesModuleCPUDevice_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",

View File

@ -30,7 +30,8 @@ using namespace mlir::torch::Torch;
static LogicalResult checkType(Operation *op, Type type,
bool actuallyEmitDiagnostics) {
// Allow various scalar types that backends are expected to be able to handle.
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType>())
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
Torch::DeviceType>())
return success();
// Backends are not expected to support dynamic computations on these types,

View File

@ -195,6 +195,24 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()
class OnesModuleCPUDevice(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4, device="cpu")
@register_test_case(module_factory=lambda: OnesModuleCPUDevice())
def OnesModuleCPUDevice_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================