Add extf-trunc f32-f64-f32 ellision (#3579)

Torch has all scalars represented as i64 and f64 types which results in
extraneous trunc-extf commands. We can rework this by elliding
widen-narrow cases away.
pull/3516/head
Rob Suderman 2024-07-31 16:50:00 -07:00 committed by GitHub
parent 7b2902f6e2
commit 7f475e174e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 0 deletions

View File

@ -9,10 +9,12 @@
#include "PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
@ -27,6 +29,25 @@ using namespace mlir::torch::TorchConversion;
namespace {
// TODO: Consider upstreaming this to an `arith::ExtFOp` folder:
struct ExtFTruncFPattern : public OpRewritePattern<arith::TruncFOp> {
ExtFTruncFPattern(MLIRContext *context) : OpRewritePattern(context) {}
LogicalResult matchAndRewrite(arith::TruncFOp truncf,
PatternRewriter &rewriter) const override {
Value operand = truncf.getOperand();
auto extf = operand.getDefiningOp<arith::ExtFOp>();
if (!extf)
return failure();
auto parentOperand = extf.getOperand();
if (truncf.getType() != parentOperand.getType())
return failure();
rewriter.replaceOp(truncf, parentOperand);
return success();
}
};
void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target) {
@ -209,6 +230,11 @@ struct FinalizingBackendTypeConversionPass
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
RewritePatternSet greedyPatterns(context);
greedyPatterns.insert<ExtFTruncFPattern>(context);
if (failed(applyPatternsAndFoldGreedily(func, std::move(greedyPatterns))))
signalPassFailure();
// Drop attributes that are no longer used after conversion out of Torch.
stripTorchAttrs(func);
}

View File

@ -83,3 +83,13 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) {
"test.sink"(%0) : (!torch.vtensor<[],f32>) -> ()
return
}
// -----
// CHECK-LABEL: @extfTruncf
func.func @extfTruncf(%arg0: f32) -> f32 {
%f64 = arith.extf %arg0 : f32 to f64
%f32 = arith.truncf %f64 : f64 to f32
// CHECK: return %arg0
return %f32 : f32
}