mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for bool type in convertScalarToDtype utility
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/792/merge
parent
708a51ae2e
commit
4605dc9c99
|
@ -242,8 +242,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
|||
return false;
|
||||
};
|
||||
|
||||
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
|
||||
dtype.isSignlessInteger(1)) {
|
||||
if (isByteOrChar(scalarType) || isByteOrChar(dtype)) {
|
||||
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
|
||||
mlir::emitError(loc)
|
||||
<< "unsupported byte, char or bool type for convertScalarToDtype "
|
||||
|
@ -251,6 +250,24 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// If the dtype is i1, i.e., a boolean type.
|
||||
if (dtype.isSignlessInteger(1)) {
|
||||
Type scalarType = scalar.getType();
|
||||
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(scalarType));
|
||||
if (scalarType.isa<mlir::FloatType>()) {
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, scalar,
|
||||
cstZero);
|
||||
} else if (scalarType.isa<mlir::IntegerType>()) {
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, scalar,
|
||||
cstZero);
|
||||
} else {
|
||||
mlir::emitError(loc)
|
||||
<< "unsupported scalar type for convertScalarToDtype " << scalarType
|
||||
<< "(scalar type) -> " << dtype << "(dtype)";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
||||
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
||||
|
|
|
@ -191,3 +191,26 @@ class ToDtypeLayoutStridedModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ToDtypeLayoutStridedModule())
|
||||
def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
dtype=torch.bool,
|
||||
layout=None,
|
||||
device=None,
|
||||
pin_memory=None,
|
||||
non_blocking=False,
|
||||
copy=False,
|
||||
memory_format=None)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule())
|
||||
def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
|
Loading…
Reference in New Issue