[Linalg] Promote type for compare tensor op (#3416)

pull/3421/head
penguin_wwy 2024-06-05 07:05:39 +08:00 committed by GitHub
parent 661be2d5b0
commit d59d0b6e5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 76 additions and 75 deletions

View File

@ -149,59 +149,18 @@ static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
}
template <typename OpTy>
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
std::is_same<OpTy, AtenLeTensorOp>() ||
std::is_same<OpTy, AtenGtTensorOp>() ||
std::is_same<OpTy, AtenGeTensorOp>() ||
std::is_same<OpTy, AtenEqTensorOp>() ||
std::is_same<OpTy, AtenNeTensorOp>(),
"unimplemented: op type not supported");
Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
op.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenNeTensorOp>()) {
return createNotEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
}
template <class T, class... Ts>
struct is_any_same : std::disjunction<std::is_same<T, Ts>...> {};
template <typename OpTy>
static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtScalarOp>() ||
std::is_same<OpTy, AtenLeScalarOp>() ||
std::is_same<OpTy, AtenEqScalarOp>() ||
std::is_same<OpTy, AtenNeScalarOp>() ||
std::is_same<OpTy, AtenGtScalarOp>() ||
std::is_same<OpTy, AtenGeScalarOp>(),
"unimplemented: op type not supported");
static Value createCompareOp(OpBuilder &b, Location loc, OpTy op, Value lhs,
Value rhs) {
static_assert(
is_any_same<OpTy, AtenLtScalarOp, AtenLeScalarOp, AtenEqScalarOp,
AtenNeScalarOp, AtenGtScalarOp, AtenGeScalarOp,
AtenLtTensorOp, AtenLeTensorOp, AtenGtTensorOp,
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp>(),
"unimplemented: op type not supported");
Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
@ -229,22 +188,22 @@ static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
return nullptr;
}
if constexpr (std::is_same<OpTy, AtenLtScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenLtScalarOp, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenLeScalarOp, AtenLeTensorOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenGtScalarOp, AtenGtTensorOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenGeScalarOp, AtenGeTensorOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenEqScalarOp, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenNeScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenNeScalarOp, AtenNeTensorOp>()) {
return createNotEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
@ -892,28 +851,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::Atan2Op>(loc, lhs, rhs);
}
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, leTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, gtTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, geTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, eqTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto neTensor = dyn_cast<AtenNeTensorOp>(op)) {
return createCompareTensorOp(b, loc, neTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, neTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands);
@ -996,27 +949,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
}
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, geScalar, payloadArgs[0], operands[1]);
}
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
}
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, neScalar, payloadArgs[0], operands[1]);
}
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
}
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, leScalar, payloadArgs[0], operands[1]);
}
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {

View File

@ -27,6 +27,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
"InterpolateDynamicModule_scales_recompute_bilinear",
"ElementwiseFloatTensorGtIntTensorModule_basic",
}
LINALG_CRASHING_SET = {
@ -2707,6 +2708,7 @@ ONNX_XFAIL_SET = {
"ElementwiseTanIntModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseUnaryIntModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
@ -3786,6 +3788,7 @@ ONNX_TOSA_XFAIL_SET = {
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",

View File

@ -599,6 +599,51 @@ def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
class ElementwiseIntTensorLtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.int64, True),
([-1], torch.float64, True),
]
)
def forward(self, x, y):
return torch.lt(x, y)
@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule())
def ElementwiseIntTensorLtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, high=10), tu.rand(5, high=10).to(torch.float64))
class ElementwiseFloatTensorGtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1], torch.int32, True),
]
)
def forward(self, x, y):
return torch.gt(x, y)
@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule())
def ElementwiseFloatTensorGtIntTensorModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(3, 5, high=10).to(torch.float32),
tu.randint(5, high=10, dtype=torch.int32),
)
# ==============================================================================