mirror of https://github.com/llvm/torch-mlir
[Stablehlo] support unary ops which promote to floating point (#3209)
* promote input to output element-type when lowering to stablehlo, so that it could satisfy stablehlo's type constraints. * split promote-to-fp unary ops from fp-only unary ops.pull/3192/head
parent
797e4cd395
commit
1f8123b5f0
|
@ -217,6 +217,37 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// These legalizations are for unary ops with promoting to floating point
|
||||
// datatypes.
|
||||
namespace {
|
||||
template <typename AtenOpT, typename StablehloOpT>
|
||||
class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<TensorType>();
|
||||
if (!selfTy)
|
||||
return op.emitError("only Tensor types supported in StableHLO");
|
||||
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
|
||||
if (resultTy.getElementType().template isa<mlir::FloatType>()) {
|
||||
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
|
||||
return success();
|
||||
} else {
|
||||
return op.emitError(
|
||||
"only result to be floating-point datatype legalization supported");
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// aten.ones & aten.zeros
|
||||
// Ref: Error checking based on the Torch to TOSA lowering
|
||||
namespace {
|
||||
|
@ -1888,23 +1919,29 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
|
||||
context)
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLog1pOp, stablehlo::Log1pOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenTanhOp, stablehlo::TanhOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenSinOp, stablehlo::SineOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenAsinOp, chlo::AsinOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenAcosOp, chlo::AcosOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenAtanOp, chlo::AtanOp);
|
||||
#undef INSERT_UNARY_FPONLY_PATTERN
|
||||
|
||||
#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, StablehloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryPromoteToFPOp<AtenOp, StablehloOp>>( \
|
||||
typeConverter, context)
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, stablehlo::LogOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLog1pOp, stablehlo::Log1pOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, stablehlo::ExpOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp);
|
||||
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp);
|
||||
#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN
|
||||
|
||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||
|
|
|
@ -620,17 +620,14 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"DiagonalModule_with_offset",
|
||||
"DivFloatModule_basic",
|
||||
"DivIntModule_basic",
|
||||
"ElementwiseAcosIntModule_basic",
|
||||
"ElementwiseAcoshIntModule_basic",
|
||||
"ElementwiseAcoshModule_basic",
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
"ElementwiseAsinIntModule_basic",
|
||||
"ElementwiseAsinhIntModule_basic",
|
||||
"ElementwiseAsinhModule_basic",
|
||||
"ElementwiseAtan2FloatIntModule_basic",
|
||||
"ElementwiseAtan2TensorFloatModule_basic",
|
||||
"ElementwiseAtan2TensorIntModule_basic",
|
||||
"ElementwiseAtanTensorIntModule_basic",
|
||||
"ElementwiseAtanhIntModule_basic",
|
||||
"ElementwiseAtanhModule_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||
|
@ -639,7 +636,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ElementwiseBitwiseRightShiftInt32Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseCoshIntModule_basic",
|
||||
"ElementwiseCoshModule_basic",
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
|
@ -649,22 +645,16 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ElementwiseLog10Module_basic",
|
||||
"ElementwiseLog2IntModule_basic",
|
||||
"ElementwiseLog2Module_basic",
|
||||
"ElementwiseLogIntModule_basic",
|
||||
"ElementwiseLogitModule_basic",
|
||||
"ElementwiseMulTensorComplexModule_basic",
|
||||
"ElementwisePowScalarModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseReciprocalIntModule_basic",
|
||||
"ElementwiseRsqrtIntModule_basic",
|
||||
"ElementwiseSigmoidIntModule_basic",
|
||||
"ElementwiseSinIntModule_basic",
|
||||
"ElementwiseSqrtIntModule_basic",
|
||||
"ElementwiseTanIntModule_basic",
|
||||
"ElementwiseTanModule_basic",
|
||||
"ElementwiseTernaryModule_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
"ElementwiseUnaryIntModule_basic",
|
||||
"EmptyModule_uint8",
|
||||
"EqIntModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
|
@ -1464,6 +1454,17 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseAcosModule_basic",
|
||||
"ElementwiseAsinModule_basic",
|
||||
"ElementwiseAtanTensorFloatModule_basic",
|
||||
"ElementwiseAcosIntModule_basic",
|
||||
"ElementwiseAsinIntModule_basic",
|
||||
"ElementwiseAtanTensorIntModule_basic",
|
||||
"ElementwiseCosIntModule_basic",
|
||||
"ElementwiseExpIntModule_basic",
|
||||
"ElementwiseLogIntModule_basic",
|
||||
"ElementwiseRsqrtIntModule_basic",
|
||||
"ElementwiseSigmoidIntModule_basic",
|
||||
"ElementwiseSinIntModule_basic",
|
||||
"ElementwiseSqrtIntModule_basic",
|
||||
"ElementwiseUnaryIntModule_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_CRASHING_SET = {
|
||||
|
|
Loading…
Reference in New Issue