mirror of https://github.com/llvm/torch-mlir
fix lt for tensor arg
parent
db6721084a
commit
85801144ca
|
@ -135,6 +135,7 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
|
|||
return buildNormalCdf(b, loc, x, zero, one);
|
||||
}
|
||||
|
||||
|
||||
template <typename MathOpTy>
|
||||
static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
|
||||
Value payloadArg, Operation *op) {
|
||||
|
@ -1051,22 +1052,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
|
||||
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
||||
Type dtype = cast<BaseTensorType>(ltScalar.getSelf().getType()).getDtype();
|
||||
Type dtype = ltScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
|
||||
// a lot of code that can be refactored.
|
||||
if (isa<mlir::FloatType>(dtype)){
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
|
||||
// a lot of code that can be refactored.
|
||||
if (isa<mlir::FloatType>(dtype))
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
payloadArgs[0], otherPromoted);
|
||||
}
|
||||
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||
// TODO: Promote tensor operand from integer to float.
|
||||
ltScalar.emitError(
|
||||
"unimplemented: type promotion from tensor to scalar");
|
||||
return nullptr;
|
||||
Value payloadArgPromoted = convertScalarToDtype(b, loc, payloadArgs[0], operands[1].getType());
|
||||
Value otherPromoted = operands[1];
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
payloadArgPromoted, otherPromoted);
|
||||
}
|
||||
else {
|
||||
Value otherPromoted =
|
||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
||||
payloadArgs[0], otherPromoted);
|
||||
|
@ -1074,6 +1081,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
payloadArgs[0], otherPromoted);
|
||||
}
|
||||
}
|
||||
ltScalar.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue