[MLIR][TORCH] Add support for bool type in convertScalarToDtype utility

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/792/merge
Vivek Khandelwal 2022-06-14 19:35:22 +05:30
parent 708a51ae2e
commit 4605dc9c99
2 changed files with 42 additions and 2 deletions

View File

@ -242,8 +242,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
return false; return false;
}; };
if (isByteOrChar(scalarType) || isByteOrChar(dtype) || if (isByteOrChar(scalarType) || isByteOrChar(dtype)) {
dtype.isSignlessInteger(1)) {
// TODO: Handle to-boolean conversion(from-boolean conversion is handled). // TODO: Handle to-boolean conversion(from-boolean conversion is handled).
mlir::emitError(loc) mlir::emitError(loc)
<< "unsupported byte, char or bool type for convertScalarToDtype " << "unsupported byte, char or bool type for convertScalarToDtype "
@ -251,6 +250,24 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
return nullptr; return nullptr;
} }
// If the dtype is i1, i.e., a boolean type.
if (dtype.isSignlessInteger(1)) {
Type scalarType = scalar.getType();
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(scalarType));
if (scalarType.isa<mlir::FloatType>()) {
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, scalar,
cstZero);
} else if (scalarType.isa<mlir::IntegerType>()) {
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, scalar,
cstZero);
} else {
mlir::emitError(loc)
<< "unsupported scalar type for convertScalarToDtype " << scalarType
<< "(scalar type) -> " << dtype << "(dtype)";
return nullptr;
}
}
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) { if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) { if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
if (scalarFloat.getWidth() > dtypeFloat.getWidth()) if (scalarFloat.getWidth() > dtypeFloat.getWidth())

View File

@ -191,3 +191,26 @@ class ToDtypeLayoutStridedModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ToDtypeLayoutStridedModule()) @register_test_case(module_factory=lambda: ToDtypeLayoutStridedModule())
def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils): def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5)) module.forward(tu.rand(3, 5))
class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1, -1], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.to(x,
dtype=torch.bool,
layout=None,
device=None,
pin_memory=None,
non_blocking=False,
copy=False,
memory_format=None)
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule())
def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))