mirror of https://github.com/llvm/torch-mlir
[onnx] Fix `onnx.sigmoid` for integer inputs/outputs (#2914)
Sample compilation crashes due to sigmoid with integer inputs/outputs. This fix avoids crashing but still experiences an error.pull/2922/head
parent
7a0d0e954b
commit
d65925a8b4
|
@ -1615,9 +1615,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape());
|
llvm::SmallVector<int64_t> intermediateShape(operandTy.getShape());
|
||||||
for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
|
for (int i = 0, s = operandTy.getRank(); i < s; ++i) {
|
||||||
if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) {
|
if (operandTy.getDimSize(i) != resultTy.getDimSize(i))
|
||||||
intermediateShape[i] = -1;
|
intermediateShape[i] = -1;
|
||||||
}
|
if (intermediateShape[i] == ShapedType::kDynamic)
|
||||||
|
intermediateShape[i] = Torch::kUnknownSize;
|
||||||
}
|
}
|
||||||
auto intermediateType = Torch::ValueTensorType::get(
|
auto intermediateType = Torch::ValueTensorType::get(
|
||||||
context, intermediateShape, resultTorchType.getOptionalDtype());
|
context, intermediateShape, resultTorchType.getOptionalDtype());
|
||||||
|
|
|
@ -128,16 +128,20 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename MathOpTy>
|
template <typename MathOpTy>
|
||||||
static Value
|
static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
|
||||||
createCalculationForMathOpWithDtypeConversion(OpBuilder &b,
|
Value payloadArg, Operation *op) {
|
||||||
const TypeConverter *converter,
|
Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
|
||||||
Value payloadArg, Operation *op) {
|
Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
|
||||||
Type dtype = converter->convertType(op->getResult(0).getType())
|
Type outTy =
|
||||||
.template cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(op->getResult(0).getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
Type computeTy = outTy;
|
||||||
|
if (isa<IntegerType>(computeTy))
|
||||||
|
computeTy = b.getF32Type();
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value arg = convertScalarToDtype(b, loc, payloadArg, dtype);
|
Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy);
|
||||||
return b.create<MathOpTy>(loc, arg);
|
auto newOp = b.create<MathOpTy>(loc, arg);
|
||||||
|
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
|
@ -217,92 +221,70 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
if (isa<AtenCeilOp>(op))
|
if (isa<AtenCeilOp>(op))
|
||||||
return b.create<math::CeilOp>(loc, payloadArgs[0]);
|
return b.create<math::CeilOp>(loc, payloadArgs[0]);
|
||||||
if (isa<AtenExpOp>(op)) {
|
if (isa<AtenExpOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::ExpOp>(
|
return createFpOpWithDtype<math::ExpOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenExpm1Op>(op)) {
|
if (isa<AtenExpm1Op>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::ExpM1Op>(
|
return createFpOpWithDtype<math::ExpM1Op>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenLogOp>(op)) {
|
if (isa<AtenLogOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::LogOp>(
|
return createFpOpWithDtype<math::LogOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenLog2Op>(op)) {
|
if (isa<AtenLog2Op>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
|
return createFpOpWithDtype<math::Log2Op>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenLog10Op>(op)) {
|
if (isa<AtenLog10Op>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::Log10Op>(
|
return createFpOpWithDtype<math::Log10Op>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenLog1pOp>(op)) {
|
if (isa<AtenLog1pOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::Log1pOp>(
|
return createFpOpWithDtype<math::Log1pOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenErfOp>(op)) {
|
if (isa<AtenErfOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::ErfOp>(
|
return createFpOpWithDtype<math::ErfOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenSqrtOp>(op)) {
|
if (isa<AtenSqrtOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::SqrtOp>(
|
return createFpOpWithDtype<math::SqrtOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenRsqrtOp>(op)) {
|
if (isa<AtenRsqrtOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::RsqrtOp>(
|
return createFpOpWithDtype<math::RsqrtOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenNegOp>(op)) {
|
if (isa<AtenNegOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
|
return createFpOpWithDtype<arith::NegFOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenSinOp>(op)) {
|
if (isa<AtenSinOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::SinOp>(
|
return createFpOpWithDtype<math::SinOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenSinhOp>(op)) {
|
if (isa<AtenSinhOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::SinhOp>(
|
return createFpOpWithDtype<math::SinhOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenAsinOp>(op)) {
|
if (isa<AtenAsinOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::AsinOp>(
|
return createFpOpWithDtype<math::AsinOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenAsinhOp>(op)) {
|
if (isa<AtenAsinhOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::AsinhOp>(
|
return createFpOpWithDtype<math::AsinhOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenCosOp>(op)) {
|
if (isa<AtenCosOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::CosOp>(
|
return createFpOpWithDtype<math::CosOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenCoshOp>(op)) {
|
if (isa<AtenCoshOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::CoshOp>(
|
return createFpOpWithDtype<math::CoshOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenAcosOp>(op)) {
|
if (isa<AtenAcosOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::AcosOp>(
|
return createFpOpWithDtype<math::AcosOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenAcoshOp>(op)) {
|
if (isa<AtenAcoshOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::AcoshOp>(
|
return createFpOpWithDtype<math::AcoshOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenTanOp>(op)) {
|
if (isa<AtenTanOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::TanOp>(
|
return createFpOpWithDtype<math::TanOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenTanhOp>(op)) {
|
if (isa<AtenTanhOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
|
return createFpOpWithDtype<math::TanhOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenAtanOp>(op)) {
|
if (isa<AtenAtanOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::AtanOp>(
|
return createFpOpWithDtype<math::AtanOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (isa<AtenAtanhOp>(op)) {
|
if (isa<AtenAtanhOp>(op)) {
|
||||||
return createCalculationForMathOpWithDtypeConversion<math::AtanhOp>(
|
return createFpOpWithDtype<math::AtanhOp>(b, converter, payloadArgs[0], op);
|
||||||
b, converter, payloadArgs[0], op);
|
|
||||||
}
|
}
|
||||||
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
|
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
|
||||||
int64_t memoryFormat;
|
int64_t memoryFormat;
|
||||||
|
@ -453,13 +435,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return createEqual(b, loc, abs.getType(), abs, infinity);
|
return createEqual(b, loc, abs.getType(), abs, infinity);
|
||||||
}
|
}
|
||||||
if (isa<AtenSigmoidOp>(op)) {
|
if (isa<AtenSigmoidOp>(op)) {
|
||||||
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
|
Type inTTy = cast<ValueTensorType>(op->getOperand(0).getType()).getDtype();
|
||||||
b, converter, payloadArgs[0], op);
|
Type outTTy = cast<ValueTensorType>(op->getResult(0).getType()).getDtype();
|
||||||
|
Type outTy = cast<RankedTensorType>(
|
||||||
|
converter->convertType(op->getResult(0).getType()))
|
||||||
|
.getElementType();
|
||||||
|
Type computeTy = outTy;
|
||||||
|
if (isa<IntegerType>(computeTy))
|
||||||
|
computeTy = b.getF32Type();
|
||||||
|
|
||||||
|
Value arg = payloadArgs[0];
|
||||||
|
arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy);
|
||||||
|
auto negate = b.create<arith::NegFOp>(loc, arg);
|
||||||
auto one =
|
auto one =
|
||||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1));
|
b.create<arith::ConstantOp>(loc, FloatAttr::get(negate.getType(), 1));
|
||||||
auto exp = b.create<math::ExpOp>(loc, negate);
|
auto exp = b.create<math::ExpOp>(loc, negate);
|
||||||
auto added = b.create<arith::AddFOp>(loc, exp, one);
|
auto added = b.create<arith::AddFOp>(loc, exp, one);
|
||||||
return b.create<arith::DivFOp>(loc, one, added);
|
auto div = b.create<arith::DivFOp>(loc, one, added);
|
||||||
|
outTy.dump();
|
||||||
|
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
|
||||||
}
|
}
|
||||||
if (auto relu = dyn_cast<AtenReluOp>(op)) {
|
if (auto relu = dyn_cast<AtenReluOp>(op)) {
|
||||||
if (!relu.getType()
|
if (!relu.getType()
|
||||||
|
|
|
@ -2166,6 +2166,9 @@ ONNX_XFAIL_SET = {
|
||||||
"ReduceMaxNegativeDim_basic",
|
"ReduceMaxNegativeDim_basic",
|
||||||
"ViewSizeFromOtherTensor_basic",
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
|
||||||
|
# Failure - onnx traces differently
|
||||||
|
"ElementwiseSigmoidIntModule_basic",
|
||||||
|
|
||||||
# Failure - unknown
|
# Failure - unknown
|
||||||
"ChunkListUnpackUneven_Module_basic",
|
"ChunkListUnpackUneven_Module_basic",
|
||||||
"ChunkListUnpack_Module_basic",
|
"ChunkListUnpack_Module_basic",
|
||||||
|
@ -2192,7 +2195,6 @@ ONNX_XFAIL_SET = {
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_CRASHING_SET = {
|
ONNX_CRASHING_SET = {
|
||||||
"ElementwiseSigmoidIntModule_basic",
|
|
||||||
"FlipModule_basic",
|
"FlipModule_basic",
|
||||||
"IndexTensorNegativeIndexModule_basic",
|
"IndexTensorNegativeIndexModule_basic",
|
||||||
"MoveDimIntNegativeIndexModule_basic",
|
"MoveDimIntNegativeIndexModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue