mirror of https://github.com/llvm/torch-mlir
[Linalg] Refactor compare scalar op (#3294)
parent
c4b28e8d9f
commit
0f0f57c960
|
@ -195,6 +195,50 @@ static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
|
|||
llvm_unreachable("unimplemented: op type not supported");
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
|
||||
Value lhs, Value rhs) {
|
||||
static_assert(std::is_same<OpTy, AtenLtScalarOp>() ||
|
||||
std::is_same<OpTy, AtenLeScalarOp>() ||
|
||||
std::is_same<OpTy, AtenEqScalarOp>() ||
|
||||
std::is_same<OpTy, AtenNeScalarOp>() ||
|
||||
std::is_same<OpTy, AtenGtScalarOp>() ||
|
||||
std::is_same<OpTy, AtenGeScalarOp>(),
|
||||
"unimplemented: op type not supported");
|
||||
|
||||
Type lhsDtype = lhs.getType();
|
||||
Type rhsDtype = rhs.getType();
|
||||
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
|
||||
Value otherPromoted = convertScalarToDtype(b, loc, rhs, lhsDtype);
|
||||
|
||||
if (isa<mlir::IntegerType>(elementalType) &&
|
||||
!isa<mlir::IntegerType>(rhsDtype)) {
|
||||
// TODO: Promote tensor args from integer to float.
|
||||
op.emitError("unimplemented: type promotion from tensor to scalar.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<OpTy, AtenLtScalarOp>()) {
|
||||
return createLessThan(b, loc, elementalType, lhs, otherPromoted);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenLeScalarOp>()) {
|
||||
return createLessThanOrEqual(b, loc, elementalType, lhs, otherPromoted);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenGtScalarOp>()) {
|
||||
return createGreaterThan(b, loc, elementalType, lhs, otherPromoted);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenGeScalarOp>()) {
|
||||
return createGreaterThanOrEqual(b, loc, elementalType, lhs, otherPromoted);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenEqScalarOp>()) {
|
||||
return createEqual(b, loc, elementalType, lhs, otherPromoted);
|
||||
}
|
||||
if constexpr (std::is_same<OpTy, AtenNeScalarOp>()) {
|
||||
return createNotEqual(b, loc, elementalType, lhs, otherPromoted);
|
||||
}
|
||||
llvm_unreachable("unimplemented: op type not supported");
|
||||
}
|
||||
|
||||
template <arith::CmpIPredicate predicate>
|
||||
static LogicalResult
|
||||
createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
|
||||
|
@ -959,151 +1003,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
|
||||
Type dtype = cast<BaseTensorType>(gtScalar.getSelf().getType()).getDtype();
|
||||
|
||||
// TODO: `gtTensor` and `gtScalar` share similar code and can be called from
|
||||
// one static function.
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
if (isa<mlir::FloatType>(dtype))
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor args from integer to float.
|
||||
gtScalar.emitError(
|
||||
"unimplemented: type promotion from tensor to scalar.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (intType.isSigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
|
||||
payloadArgs[0], otherPromoted);
|
||||
}
|
||||
gtScalar.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
|
||||
Type dtype = cast<BaseTensorType>(geScalar.getSelf().getType()).getDtype();
|
||||
|
||||
// TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that
|
||||
// can be refactored.
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
if (isa<mlir::FloatType>(dtype))
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor args from integer to float.
|
||||
geScalar.emitError(
|
||||
"unimplemented: type promotion from tensor to scalar.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (intType.isSigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
|
||||
payloadArgs[0], otherPromoted);
|
||||
}
|
||||
geScalar.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
|
||||
Type dtype = cast<BaseTensorType>(eqScalar.getSelf().getType()).getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
if (isa<mlir::IntegerType>(dtype)) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor operand from integer to float.
|
||||
eqScalar.emitError(
|
||||
"unimplemented: type promotion from tensor to scalar");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return createEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
|
||||
return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
|
||||
Type dtype = cast<BaseTensorType>(neScalar.getSelf().getType()).getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
if (isa<mlir::IntegerType>(dtype)) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor operand from integer to float.
|
||||
neScalar.emitError(
|
||||
"unimplemented: type promotion from tensor to scalar");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return createNotEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
|
||||
return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
||||
Type dtype = cast<BaseTensorType>(ltScalar.getSelf().getType()).getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
|
||||
// a lot of code that can be refactored.
|
||||
if (isa<mlir::FloatType>(dtype))
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor operand from integer to float.
|
||||
ltScalar.emitError(
|
||||
"unimplemented: type promotion from tensor to scalar");
|
||||
return nullptr;
|
||||
}
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (intType.isSigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
payloadArgs[0], otherPromoted);
|
||||
}
|
||||
ltScalar.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
|
||||
Type dtype = cast<BaseTensorType>(leScalar.getSelf().getType()).getDtype();
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
|
||||
// that can be refactored.
|
||||
if (isa<mlir::FloatType>(dtype))
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor operand from integer to float.
|
||||
leScalar.emitError(
|
||||
"unimplemented: type promotion from tensor to scalar");
|
||||
return nullptr;
|
||||
}
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
|
||||
payloadArgs[0], otherPromoted);
|
||||
if (intType.isSigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
|
||||
payloadArgs[0], otherPromoted);
|
||||
}
|
||||
leScalar.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]);
|
||||
}
|
||||
|
||||
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
|
||||
|
|
Loading…
Reference in New Issue