mirror of https://github.com/llvm/torch-mlir
[RefBackend] Use new upstream SCF type conversions.
parent
4f2aa12d1a
commit
b6ae53b312
|
@ -22,6 +22,7 @@ add_npcomp_library(NPCOMPRefBackend
|
|||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRSCFToStandard
|
||||
MLIRSCFTransforms
|
||||
MLIRShapeToStandard
|
||||
MLIRStandard
|
||||
MLIRStandardOpsTransforms
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "npcomp/RefBackend/RefBackend.h"
|
||||
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
|
@ -27,56 +28,7 @@ using namespace mlir::NPCOMP;
|
|||
// conversion about them.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// This is a type conversion similar to CallOpSignatureConversion.
|
||||
class LowerIfOpTypes : public OpConversionPattern<scf::IfOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::IfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
SmallVector<Type, 6> newResultTypes;
|
||||
for (auto type : op.getResultTypes()) {
|
||||
Type newType = typeConverter->convertType(type);
|
||||
if (!newType)
|
||||
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||
newResultTypes.push_back(newType);
|
||||
}
|
||||
rewriter.updateRootInPlace(op, [&] {
|
||||
for (auto t : llvm::zip(op.getResults(), newResultTypes))
|
||||
std::get<0>(t).setType(std::get<1>(t));
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// This is a type conversion similar to CallOpSignatureConversion.
|
||||
class LowerForOpTypes : public OpConversionPattern<scf::ForOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::ForOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
SmallVector<Type, 6> newResultTypes;
|
||||
for (auto type : op.getResultTypes()) {
|
||||
Type newType = typeConverter->convertType(type);
|
||||
if (!newType)
|
||||
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||
newResultTypes.push_back(newType);
|
||||
}
|
||||
rewriter.updateRootInPlace(op, [&] {
|
||||
for (auto t : llvm::zip(op.getResults(), newResultTypes))
|
||||
std::get<0>(t).setType(std::get<1>(t));
|
||||
auto bodyArgs = op.getBody()->getArguments();
|
||||
for (auto t : llvm::zip(llvm::drop_begin(bodyArgs, 1), newResultTypes))
|
||||
std::get<0>(t).setType(std::get<1>(t));
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// This is a type conversion similar to CallOpSignatureConversion.
|
||||
|
@ -155,9 +107,9 @@ class LowerStructuralToMemref
|
|||
typeConverter.isLegal(&op.getBody());
|
||||
});
|
||||
|
||||
scf::populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
|
||||
patterns, target);
|
||||
patterns.insert<LowerSelectOpTypes>(typeConverter, context);
|
||||
patterns.insert<LowerIfOpTypes>(typeConverter, context);
|
||||
patterns.insert<LowerForOpTypes>(typeConverter, context);
|
||||
patterns.insert<LowerTensorToMemrefOp>(typeConverter, context);
|
||||
patterns.insert<LowerTensorLoadOp>(typeConverter, context);
|
||||
target.addIllegalOp<TensorToMemrefOp>();
|
||||
|
|
Loading…
Reference in New Issue