Add support for int types in gtScalar op.

Support for integer types in gtScalar op has been added.
The code share same logic with gtTensor op and can be merged
which is added as a TODO.
pull/480/head snapshot-20211213.142
Prashant Kumar 2021-12-13 21:02:47 +05:30
parent 8d4879feb0
commit 6dabf185f5
2 changed files with 67 additions and 10 deletions

View File

@ -326,7 +326,7 @@ def ElementwiseMaximumModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseGtScalarModule(torch.nn.Module): class ElementwiseGtFloatScalarModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -339,10 +339,47 @@ class ElementwiseGtScalarModule(torch.nn.Module):
return torch.gt(x, 0.6) return torch.gt(x, 0.6)
@register_test_case(module_factory=lambda: ElementwiseGtScalarModule()) @register_test_case(module_factory=lambda: ElementwiseGtFloatScalarModule())
def ElementwiseGtScalarModule_basic(module, tu: TestUtils): def ElementwiseGtFloatScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5)) module.forward(tu.rand(3, 5))
class ElementwiseGtIntScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.gt(x, 10)
@register_test_case(module_factory=lambda: ElementwiseGtIntScalarModule())
def ElementwiseGtIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3,4)))
class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.gt(x, 7)
@register_test_case(module_factory=lambda: ElementwiseGtMixed2ScalarModule())
def ElementwiseGtMixed2ScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 15, (3,4)).to(torch.int32))
class ElementwiseGtFloatTensorModule(torch.nn.Module): class ElementwiseGtFloatTensorModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -361,6 +398,7 @@ class ElementwiseGtFloatTensorModule(torch.nn.Module):
def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils): def ElementwiseGtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(5)) module.forward(tu.rand(3, 5), tu.rand(5))
class ElementwiseGtIntTensorModule(torch.nn.Module): class ElementwiseGtIntTensorModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -1735,14 +1735,33 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) { if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
Type dtype = gtScalar.self().getType().cast<ValueTensorType>().getDtype(); Type dtype = gtScalar.self().getType().cast<BaseTensorType>().getDtype();
if (!dtype.isa<mlir::FloatType>()) {
gtScalar.emitError("unimplemented: non-floating point operand dtype"); // TODO: `gtTensor` and `gtScalar` share similar code and can be called from
return nullptr; // one static function.
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (dtype.isa<mlir::FloatType>())
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dtype.dyn_cast<mlir::IntegerType>()) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float.
gtScalar.emitError(
"unimplemented: type promotion from tensor to scalar.");
return nullptr;
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
payloadArgs[0], otherPromoted);
} }
Value otherPromoted = convertScalarToDtype(b, loc, operands[1], dtype); gtScalar.emitError("unimplemented: dtype isn't supported.");
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, return nullptr;
payloadArgs[0], otherPromoted);
} }
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) { if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {