mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `onnx.sigmoid` for integer inputs/outputs (#2914)
Sample compilation crashes due to sigmoid with integer inputs/outputs. This fix avoids crashing but still experiences an error.pull/2922/head
parent
7a0d0e954b
commit
d65925a8b4
|
@ -1615,9 +1615,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape());
|
||||
for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
|
||||
if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) {
|
||||
if (operandTy.getDimSize(i) != resultTy.getDimSize(i))
|
||||
intermediateShape[i] = -1;
|
||||
}
|
||||
if (intermediateShape[i] == ShapedType::kDynamic)
|
||||
intermediateShape[i] = Torch::kUnknownSize;
|
||||
}
|
||||
auto intermediateType = Torch::ValueTensorType::get(
|
||||
context, intermediateShape, resultTorchType.getOptionalDtype());
|
||||
|
|
|
@ -128,16 +128,20 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
|
|||
}
|
||||
|
||||
template <typename MathOpTy>
|
||||
static Value
|
||||
createCalculationForMathOpWithDtypeConversion(OpBuilder &b,
|
||||
const TypeConverter *converter,
|
||||
Value payloadArg, Operation *op) {
|
||||
Type dtype = converter->convertType(op->getResult(0).getType())
|
||||
.template cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
|
||||
Value payloadArg, Operation *op) {
|
||||
Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
|
||||
Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
|
||||
Type outTy =
|
||||
cast<RankedTensorType>(converter->convertType(op->getResult(0).getType()))
|
||||
.getElementType();
|
||||
Type computeTy = outTy;
|
||||
if (isa<IntegerType>(computeTy))
|
||||
computeTy = b.getF32Type();
|
||||
Location loc = op->getLoc();
|
||||
Value arg = convertScalarToDtype(b, loc, payloadArg, dtype);
|
||||
return b.create<MathOpTy>(loc, arg);
|
||||
Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy);
|
||||
auto newOp = b.create<MathOpTy>(loc, arg);
|
||||
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
|
@ -217,92 +221,70 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
if (isa<AtenCeilOp>(op))
|
||||
return b.create<math::CeilOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenExpOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::ExpOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::ExpOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenExpm1Op>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::ExpM1Op>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::ExpM1Op>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenLogOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::LogOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::LogOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenLog2Op>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::Log2Op>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenLog10Op>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::Log10Op>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::Log10Op>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenLog1pOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::Log1pOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::Log1pOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenErfOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::ErfOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::ErfOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenSqrtOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::SqrtOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::SqrtOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenRsqrtOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::RsqrtOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::RsqrtOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenNegOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<arith::NegFOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenSinOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::SinOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::SinOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenSinhOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::SinhOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::SinhOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenAsinOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::AsinOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::AsinOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenAsinhOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::AsinhOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::AsinhOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenCosOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::CosOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::CosOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenCoshOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::CoshOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::CoshOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenAcosOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::AcosOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::AcosOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenAcoshOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::AcoshOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::AcoshOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenTanOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::TanOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::TanOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenTanhOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::TanhOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenAtanOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::AtanOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::AtanOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenAtanhOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::AtanhOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
return createFpOpWithDtype<math::AtanhOp>(b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
|
||||
int64_t memoryFormat;
|
||||
|
@ -453,13 +435,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return createEqual(b, loc, abs.getType(), abs, infinity);
|
||||
}
|
||||
if (isa<AtenSigmoidOp>(op)) {
|
||||
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
|
||||
Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
|
||||
Type outTy = cast<RankedTensorType>(
|
||||
converter->convertType(op->getResult(0).getType()))
|
||||
.getElementType();
|
||||
Type computeTy = outTy;
|
||||
if (isa<IntegerType>(computeTy))
|
||||
computeTy = b.getF32Type();
|
||||
|
||||
Value arg = payloadArgs[0];
|
||||
arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy);
|
||||
auto negate = b.create<arith::NegFOp>(loc, arg);
|
||||
auto one =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1));
|
||||
auto exp = b.create<math::ExpOp>(loc, negate);
|
||||
auto added = b.create<arith::AddFOp>(loc, exp, one);
|
||||
return b.create<arith::DivFOp>(loc, one, added);
|
||||
auto div = b.create<arith::DivFOp>(loc, one, added);
|
||||
outTy.dump();
|
||||
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
|
||||
}
|
||||
if (auto relu = dyn_cast<AtenReluOp>(op)) {
|
||||
if (!relu.getType()
|
||||
|
|
|
@ -2165,6 +2165,9 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceMaxKeepDimReturnBoth_basic",
|
||||
"ReduceMaxNegativeDim_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
|
||||
# Failure - onnx traces differently
|
||||
"ElementwiseSigmoidIntModule_basic",
|
||||
|
||||
# Failure - unknown
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
|
@ -2192,7 +2195,6 @@ ONNX_XFAIL_SET = {
|
|||
}
|
||||
|
||||
ONNX_CRASHING_SET = {
|
||||
"ElementwiseSigmoidIntModule_basic",
|
||||
"FlipModule_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"MoveDimIntNegativeIndexModule_basic",
|
||||
|
|
Loading…
Reference in New Issue