fix lt for tensor arg

fix_lt
dan 2024-05-01 19:17:43 +00:00
parent db6721084a
commit 85801144ca
1 changed files with 22 additions and 14 deletions

View File

@ -135,6 +135,7 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
return buildNormalCdf(b, loc, x, zero, one); return buildNormalCdf(b, loc, x, zero, one);
} }
template <typename MathOpTy> template <typename MathOpTy>
static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter, static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
Value payloadArg, Operation *op) { Value payloadArg, Operation *op) {
@ -1051,22 +1052,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) { if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(ltScalar.getSelf().getType()).getDtype(); Type dtype = ltScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
// a lot of code that can be refactored.
if (isa<mlir::FloatType>(dtype)){
Value otherPromoted = Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
// a lot of code that can be refactored.
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], otherPromoted); payloadArgs[0], otherPromoted);
}
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) { if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) { if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float. // TODO: Promote tensor operand from integer to float.
ltScalar.emitError( Value payloadArgPromoted = convertScalarToDtype(b, loc, payloadArgs[0], operands[1].getType());
"unimplemented: type promotion from tensor to scalar"); Value otherPromoted = operands[1];
return nullptr; return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgPromoted, otherPromoted);
} }
else {
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
if (intType.isUnsigned()) if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
payloadArgs[0], otherPromoted); payloadArgs[0], otherPromoted);
@ -1074,6 +1081,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
payloadArgs[0], otherPromoted); payloadArgs[0], otherPromoted);
} }
}
ltScalar.emitError("unimplemented: dtype isn't supported."); ltScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr; return nullptr;
} }