mirror of https://github.com/llvm/torch-mlir
[Stablehlo] support unary ops which promote to floating point (#3209)
* promote input to output element-type when lowering to stablehlo, so that it could satisfy stablehlo's type constraints. * split promote-to-fp unary ops from fp-only unary ops.pull/3192/head
parent
797e4cd395
commit
1f8123b5f0
|
@ -217,6 +217,37 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// These legalizations are for unary ops with promoting to floating point
|
||||||
|
// datatypes.
|
||||||
|
namespace {
|
||||||
|
template <typename AtenOpT, typename StablehloOpT>
|
||||||
|
class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Value self = adaptor.getSelf();
|
||||||
|
auto selfTy = self.getType().cast<TensorType>();
|
||||||
|
if (!selfTy)
|
||||||
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template cast<TensorType>();
|
||||||
|
|
||||||
|
if (resultTy.getElementType().template isa<mlir::FloatType>()) {
|
||||||
|
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
||||||
|
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
|
||||||
|
return success();
|
||||||
|
} else {
|
||||||
|
return op.emitError(
|
||||||
|
"only result to be floating-point datatype legalization supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// aten.ones & aten.zeros
|
// aten.ones & aten.zeros
|
||||||
// Ref: Error checking based on the Torch to TOSA lowering
|
// Ref: Error checking based on the Torch to TOSA lowering
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -1888,23 +1919,29 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
|
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
|
||||||
context)
|
context)
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenLog1pOp, stablehlo::Log1pOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp);
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp);
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp);
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenAsinOp, chlo::AsinOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenAcosOp, chlo::AcosOp);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenAtanOp, chlo::AtanOp);
|
|
||||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, StablehloOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenUnaryPromoteToFPOp<AtenOp, StablehloOp>>( \
|
||||||
|
typeConverter, context)
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, stablehlo::LogOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLog1pOp, stablehlo::Log1pOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, stablehlo::ExpOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp);
|
||||||
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp);
|
||||||
|
#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||||
|
|
|
@ -620,17 +620,14 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"DiagonalModule_with_offset",
|
"DiagonalModule_with_offset",
|
||||||
"DivFloatModule_basic",
|
"DivFloatModule_basic",
|
||||||
"DivIntModule_basic",
|
"DivIntModule_basic",
|
||||||
"ElementwiseAcosIntModule_basic",
|
|
||||||
"ElementwiseAcoshIntModule_basic",
|
"ElementwiseAcoshIntModule_basic",
|
||||||
"ElementwiseAcoshModule_basic",
|
"ElementwiseAcoshModule_basic",
|
||||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||||
"ElementwiseAsinIntModule_basic",
|
|
||||||
"ElementwiseAsinhIntModule_basic",
|
"ElementwiseAsinhIntModule_basic",
|
||||||
"ElementwiseAsinhModule_basic",
|
"ElementwiseAsinhModule_basic",
|
||||||
"ElementwiseAtan2FloatIntModule_basic",
|
"ElementwiseAtan2FloatIntModule_basic",
|
||||||
"ElementwiseAtan2TensorFloatModule_basic",
|
"ElementwiseAtan2TensorFloatModule_basic",
|
||||||
"ElementwiseAtan2TensorIntModule_basic",
|
"ElementwiseAtan2TensorIntModule_basic",
|
||||||
"ElementwiseAtanTensorIntModule_basic",
|
|
||||||
"ElementwiseAtanhIntModule_basic",
|
"ElementwiseAtanhIntModule_basic",
|
||||||
"ElementwiseAtanhModule_basic",
|
"ElementwiseAtanhModule_basic",
|
||||||
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||||
|
@ -639,7 +636,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ElementwiseBitwiseRightShiftInt32Module_basic",
|
"ElementwiseBitwiseRightShiftInt32Module_basic",
|
||||||
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
||||||
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||||
"ElementwiseCosIntModule_basic",
|
|
||||||
"ElementwiseCoshIntModule_basic",
|
"ElementwiseCoshIntModule_basic",
|
||||||
"ElementwiseCoshModule_basic",
|
"ElementwiseCoshModule_basic",
|
||||||
"ElementwiseDequantizePerChannelModule_basic",
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
|
@ -649,22 +645,16 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"ElementwiseLog10Module_basic",
|
"ElementwiseLog10Module_basic",
|
||||||
"ElementwiseLog2IntModule_basic",
|
"ElementwiseLog2IntModule_basic",
|
||||||
"ElementwiseLog2Module_basic",
|
"ElementwiseLog2Module_basic",
|
||||||
"ElementwiseLogIntModule_basic",
|
|
||||||
"ElementwiseLogitModule_basic",
|
"ElementwiseLogitModule_basic",
|
||||||
"ElementwiseMulTensorComplexModule_basic",
|
"ElementwiseMulTensorComplexModule_basic",
|
||||||
"ElementwisePowScalarModule_basic",
|
"ElementwisePowScalarModule_basic",
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||||
"ElementwiseReciprocalIntModule_basic",
|
"ElementwiseReciprocalIntModule_basic",
|
||||||
"ElementwiseRsqrtIntModule_basic",
|
|
||||||
"ElementwiseSigmoidIntModule_basic",
|
|
||||||
"ElementwiseSinIntModule_basic",
|
|
||||||
"ElementwiseSqrtIntModule_basic",
|
|
||||||
"ElementwiseTanIntModule_basic",
|
"ElementwiseTanIntModule_basic",
|
||||||
"ElementwiseTanModule_basic",
|
"ElementwiseTanModule_basic",
|
||||||
"ElementwiseTernaryModule_basic",
|
"ElementwiseTernaryModule_basic",
|
||||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
"ElementwiseUnaryIntModule_basic",
|
|
||||||
"EmptyModule_uint8",
|
"EmptyModule_uint8",
|
||||||
"EqIntModule_basic",
|
"EqIntModule_basic",
|
||||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||||
|
@ -1464,6 +1454,17 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseAcosModule_basic",
|
"ElementwiseAcosModule_basic",
|
||||||
"ElementwiseAsinModule_basic",
|
"ElementwiseAsinModule_basic",
|
||||||
"ElementwiseAtanTensorFloatModule_basic",
|
"ElementwiseAtanTensorFloatModule_basic",
|
||||||
|
"ElementwiseAcosIntModule_basic",
|
||||||
|
"ElementwiseAsinIntModule_basic",
|
||||||
|
"ElementwiseAtanTensorIntModule_basic",
|
||||||
|
"ElementwiseCosIntModule_basic",
|
||||||
|
"ElementwiseExpIntModule_basic",
|
||||||
|
"ElementwiseLogIntModule_basic",
|
||||||
|
"ElementwiseRsqrtIntModule_basic",
|
||||||
|
"ElementwiseSigmoidIntModule_basic",
|
||||||
|
"ElementwiseSinIntModule_basic",
|
||||||
|
"ElementwiseSqrtIntModule_basic",
|
||||||
|
"ElementwiseUnaryIntModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLEHLO_CRASHING_SET = {
|
STABLEHLO_CRASHING_SET = {
|
||||||
|
|
Loading…
Reference in New Issue