[onnx] Fix `onnx.sigmoid` for integer inputs/outputs (#2914)

Sample compilation crashes due to sigmoid with integer inputs/outputs.
This fix avoids crashing but still experiences an error.
pull/2922/head
Rob Suderman 2024-02-16 13:35:25 -08:00 committed by GitHub
parent 7a0d0e954b
commit d65925a8b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 59 deletions

View File

@ -1615,9 +1615,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape());
for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) {
if (operandTy.getDimSize(i) != resultTy.getDimSize(i))
intermediateShape[i] = -1;
}
if (intermediateShape[i] == ShapedType::kDynamic)
intermediateShape[i] = Torch::kUnknownSize;
}
auto intermediateType = Torch::ValueTensorType::get(
context, intermediateShape, resultTorchType.getOptionalDtype());

View File

@ -128,16 +128,20 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
}
template <typename MathOpTy>
static Value
createCalculationForMathOpWithDtypeConversion(OpBuilder &b,
const TypeConverter *converter,
Value payloadArg, Operation *op) {
Type dtype = converter->convertType(op->getResult(0).getType())
.template cast<RankedTensorType>()
.getElementType();
static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
Value payloadArg, Operation *op) {
Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
Type outTy =
cast<RankedTensorType>(converter->convertType(op->getResult(0).getType()))
.getElementType();
Type computeTy = outTy;
if (isa<IntegerType>(computeTy))
computeTy = b.getF32Type();
Location loc = op->getLoc();
Value arg = convertScalarToDtype(b, loc, payloadArg, dtype);
return b.create<MathOpTy>(loc, arg);
Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy);
auto newOp = b.create<MathOpTy>(loc, arg);
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
}
template <typename OpTy>
@ -217,92 +221,70 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (isa<AtenCeilOp>(op))
return b.create<math::CeilOp>(loc, payloadArgs[0]);
if (isa<AtenExpOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ExpOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::ExpOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenExpm1Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ExpM1Op>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::ExpM1Op>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLogOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::LogOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::LogOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLog2Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::Log2Op>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLog10Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log10Op>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::Log10Op>(b, converter, payloadArgs[0], op);
}
if (isa<AtenLog1pOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log1pOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::Log1pOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenErfOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ErfOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::ErfOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenSqrtOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SqrtOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::SqrtOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenRsqrtOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::RsqrtOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::RsqrtOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenNegOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<arith::NegFOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenSinOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SinOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::SinOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenSinhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SinhOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::SinhOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAsinOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AsinOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AsinOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAsinhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AsinhOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AsinhOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenCosOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::CosOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::CosOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenCoshOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::CoshOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::CoshOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAcosOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AcosOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AcosOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAcoshOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AcoshOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AcoshOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenTanOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::TanOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::TanOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenTanhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::TanhOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAtanOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AtanOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AtanOp>(b, converter, payloadArgs[0], op);
}
if (isa<AtenAtanhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AtanhOp>(
b, converter, payloadArgs[0], op);
return createFpOpWithDtype<math::AtanhOp>(b, converter, payloadArgs[0], op);
}
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
int64_t memoryFormat;
@ -453,13 +435,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createEqual(b, loc, abs.getType(), abs, infinity);
}
if (isa<AtenSigmoidOp>(op)) {
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
b, converter, payloadArgs[0], op);
Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
Type outTy = cast<RankedTensorType>(
converter->convertType(op->getResult(0).getType()))
.getElementType();
Type computeTy = outTy;
if (isa<IntegerType>(computeTy))
computeTy = b.getF32Type();
Value arg = payloadArgs[0];
arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy);
auto negate = b.create<arith::NegFOp>(loc, arg);
auto one =
b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1));
auto exp = b.create<math::ExpOp>(loc, negate);
auto added = b.create<arith::AddFOp>(loc, exp, one);
return b.create<arith::DivFOp>(loc, one, added);
auto div = b.create<arith::DivFOp>(loc, one, added);
outTy.dump();
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
}
if (auto relu = dyn_cast<AtenReluOp>(op)) {
if (!relu.getType()

View File

@ -2165,6 +2165,9 @@ ONNX_XFAIL_SET = {
"ReduceMaxKeepDimReturnBoth_basic",
"ReduceMaxNegativeDim_basic",
"ViewSizeFromOtherTensor_basic",
# Failure - onnx traces differently
"ElementwiseSigmoidIntModule_basic",
# Failure - unknown
"ChunkListUnpackUneven_Module_basic",
@ -2192,7 +2195,6 @@ ONNX_XFAIL_SET = {
}
ONNX_CRASHING_SET = {
"ElementwiseSigmoidIntModule_basic",
"FlipModule_basic",
"IndexTensorNegativeIndexModule_basic",
"MoveDimIntNegativeIndexModule_basic",