[Linalg] Refactor compare scalar op (#3294)

pull/3306/head
penguin_wwy 2024-05-09 10:40:19 +08:00 committed by GitHub
parent c4b28e8d9f
commit 0f0f57c960
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 50 additions and 130 deletions

View File

@ -195,6 +195,50 @@ static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
llvm_unreachable("unimplemented: op type not supported");
}
template <typename OpTy>
static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtScalarOp>() ||
std::is_same<OpTy, AtenLeScalarOp>() ||
std::is_same<OpTy, AtenEqScalarOp>() ||
std::is_same<OpTy, AtenNeScalarOp>() ||
std::is_same<OpTy, AtenGtScalarOp>() ||
std::is_same<OpTy, AtenGeScalarOp>(),
"unimplemented: op type not supported");
Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
Value otherPromoted = convertScalarToDtype(b, loc, rhs, lhsDtype);
if (isa<mlir::IntegerType>(elementalType) &&
!isa<mlir::IntegerType>(rhsDtype)) {
// TODO: Promote tensor args from integer to float.
op.emitError("unimplemented: type promotion from tensor to scalar.");
return nullptr;
}
if constexpr (std::is_same<OpTy, AtenLtScalarOp>()) {
return createLessThan(b, loc, elementalType, lhs, otherPromoted);
}
if constexpr (std::is_same<OpTy, AtenLeScalarOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, otherPromoted);
}
if constexpr (std::is_same<OpTy, AtenGtScalarOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, otherPromoted);
}
if constexpr (std::is_same<OpTy, AtenGeScalarOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, otherPromoted);
}
if constexpr (std::is_same<OpTy, AtenEqScalarOp>()) {
return createEqual(b, loc, elementalType, lhs, otherPromoted);
}
if constexpr (std::is_same<OpTy, AtenNeScalarOp>()) {
return createNotEqual(b, loc, elementalType, lhs, otherPromoted);
}
llvm_unreachable("unimplemented: op type not supported");
}
template <arith::CmpIPredicate predicate>
static LogicalResult
createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
@ -959,151 +1003,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(gtScalar.getSelf().getType()).getDtype();
// TODO: `gtTensor` and `gtScalar` share similar code and can be called from
// one static function.
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float.
gtScalar.emitError(
"unimplemented: type promotion from tensor to scalar.");
return nullptr;
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
payloadArgs[0], otherPromoted);
}
gtScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
}
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(geScalar.getSelf().getType()).getDtype();
// TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that
// can be refactored.
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float.
geScalar.emitError(
"unimplemented: type promotion from tensor to scalar.");
return nullptr;
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
payloadArgs[0], otherPromoted);
}
geScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]);
}
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(eqScalar.getSelf().getType()).getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (isa<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
eqScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
}
return createEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
}
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(neScalar.getSelf().getType()).getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (isa<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
neScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
}
return createNotEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]);
}
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(ltScalar.getSelf().getType()).getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
// a lot of code that can be refactored.
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
ltScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
payloadArgs[0], otherPromoted);
}
ltScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
}
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(leScalar.getSelf().getType()).getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
// that can be refactored.
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
leScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
payloadArgs[0], otherPromoted);
}
leScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]);
}
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {