mirror of https://github.com/llvm/torch-mlir
Add extf-trunc f32-f64-f32 ellision (#3579)
Torch has all scalars represented as i64 and f64 types which results in extraneous trunc-extf commands. We can rework this by elliding widen-narrow cases away.pull/3516/head
parent
7b2902f6e2
commit
7f475e174e
|
@ -9,10 +9,12 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
|
@ -27,6 +29,25 @@ using namespace mlir::torch::TorchConversion;
|
|||
|
||||
namespace {
|
||||
|
||||
// TODO: Consider upstreaming this to an `arith::ExtFOp` folder:
|
||||
struct ExtFTruncFPattern : public OpRewritePattern<arith::TruncFOp> {
|
||||
ExtFTruncFPattern(MLIRContext *context) : OpRewritePattern(context) {}
|
||||
LogicalResult matchAndRewrite(arith::TruncFOp truncf,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value operand = truncf.getOperand();
|
||||
auto extf = operand.getDefiningOp<arith::ExtFOp>();
|
||||
if (!extf)
|
||||
return failure();
|
||||
|
||||
auto parentOperand = extf.getOperand();
|
||||
if (truncf.getType() != parentOperand.getType())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(truncf, parentOperand);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
|
@ -209,6 +230,11 @@ struct FinalizingBackendTypeConversionPass
|
|||
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
|
||||
RewritePatternSet greedyPatterns(context);
|
||||
greedyPatterns.insert<ExtFTruncFPattern>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(func, std::move(greedyPatterns))))
|
||||
signalPassFailure();
|
||||
|
||||
// Drop attributes that are no longer used after conversion out of Torch.
|
||||
stripTorchAttrs(func);
|
||||
}
|
||||
|
|
|
@ -83,3 +83,13 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) {
|
|||
"test.sink"(%0) : (!torch.vtensor<[],f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @extfTruncf
|
||||
func.func @extfTruncf(%arg0: f32) -> f32 {
|
||||
%f64 = arith.extf %arg0 : f32 to f64
|
||||
%f32 = arith.truncf %f64 : f64 to f32
|
||||
// CHECK: return %arg0
|
||||
return %f32 : f32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue