fix lt for tensor arg

fix_lt
dan 2024-05-01 19:17:43 +00:00
parent db6721084a
commit 85801144ca
1 changed files with 22 additions and 14 deletions

View File

@ -135,6 +135,7 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
return buildNormalCdf(b, loc, x, zero, one);
}
template <typename MathOpTy>
static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
Value payloadArg, Operation *op) {
@ -1051,28 +1052,35 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(ltScalar.getSelf().getType()).getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
Type dtype = ltScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
// a lot of code that can be refactored.
if (isa<mlir::FloatType>(dtype))
if (isa<mlir::FloatType>(dtype)){
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], otherPromoted);
}
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
ltScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
Value payloadArgPromoted = convertScalarToDtype(b, loc, payloadArgs[0], operands[1].getType());
Value otherPromoted = operands[1];
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgPromoted, otherPromoted);
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
payloadArgs[0], otherPromoted);
else {
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
payloadArgs[0], otherPromoted);
}
}
ltScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;