mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add torch.Device type to backend contract scalar types
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1266/head
parent
9176b5ed29
commit
8cad02f87e
|
@ -124,6 +124,7 @@ TOSA_PASS_SET = {
|
|||
"OnesModuleInt_basic",
|
||||
"OnesModuleFloat_basic",
|
||||
"OnesModuleFalsePinMemory_basic",
|
||||
"OnesModuleCPUDevice_basic",
|
||||
"NewZerosModuleDefaultDtype_basic",
|
||||
"NewZerosModuleInt2D_basic",
|
||||
"NewZerosModuleInt3D_basic",
|
||||
|
|
|
@ -30,7 +30,8 @@ using namespace mlir::torch::Torch;
|
|||
static LogicalResult checkType(Operation *op, Type type,
|
||||
bool actuallyEmitDiagnostics) {
|
||||
// Allow various scalar types that backends are expected to be able to handle.
|
||||
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType>())
|
||||
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
||||
Torch::DeviceType>())
|
||||
return success();
|
||||
|
||||
// Backends are not expected to support dynamic computations on these types,
|
||||
|
|
|
@ -195,6 +195,24 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
|
|||
module.forward()
|
||||
|
||||
|
||||
class OnesModuleCPUDevice(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.ones(3, 4, device="cpu")
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: OnesModuleCPUDevice())
|
||||
def OnesModuleCPUDevice_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue