mirror of https://github.com/llvm/torch-mlir
parent
c9c9b68d1f
commit
b0cb49ca93
|
@ -784,6 +784,7 @@ def AddCDivModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class DropoutModule(torch.nn.Module):
|
class DropoutModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -363,6 +363,8 @@ def RsubModule_noalpha_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseMulScalarModule(torch.nn.Module):
|
class ElementwiseMulScalarModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -379,6 +381,51 @@ class ElementwiseMulScalarModule(torch.nn.Module):
|
||||||
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.mul(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseMulTensorFloatModule())
|
||||||
|
def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.rand(4),
|
||||||
|
tu.rand(4).type(torch.float64))
|
||||||
|
|
||||||
|
class ElementwiseMulTensorIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int32, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.mul(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseMulTensorIntModule())
|
||||||
|
def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
torch.randint(10, [4]).type(torch.int32),
|
||||||
|
torch.randint(10, [4]))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
class ElementwiseLogModule(torch.nn.Module):
|
class ElementwiseLogModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -553,7 +600,32 @@ class ElementwiseDivScalarModule(torch.nn.Module):
|
||||||
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseDivTensorFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.div(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(
|
||||||
|
module_factory=lambda: ElementwiseDivTensorFloatModule())
|
||||||
|
def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.rand(4),
|
||||||
|
tu.rand(4).type(torch.float64))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseAndIntegerModule(torch.nn.Module):
|
class ElementwiseAndIntegerModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -573,3 +645,5 @@ class ElementwiseAndIntegerModule(torch.nn.Module):
|
||||||
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
|
def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32),
|
module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32),
|
||||||
torch.randint(-10, 10, (3, 4)))
|
torch.randint(-10, 10, (3, 4)))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -111,3 +111,5 @@ class TypePromotionAlphaWiderModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule())
|
@register_test_case(module_factory=lambda: TypePromotionAlphaWiderModule())
|
||||||
def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils):
|
def TypePromotionAlphaWiderModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(4), tu.rand())
|
module.forward(tu.rand(4), tu.rand())
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1531,24 +1531,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
|
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
|
||||||
if (!mul.getType()
|
AtenMulTensorOp::Adaptor adaptor(operands);
|
||||||
.cast<ValueTensorType>()
|
Type dtype = converter->convertType(mul.getType())
|
||||||
.getDtype()
|
.cast<RankedTensorType>()
|
||||||
.isa<mlir::FloatType>()) {
|
.getElementType();
|
||||||
mul.emitError("unimplemented: non-floating point dtype");
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
return nullptr;
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
|
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
||||||
|
} else {
|
||||||
|
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
return b.create<arith::MulFOp>(loc, payloadArgs[0], payloadArgs[1]);
|
|
||||||
}
|
}
|
||||||
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
||||||
if (!div.getType()
|
AtenDivTensorOp::Adaptor adaptor(operands);
|
||||||
.cast<ValueTensorType>()
|
Type dtype = converter->convertType(div.getType())
|
||||||
.getDtype()
|
.cast<RankedTensorType>()
|
||||||
.isa<mlir::FloatType>()) {
|
.getElementType();
|
||||||
|
if (!dtype.isa<mlir::FloatType>())
|
||||||
div.emitError("unimplemented: non-floating point dtype");
|
div.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
}
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
return b.create<arith::DivFOp>(loc, payloadArgs[0], payloadArgs[1]);
|
return b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
|
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
|
||||||
if (!pow.getType()
|
if (!pow.getType()
|
||||||
|
|
Loading…
Reference in New Issue