mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for bool type in convertScalarToDtype utility
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/792/merge
parent
708a51ae2e
commit
4605dc9c99
|
@ -242,8 +242,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (isByteOrChar(scalarType) || isByteOrChar(dtype) ||
|
if (isByteOrChar(scalarType) || isByteOrChar(dtype)) {
|
||||||
dtype.isSignlessInteger(1)) {
|
|
||||||
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
|
// TODO: Handle to-boolean conversion(from-boolean conversion is handled).
|
||||||
mlir::emitError(loc)
|
mlir::emitError(loc)
|
||||||
<< "unsupported byte, char or bool type for convertScalarToDtype "
|
<< "unsupported byte, char or bool type for convertScalarToDtype "
|
||||||
|
@ -251,6 +250,24 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the dtype is i1, i.e., a boolean type.
|
||||||
|
if (dtype.isSignlessInteger(1)) {
|
||||||
|
Type scalarType = scalar.getType();
|
||||||
|
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(scalarType));
|
||||||
|
if (scalarType.isa<mlir::FloatType>()) {
|
||||||
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, scalar,
|
||||||
|
cstZero);
|
||||||
|
} else if (scalarType.isa<mlir::IntegerType>()) {
|
||||||
|
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, scalar,
|
||||||
|
cstZero);
|
||||||
|
} else {
|
||||||
|
mlir::emitError(loc)
|
||||||
|
<< "unsupported scalar type for convertScalarToDtype " << scalarType
|
||||||
|
<< "(scalar type) -> " << dtype << "(dtype)";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
if (auto dtypeFloat = dtype.dyn_cast<mlir::FloatType>()) {
|
||||||
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
if (auto scalarFloat = scalarType.dyn_cast<mlir::FloatType>()) {
|
||||||
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
||||||
|
|
|
@ -191,3 +191,26 @@ class ToDtypeLayoutStridedModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ToDtypeLayoutStridedModule())
|
@register_test_case(module_factory=lambda: ToDtypeLayoutStridedModule())
|
||||||
def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils):
|
def ToDtypeLayoutStridedModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5))
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
|
class ToDtypeBoolLayoutNoneModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.to(x,
|
||||||
|
dtype=torch.bool,
|
||||||
|
layout=None,
|
||||||
|
device=None,
|
||||||
|
pin_memory=None,
|
||||||
|
non_blocking=False,
|
||||||
|
copy=False,
|
||||||
|
memory_format=None)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ToDtypeBoolLayoutNoneModule())
|
||||||
|
def ToDtypeBoolLayoutNoneModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
Loading…
Reference in New Issue