mirror of https://github.com/llvm/torch-mlir
fix lt for tensor arg
parent
db6721084a
commit
85801144ca
|
@ -135,6 +135,7 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
|
||||||
return buildNormalCdf(b, loc, x, zero, one);
|
return buildNormalCdf(b, loc, x, zero, one);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename MathOpTy>
|
template <typename MathOpTy>
|
||||||
static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
|
static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
|
||||||
Value payloadArg, Operation *op) {
|
Value payloadArg, Operation *op) {
|
||||||
|
@ -1051,22 +1052,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
||||||
Type dtype = cast<BaseTensorType>(ltScalar.getSelf().getType()).getDtype();
|
Type dtype = ltScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
||||||
|
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
|
||||||
|
// a lot of code that can be refactored.
|
||||||
|
if (isa<mlir::FloatType>(dtype)){
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
|
|
||||||
// a lot of code that can be refactored.
|
|
||||||
if (isa<mlir::FloatType>(dtype))
|
|
||||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||||
payloadArgs[0], otherPromoted);
|
payloadArgs[0], otherPromoted);
|
||||||
|
}
|
||||||
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
|
||||||
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
if (!operands[1].getType().isa<mlir::IntegerType>()) {
|
||||||
// TODO: Promote tensor operand from integer to float.
|
// TODO: Promote tensor operand from integer to float.
|
||||||
ltScalar.emitError(
|
Value payloadArgPromoted = convertScalarToDtype(b, loc, payloadArgs[0], operands[1].getType());
|
||||||
"unimplemented: type promotion from tensor to scalar");
|
Value otherPromoted = operands[1];
|
||||||
return nullptr;
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||||
|
payloadArgPromoted, otherPromoted);
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
Value otherPromoted =
|
||||||
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
if (intType.isUnsigned())
|
if (intType.isUnsigned())
|
||||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
||||||
payloadArgs[0], otherPromoted);
|
payloadArgs[0], otherPromoted);
|
||||||
|
@ -1074,6 +1081,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||||
payloadArgs[0], otherPromoted);
|
payloadArgs[0], otherPromoted);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
ltScalar.emitError("unimplemented: dtype isn't supported.");
|
ltScalar.emitError("unimplemented: dtype isn't supported.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue