mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add integer dtype support for aten.rsub.Scalar op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/947/head
parent
b90837ee24
commit
aed5517fda
|
@ -63,8 +63,8 @@ TOSA_PASS_SET = {
|
|||
"MmDagModule_basic",
|
||||
"Matmul_dot",
|
||||
"Matmul_3d",
|
||||
"RsubModule_basic",
|
||||
"RsubModule_noalpha_basic",
|
||||
"RsubFloatModule_basic",
|
||||
"RsubFloatModule_noalpha_basic",
|
||||
"ElementwiseGtFloatScalarModule_basic",
|
||||
"ElementwiseGtIntScalarModule_basic",
|
||||
"ElementwiseGtMixed2ScalarModule_basic",
|
||||
|
|
|
@ -736,15 +736,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Type dtype = converter->convertType(rsub.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
rsub.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value self = payloadArgs[0];
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
Value alpha = convertScalarToDtype(b, loc, operands[2], dtype);
|
||||
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
|
||||
return b.create<arith::SubFOp>(loc, other, mult);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
Value mult = b.create<arith::MulFOp>(loc, self, alpha);
|
||||
return b.create<arith::SubFOp>(loc, other, mult);
|
||||
} else if (dtype.isa<mlir::IntegerType>()) {
|
||||
Value mult = b.create<arith::MulIOp>(loc, self, alpha);
|
||||
return b.create<arith::SubIOp>(loc, other, mult);
|
||||
}
|
||||
rsub.emitError("unimplemented: dtype other than float and integer "
|
||||
"types are not supported.");
|
||||
return nullptr;
|
||||
}
|
||||
if (auto mulScalar = dyn_cast<AtenMulScalarOp>(op)) {
|
||||
Type dtype = converter->convertType(mulScalar.getType())
|
||||
|
|
|
@ -603,7 +603,7 @@ def ElementwiseClampMaxModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class RsubModule(torch.nn.Module):
|
||||
class RsubFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -617,15 +617,15 @@ class RsubModule(torch.nn.Module):
|
|||
return torch.rsub(x, 3.0, alpha=1.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RsubModule())
|
||||
def RsubModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: RsubFloatModule())
|
||||
def RsubFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RsubModule_noalpha(torch.nn.Module):
|
||||
class RsubFloatModule_noalpha(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -639,14 +639,58 @@ class RsubModule_noalpha(torch.nn.Module):
|
|||
return torch.rsub(x, 2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RsubModule_noalpha())
|
||||
def RsubModule_noalpha_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: RsubFloatModule_noalpha())
|
||||
def RsubFloatModule_noalpha_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RsubIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.rsub(x, 2, alpha=3)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RsubIntModule())
|
||||
def RsubIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 4)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RsubIntModule_noalpha(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.rsub(x, 2.)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RsubIntModule_noalpha())
|
||||
def RsubIntModule_noalpha_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 4)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMulScalarIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue