mirror of https://github.com/llvm/torch-mlir
Add extf-trunc f32-f64-f32 ellision (#3579)
Torch has all scalars represented as i64 and f64 types which results in extraneous trunc-extf commands. We can rework this by elliding widen-narrow cases away.pull/3516/head
parent
7b2902f6e2
commit
7f475e174e
|
@ -9,10 +9,12 @@
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
@ -27,6 +29,25 @@ using namespace mlir::torch::TorchConversion;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// TODO: Consider upstreaming this to an `arith::ExtFOp` folder:
|
||||||
|
struct ExtFTruncFPattern : public OpRewritePattern<arith::TruncFOp> {
|
||||||
|
ExtFTruncFPattern(MLIRContext *context) : OpRewritePattern(context) {}
|
||||||
|
LogicalResult matchAndRewrite(arith::TruncFOp truncf,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value operand = truncf.getOperand();
|
||||||
|
auto extf = operand.getDefiningOp<arith::ExtFOp>();
|
||||||
|
if (!extf)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto parentOperand = extf.getOperand();
|
||||||
|
if (truncf.getType() != parentOperand.getType())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOp(truncf, parentOperand);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter,
|
void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns,
|
RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target) {
|
||||||
|
@ -209,6 +230,11 @@ struct FinalizingBackendTypeConversionPass
|
||||||
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
|
||||||
|
RewritePatternSet greedyPatterns(context);
|
||||||
|
greedyPatterns.insert<ExtFTruncFPattern>(context);
|
||||||
|
if (failed(applyPatternsAndFoldGreedily(func, std::move(greedyPatterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
|
||||||
// Drop attributes that are no longer used after conversion out of Torch.
|
// Drop attributes that are no longer used after conversion out of Torch.
|
||||||
stripTorchAttrs(func);
|
stripTorchAttrs(func);
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,3 +83,13 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: tensor<f32>) {
|
||||||
"test.sink"(%0) : (!torch.vtensor<[],f32>) -> ()
|
"test.sink"(%0) : (!torch.vtensor<[],f32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @extfTruncf
|
||||||
|
func.func @extfTruncf(%arg0: f32) -> f32 {
|
||||||
|
%f64 = arith.extf %arg0 : f32 to f64
|
||||||
|
%f32 = arith.truncf %f64 : f64 to f32
|
||||||
|
// CHECK: return %arg0
|
||||||
|
return %f32 : f32
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue