mirror of https://github.com/llvm/torch-mlir
parent
c9c9b68d1f
commit
b0cb49ca93
|
@ -784,6 +784,7 @@ def AddCDivModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DropoutModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -363,6 +363,8 @@ def RsubModule_noalpha_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMulScalarModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -379,6 +381,51 @@ class ElementwiseMulScalarModule(torch.nn.Module):
|
|||
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
|
||||
class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.mul(a, b)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseMulTensorFloatModule())
|
||||
def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(4),
|
||||
tu.rand(4).type(torch.float64))
|
||||
|
||||
class ElementwiseMulTensorIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.mul(a, b)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseMulTensorIntModule())
|
||||
def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.randint(10, [4]).type(torch.int32),
|
||||
torch.randint(10, [4]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class ElementwiseLogModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -553,7 +600,32 @@ class ElementwiseDivScalarModule(torch.nn.Module):
|
|||
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
class ElementwiseDivTensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.div(a, b)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivTensorFloatModule())
|
||||
def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(4),
|
||||
tu.rand(4).type(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAndIntegerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -573,3 +645,5 @@ class ElementwiseAndIntegerModule(torch.nn.Module):
|
|||
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32),
|
||||
torch.randint(-10, 10, (3, 4)))
|
||||
|
||||
|
||||
|
|
|
@ -111,3 +111,5 @@ class TypePromotionAlphaWiderModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule())
|
||||
def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand())
|
||||
|
||||
|
||||
|
|
|
@ -1531,24 +1531,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
}
|
||||
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
|
||||
if (!mul.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
mul.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
AtenMulTensorOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(mul.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
||||
} else {
|
||||
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
||||
}
|
||||
return b.create<arith::MulFOp>(loc, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
||||
if (!div.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
AtenDivTensorOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(div.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>())
|
||||
div.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
return b.create<arith::DivFOp>(loc, payloadArgs[0], payloadArgs[1]);
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
|
||||
if (!pow.getType()
|
||||
|
|
Loading…
Reference in New Issue