mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add torch.Device type to backend contract scalar types
Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1266/head
parent
9176b5ed29
commit
8cad02f87e
|
@ -124,6 +124,7 @@ TOSA_PASS_SET = {
|
||||||
"OnesModuleInt_basic",
|
"OnesModuleInt_basic",
|
||||||
"OnesModuleFloat_basic",
|
"OnesModuleFloat_basic",
|
||||||
"OnesModuleFalsePinMemory_basic",
|
"OnesModuleFalsePinMemory_basic",
|
||||||
|
"OnesModuleCPUDevice_basic",
|
||||||
"NewZerosModuleDefaultDtype_basic",
|
"NewZerosModuleDefaultDtype_basic",
|
||||||
"NewZerosModuleInt2D_basic",
|
"NewZerosModuleInt2D_basic",
|
||||||
"NewZerosModuleInt3D_basic",
|
"NewZerosModuleInt3D_basic",
|
||||||
|
|
|
@ -30,7 +30,8 @@ using namespace mlir::torch::Torch;
|
||||||
static LogicalResult checkType(Operation *op, Type type,
|
static LogicalResult checkType(Operation *op, Type type,
|
||||||
bool actuallyEmitDiagnostics) {
|
bool actuallyEmitDiagnostics) {
|
||||||
// Allow various scalar types that backends are expected to be able to handle.
|
// Allow various scalar types that backends are expected to be able to handle.
|
||||||
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType>())
|
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
|
||||||
|
Torch::DeviceType>())
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// Backends are not expected to support dynamic computations on these types,
|
// Backends are not expected to support dynamic computations on these types,
|
||||||
|
|
|
@ -195,6 +195,24 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class OnesModuleCPUDevice(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.ones(3, 4, device="cpu")
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: OnesModuleCPUDevice())
|
||||||
|
def OnesModuleCPUDevice_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue