[MLIR][TORCH] Add integer dtype support for aten.rsub.Scalar op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/947/head
Vivek Khandelwal 2022-06-14 18:01:30 +05:30
parent b90837ee24
commit aed5517fda
3 changed files with 63 additions and 15 deletions

View File

@ -63,8 +63,8 @@ TOSA_PASS_SET = {
"MmDagModule_basic",
"Matmul_dot",
"Matmul_3d",
"RsubModule_basic",
"RsubModule_noalpha_basic",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"ElementwiseGtFloatScalarModule_basic",
"ElementwiseGtIntScalarModule_basic",
"ElementwiseGtMixed2ScalarModule_basic",

View File

@ -736,15 +736,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(rsub.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
rsub.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value self = payloadArgs[0];
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult);
if (dtype.isa<mlir::FloatType>()) {
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
return b.create<arith::SubFOp>(loc, other, mult);
} else if (dtype.isa<mlir::IntegerType>()) {
Value mult = b.create<arith::MulIOp>(loc, self, alpha);
return b.create<arith::SubIOp>(loc, other, mult);
}
rsub.emitError("unimplemented: dtype other than float and integer "
"types are not supported.");
return nullptr;
}
if (auto mulScalar = dyn_cast<AtenMulScalarOp>(op)) {
Type dtype = converter->convertType(mulScalar.getType())

View File

@ -603,7 +603,7 @@ def ElementwiseClampMaxModule_basic(module, tu: TestUtils):
# ==============================================================================
class RsubModule(torch.nn.Module):
class RsubFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -617,15 +617,15 @@ class RsubModule(torch.nn.Module):
return torch.rsub(x, 3.0, alpha=1.0)
@register_test_case(module_factory=lambda: RsubModule())
def RsubModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: RsubFloatModule())
def RsubFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class RsubModule_noalpha(torch.nn.Module):
class RsubFloatModule_noalpha(torch.nn.Module):
def __init__(self):
super().__init__()
@ -639,14 +639,58 @@ class RsubModule_noalpha(torch.nn.Module):
return torch.rsub(x, 2.0)
@register_test_case(module_factory=lambda: RsubModule_noalpha())
def RsubModule_noalpha_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: RsubFloatModule_noalpha())
def RsubFloatModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class RsubIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.rsub(x, 2, alpha=3)
@register_test_case(module_factory=lambda: RsubIntModule())
def RsubIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (3, 4)))
# ==============================================================================
class RsubIntModule_noalpha(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.rsub(x, 2.)
@register_test_case(module_factory=lambda: RsubIntModule_noalpha())
def RsubIntModule_noalpha_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (3, 4)))
# ==============================================================================
class ElementwiseMulScalarIntModule(torch.nn.Module):
def __init__(self):