mirror of https://github.com/llvm/torch-mlir
[Linalg] Promote type for compare tensor op (#3416)
parent
661be2d5b0
commit
d59d0b6e5a
|
@ -149,58 +149,17 @@ static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
|
|||
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
|
||||
Value lhs, Value rhs) {
|
||||
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
|
||||
std::is_same<OpTy, AtenLeTensorOp>() ||
|
||||
std::is_same<OpTy, AtenGtTensorOp>() ||
|
||||
std::is_same<OpTy, AtenGeTensorOp>() ||
|
||||
std::is_same<OpTy, AtenEqTensorOp>() ||
|
||||
std::is_same<OpTy, AtenNeTensorOp>(),
|
||||
"unimplemented: op type not supported");
|
||||
|
||||
Type lhsDtype = lhs.getType();
|
||||
Type rhsDtype = rhs.getType();
|
||||
|
||||
// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
|
||||
// to be handled.
|
||||
if (lhsDtype != rhsDtype) {
|
||||
op.emitError("unimplemented: lhs and rhs dtype must be same");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
|
||||
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
|
||||
return createLessThan(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
|
||||
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
|
||||
return createGreaterThan(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
|
||||
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
|
||||
return createEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenNeTensorOp>()) {
|
||||
return createNotEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
llvm_unreachable("unimplemented: op type not supported");
|
||||
}
|
||||
template <class T, class... Ts>
|
||||
struct is_any_same : std::disjunction<std::is_same<T, Ts>...> {};
|
||||
|
||||
template <typename OpTy>
|
||||
static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
|
||||
Value lhs, Value rhs) {
|
||||
static_assert(std::is_same<OpTy, AtenLtScalarOp>() ||
|
||||
std::is_same<OpTy, AtenLeScalarOp>() ||
|
||||
std::is_same<OpTy, AtenEqScalarOp>() ||
|
||||
std::is_same<OpTy, AtenNeScalarOp>() ||
|
||||
std::is_same<OpTy, AtenGtScalarOp>() ||
|
||||
std::is_same<OpTy, AtenGeScalarOp>(),
|
||||
static Value createCompareOp(OpBuilder &b, Location loc, OpTy op, Value lhs,
|
||||
Value rhs) {
|
||||
static_assert(
|
||||
is_any_same<OpTy, AtenLtScalarOp, AtenLeScalarOp, AtenEqScalarOp,
|
||||
AtenNeScalarOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenGtTensorOp,
|
||||
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp>(),
|
||||
"unimplemented: op type not supported");
|
||||
|
||||
Type lhsDtype = lhs.getType();
|
||||
|
@ -229,22 +188,22 @@ static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<OpTy, AtenLtScalarOp>()) {
|
||||
if constexpr (is_any_same<OpTy, AtenLtScalarOp, AtenLtTensorOp>()) {
|
||||
return createLessThan(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenLeScalarOp>()) {
|
||||
if constexpr (is_any_same<OpTy, AtenLeScalarOp, AtenLeTensorOp>()) {
|
||||
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenGtScalarOp>()) {
|
||||
if constexpr (is_any_same<OpTy, AtenGtScalarOp, AtenGtTensorOp>()) {
|
||||
return createGreaterThan(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenGeScalarOp>()) {
|
||||
if constexpr (is_any_same<OpTy, AtenGeScalarOp, AtenGeTensorOp>()) {
|
||||
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenEqScalarOp>()) {
|
||||
if constexpr (is_any_same<OpTy, AtenEqScalarOp, AtenEqTensorOp>()) {
|
||||
return createEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenNeScalarOp>()) {
|
||||
if constexpr (is_any_same<OpTy, AtenNeScalarOp, AtenNeTensorOp>()) {
|
||||
return createNotEqual(b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
llvm_unreachable("unimplemented: op type not supported");
|
||||
|
@ -892,28 +851,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<math::Atan2Op>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
return createCompareOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
return createCompareOp(b, loc, leTensor, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
return createCompareOp(b, loc, gtTensor, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
return createCompareOp(b, loc, geTensor, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
return createCompareOp(b, loc, eqTensor, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto neTensor = dyn_cast<AtenNeTensorOp>(op)) {
|
||||
return createCompareTensorOp(b, loc, neTensor, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
return createCompareOp(b, loc, neTensor, payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
||||
AtenDivTensorOp::Adaptor adaptor(operands);
|
||||
|
@ -996,27 +949,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
|
||||
return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
|
||||
return createCompareOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
|
||||
return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]);
|
||||
return createCompareOp(b, loc, geScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
|
||||
return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
|
||||
return createCompareOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
|
||||
return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]);
|
||||
return createCompareOp(b, loc, neScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
||||
return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
|
||||
return createCompareOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
|
||||
return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]);
|
||||
return createCompareOp(b, loc, leScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
|
||||
|
|
|
@ -27,6 +27,7 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
}
|
||||
|
||||
LINALG_CRASHING_SET = {
|
||||
|
@ -2707,6 +2708,7 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseTanIntModule_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
"ElementwiseUnaryIntModule_basic",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
"NativeDropoutTrainModule_basic",
|
||||
"NativeDropoutTrainStaticShapeModule_basic",
|
||||
|
@ -3786,6 +3788,7 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"ElementwiseExpm1IntModule_basic",
|
||||
"ElementwiseExpm1Module_basic",
|
||||
"ElementwiseFlattenBroadcastModule_basic",
|
||||
"ElementwiseFloatTensorGtIntTensorModule_basic",
|
||||
"ElementwiseFmodTensor_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_Float_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
|
|
|
@ -599,6 +599,51 @@ def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))
|
||||
|
||||
|
||||
class ElementwiseIntTensorLtFloatTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.float64, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule())
|
||||
def ElementwiseIntTensorLtFloatTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 5, high=10), tu.rand(5, high=10).to(torch.float64))
|
||||
|
||||
|
||||
class ElementwiseFloatTensorGtIntTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
return torch.gt(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule())
|
||||
def ElementwiseFloatTensorGtIntTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(3, 5, high=10).to(torch.float32),
|
||||
tu.randint(5, high=10, dtype=torch.int32),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue