Add scalar type promotion for mul and div (#454)

pull/452/head snapshot-20211203.122
Daniel Garvey 2021-12-03 13:51:25 -06:00 committed by GitHub
parent c9c9b68d1f
commit b0cb49ca93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 15 deletions

View File

@ -784,6 +784,7 @@ def AddCDivModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class DropoutModule(torch.nn.Module): class DropoutModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -363,6 +363,8 @@ def RsubModule_noalpha_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4)) module.forward(tu.rand(3, 4))
# ============================================================================== # ==============================================================================
class ElementwiseMulScalarModule(torch.nn.Module): class ElementwiseMulScalarModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -379,6 +381,51 @@ class ElementwiseMulScalarModule(torch.nn.Module):
def ElementwiseMulScalarModule_basic(module, tu: TestUtils): def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4)) module.forward(tu.rand(3, 4))
class ElementwiseMulTensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float64, True),
])
def forward(self, a, b):
return torch.mul(a, b)
@register_test_case(
module_factory=lambda: ElementwiseMulTensorFloatModule())
def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(4),
tu.rand(4).type(torch.float64))
class ElementwiseMulTensorIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int32, True),
([-1], torch.int64, True),
])
def forward(self, a, b):
return torch.mul(a, b)
@register_test_case(
module_factory=lambda: ElementwiseMulTensorIntModule())
def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
module.forward(
torch.randint(10, [4]).type(torch.int32),
torch.randint(10, [4]))
# ============================================================================== # ==============================================================================
class ElementwiseLogModule(torch.nn.Module): class ElementwiseLogModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -553,7 +600,32 @@ class ElementwiseDivScalarModule(torch.nn.Module):
def ElementwiseDivScalarModule_basic(module, tu: TestUtils): def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4)) module.forward(tu.rand(3, 4))
class ElementwiseDivTensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float64, True),
])
def forward(self, a, b):
return torch.div(a, b)
@register_test_case(
module_factory=lambda: ElementwiseDivTensorFloatModule())
def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(4),
tu.rand(4).type(torch.float64))
# ============================================================================== # ==============================================================================
class ElementwiseAndIntegerModule(torch.nn.Module): class ElementwiseAndIntegerModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -573,3 +645,5 @@ class ElementwiseAndIntegerModule(torch.nn.Module):
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils): def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32), module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32),
torch.randint(-10, 10, (3, 4))) torch.randint(-10, 10, (3, 4)))

View File

@ -111,3 +111,5 @@ class TypePromotionAlphaWiderModule(torch.nn.Module):
@register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule()) @register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule())
def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils): def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand()) module.forward(tu.rand(4), tu.rand())

View File

@ -1531,24 +1531,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
} }
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) { if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
if (!mul.getType() AtenMulTensorOp::Adaptor adaptor(operands);
.cast<ValueTensorType>() Type dtype = converter->convertType(mul.getType())
.getDtype() .cast<RankedTensorType>()
.isa<mlir::FloatType>()) { .getElementType();
mul.emitError("unimplemented: non-floating point dtype"); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
return nullptr; Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
if (dtype.isa<mlir::FloatType>()) {
return b.create<arith::MulFOp>(loc, lhs, rhs);
} else {
return b.create<arith::MulIOp>(loc, lhs, rhs);
} }
return b.create<arith::MulFOp>(loc, payloadArgs[0], payloadArgs[1]);
} }
if (auto div = dyn_cast<AtenDivTensorOp>(op)) { if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
if (!div.getType() AtenDivTensorOp::Adaptor adaptor(operands);
.cast<ValueTensorType>() Type dtype = converter->convertType(div.getType())
.getDtype() .cast<RankedTensorType>()
.isa<mlir::FloatType>()) { .getElementType();
if (!dtype.isa<mlir::FloatType>())
div.emitError("unimplemented: non-floating point dtype"); div.emitError("unimplemented: non-floating point dtype");
return nullptr; Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
} Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::DivFOp>(loc, payloadArgs[0], payloadArgs[1]); return b.create<arith::DivFOp>(loc, lhs, rhs);
} }
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) { if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
if (!pow.getType() if (!pow.getType()