[onnx] Fix `onnx.sigmoid` for integer inputs/outputs (#2914)

Sample compilation crashes due to sigmoid with integer inputs/outputs.
This fix avoids crashing but still experiences an error.
pull/2922/head
Rob Suderman 2024-02-16 13:35:25 -08:00 committed by GitHub
parent 7a0d0e954b
commit d65925a8b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 59 deletions

View File

@ -1615,9 +1615,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape()); llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape());
for (int i = 0, s = operandTy.getRank(); i < s; ++i) { for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) { if (operandTy.getDimSize(i) != resultTy.getDimSize(i))
intermediateShape[i] = -1; intermediateShape[i] = -1;
} if (intermediateShape[i] == ShapedType::kDynamic)
intermediateShape[i] = Torch::kUnknownSize;
} }
auto intermediateType = Torch::ValueTensorType::get( auto intermediateType = Torch::ValueTensorType::get(
context, intermediateShape, resultTorchType.getOptionalDtype()); context, intermediateShape, resultTorchType.getOptionalDtype());

View File

@ -128,16 +128,20 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
} }
template <typename MathOpTy> template <typename MathOpTy>
static Value static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
createCalculationForMathOpWithDtypeConversion(OpBuilder &b, Value payloadArg, Operation *op) {
const TypeConverter *converter, Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
Value payloadArg, Operation *op) { Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
Type dtype = converter->convertType(op->getResult(0).getType()) Type outTy =
.template cast<RankedTensorType>() cast<RankedTensorType>(converter->convertType(op->getResult(0).getType()))
.getElementType(); .getElementType();
Type computeTy = outTy;
if (isa<IntegerType>(computeTy))
computeTy = b.getF32Type();
Location loc = op->getLoc(); Location loc = op->getLoc();
Value arg = convertScalarToDtype(b, loc, payloadArg, dtype); Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy);
return b.create<MathOpTy>(loc, arg); auto newOp = b.create<MathOpTy>(loc, arg);
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
} }
template <typename OpTy> template <typename OpTy>
@ -217,92 +221,70 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (isa<AtenCeilOp>(op)) if (isa<AtenCeilOp>(op))
return b.create<math::CeilOp>(loc, payloadArgs[0]); return b.create<math::CeilOp>(loc, payloadArgs[0]);
if (isa<AtenExpOp>(op)) { if (isa<AtenExpOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ExpOp>( return createFpOpWithDtype<math::ExpOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenExpm1Op>(op)) { if (isa<AtenExpm1Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ExpM1Op>( return createFpOpWithDtype<math::ExpM1Op>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenLogOp>(op)) { if (isa<AtenLogOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::LogOp>( return createFpOpWithDtype<math::LogOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenLog2Op>(op)) { if (isa<AtenLog2Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>( return createFpOpWithDtype<math::Log2Op>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenLog10Op>(op)) { if (isa<AtenLog10Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log10Op>( return createFpOpWithDtype<math::Log10Op>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenLog1pOp>(op)) { if (isa<AtenLog1pOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log1pOp>( return createFpOpWithDtype<math::Log1pOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenErfOp>(op)) { if (isa<AtenErfOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ErfOp>( return createFpOpWithDtype<math::ErfOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenSqrtOp>(op)) { if (isa<AtenSqrtOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SqrtOp>( return createFpOpWithDtype<math::SqrtOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenRsqrtOp>(op)) { if (isa<AtenRsqrtOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::RsqrtOp>( return createFpOpWithDtype<math::RsqrtOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenNegOp>(op)) { if (isa<AtenNegOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<arith::NegFOp>( return createFpOpWithDtype<arith::NegFOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenSinOp>(op)) { if (isa<AtenSinOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SinOp>( return createFpOpWithDtype<math::SinOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenSinhOp>(op)) { if (isa<AtenSinhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::SinhOp>( return createFpOpWithDtype<math::SinhOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenAsinOp>(op)) { if (isa<AtenAsinOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AsinOp>( return createFpOpWithDtype<math::AsinOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenAsinhOp>(op)) { if (isa<AtenAsinhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AsinhOp>( return createFpOpWithDtype<math::AsinhOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenCosOp>(op)) { if (isa<AtenCosOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::CosOp>( return createFpOpWithDtype<math::CosOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenCoshOp>(op)) { if (isa<AtenCoshOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::CoshOp>( return createFpOpWithDtype<math::CoshOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenAcosOp>(op)) { if (isa<AtenAcosOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AcosOp>( return createFpOpWithDtype<math::AcosOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenAcoshOp>(op)) { if (isa<AtenAcoshOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AcoshOp>( return createFpOpWithDtype<math::AcoshOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenTanOp>(op)) { if (isa<AtenTanOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::TanOp>( return createFpOpWithDtype<math::TanOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenTanhOp>(op)) { if (isa<AtenTanhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>( return createFpOpWithDtype<math::TanhOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenAtanOp>(op)) { if (isa<AtenAtanOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AtanOp>( return createFpOpWithDtype<math::AtanOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (isa<AtenAtanhOp>(op)) { if (isa<AtenAtanhOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AtanhOp>( return createFpOpWithDtype<math::AtanhOp>(b, converter, payloadArgs[0], op);
b, converter, payloadArgs[0], op);
} }
if (auto clone = dyn_cast<AtenCloneOp>(op)) { if (auto clone = dyn_cast<AtenCloneOp>(op)) {
int64_t memoryFormat; int64_t memoryFormat;
@ -453,13 +435,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createEqual(b, loc, abs.getType(), abs, infinity); return createEqual(b, loc, abs.getType(), abs, infinity);
} }
if (isa<AtenSigmoidOp>(op)) { if (isa<AtenSigmoidOp>(op)) {
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>( Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
b, converter, payloadArgs[0], op); Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
Type outTy = cast<RankedTensorType>(
converter->convertType(op->getResult(0).getType()))
.getElementType();
Type computeTy = outTy;
if (isa<IntegerType>(computeTy))
computeTy = b.getF32Type();
Value arg = payloadArgs[0];
arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy);
auto negate = b.create<arith::NegFOp>(loc, arg);
auto one = auto one =
b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1)); b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1));
auto exp = b.create<math::ExpOp>(loc, negate); auto exp = b.create<math::ExpOp>(loc, negate);
auto added = b.create<arith::AddFOp>(loc, exp, one); auto added = b.create<arith::AddFOp>(loc, exp, one);
return b.create<arith::DivFOp>(loc, one, added); auto div = b.create<arith::DivFOp>(loc, one, added);
outTy.dump();
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
} }
if (auto relu = dyn_cast<AtenReluOp>(op)) { if (auto relu = dyn_cast<AtenReluOp>(op)) {
if (!relu.getType() if (!relu.getType()

View File

@ -2165,6 +2165,9 @@ ONNX_XFAIL_SET = {
"ReduceMaxKeepDimReturnBoth_basic", "ReduceMaxKeepDimReturnBoth_basic",
"ReduceMaxNegativeDim_basic", "ReduceMaxNegativeDim_basic",
"ViewSizeFromOtherTensor_basic", "ViewSizeFromOtherTensor_basic",
# Failure - onnx traces differently
"ElementwiseSigmoidIntModule_basic",
# Failure - unknown # Failure - unknown
"ChunkListUnpackUneven_Module_basic", "ChunkListUnpackUneven_Module_basic",
@ -2192,7 +2195,6 @@ ONNX_XFAIL_SET = {
} }
ONNX_CRASHING_SET = { ONNX_CRASHING_SET = {
"ElementwiseSigmoidIntModule_basic",
"FlipModule_basic", "FlipModule_basic",
"IndexTensorNegativeIndexModule_basic", "IndexTensorNegativeIndexModule_basic",
"MoveDimIntNegativeIndexModule_basic", "MoveDimIntNegativeIndexModule_basic",