diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 430541db7..51d2ac141 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -135,6 +135,7 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { return buildNormalCdf(b, loc, x, zero, one); } + template static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter, Value payloadArg, Operation *op) { @@ -1051,28 +1052,35 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto ltScalar = dyn_cast(op)) { - Type dtype = cast(ltScalar.getSelf().getType()).getDtype(); - Value otherPromoted = - convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); - + Type dtype = ltScalar.getSelf().getType().cast().getDtype(); // TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share // a lot of code that can be refactored. - if (isa(dtype)) + if (isa(dtype)){ + Value otherPromoted = + convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); + return b.create(loc, arith::CmpFPredicate::ULT, payloadArgs[0], otherPromoted); + } if (IntegerType intType = dyn_cast(dtype)) { if (!operands[1].getType().isa()) { // TODO: Promote tensor operand from integer to float. - ltScalar.emitError( - "unimplemented: type promotion from tensor to scalar"); - return nullptr; + Value payloadArgPromoted = convertScalarToDtype(b, loc, payloadArgs[0], operands[1].getType()); + Value otherPromoted = operands[1]; + return b.create(loc, arith::CmpFPredicate::ULT, + payloadArgPromoted, otherPromoted); } - if (intType.isUnsigned()) - return b.create(loc, arith::CmpIPredicate::ult, - payloadArgs[0], otherPromoted); - if (intType.isSigned()) - return b.create(loc, arith::CmpIPredicate::slt, - payloadArgs[0], otherPromoted); + else { + Value otherPromoted = + convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType()); + + if (intType.isUnsigned()) + return b.create(loc, arith::CmpIPredicate::ult, + payloadArgs[0], otherPromoted); + if (intType.isSigned()) + return b.create(loc, arith::CmpIPredicate::slt, + payloadArgs[0], otherPromoted); + } } ltScalar.emitError("unimplemented: dtype isn't supported."); return nullptr;