mirror of https://github.com/llvm/torch-mlir
Add TorchToSCF pass.
1. Add TorchToSCF pass. 2. Convert prim.If and prim.If.yield.pull/239/head
parent
5ad144c4fe
commit
45f2edfc7a
|
@ -20,6 +20,11 @@ def ConvertTorchToStd : Pass<"convert-torch-to-std", "FuncOp"> {
|
|||
let constructor = "mlir::NPCOMP::createConvertTorchToStdPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "FuncOp"> {
|
||||
let summary = "Convert recognized Torch ops to SCF ops";
|
||||
let constructor = "mlir::NPCOMP::createConvertTorchToSCFPass()";
|
||||
}
|
||||
|
||||
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
||||
let summary = "Convert recognized Torch ops to Linalg ops";
|
||||
let description = [{
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
||||
#define NPCOMP_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToSCFPass();
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
|
@ -1,4 +1,5 @@
|
|||
add_subdirectory(TorchToLinalg)
|
||||
add_subdirectory(TorchToSCF)
|
||||
add_subdirectory(TorchToStd)
|
||||
add_subdirectory(BasicpyToStd)
|
||||
add_subdirectory(NumpyToTCF)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
|
||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
add_npcomp_conversion_library(NPCOMPTorchToSCF
|
||||
TorchToSCF.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TorchToSCF
|
||||
|
||||
DEPENDS
|
||||
NPCOMPConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRSCF
|
||||
MLIRStandard
|
||||
NPCOMPTorchDialect
|
||||
)
|
|
@ -0,0 +1,93 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::Torch;
|
||||
|
||||
namespace {
|
||||
class ConvertTorchPrimIfYieldOp : public OpConversionPattern<PrimIfYieldOp> {
|
||||
public:
|
||||
using OpConversionPattern<PrimIfYieldOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PrimIfYieldOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertTorchPrimIfOp : public OpConversionPattern<PrimIfOp> {
|
||||
public:
|
||||
using OpConversionPattern<PrimIfOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PrimIfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
SmallVector<Type, 1> newResultTypes;
|
||||
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
|
||||
newResultTypes)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"could not convert PrimIfOp outputs");
|
||||
auto scfIf = rewriter.create<scf::IfOp>(
|
||||
op->getLoc(), newResultTypes, operands[0], /*withElseRegion=*/true);
|
||||
auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
|
||||
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
|
||||
rewriter.eraseBlock(&dstRegion.back());
|
||||
};
|
||||
inlineIfCase(op.thenRegion(), scfIf.thenRegion());
|
||||
inlineIfCase(op.elseRegion(), scfIf.elseRegion());
|
||||
rewriter.replaceOp(op, scfIf.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<scf::SCFDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
target.addIllegalOp<PrimIfOp>();
|
||||
patterns.add<ConvertTorchPrimIfOp>(typeConverter, context);
|
||||
target.addIllegalOp<PrimIfYieldOp>();
|
||||
patterns.add<ConvertTorchPrimIfYieldOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createConvertTorchToSCFPass() {
|
||||
return std::make_unique<ConvertTorchToSCF>();
|
||||
}
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/Transforms/Passes.h"
|
||||
#include "npcomp/Backend/Common/Passes.h"
|
||||
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -150,6 +151,8 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
|||
// TODO: Improve torch op canonicalizations.
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
||||
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
|
||||
|
||||
// Lower to linalg + guards which is the input to codegen backends.
|
||||
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
||||
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
// RUN: npcomp-opt <%s -convert-torch-to-scf | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @torch.prim.if(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
|
||||
// CHECK: %[[VAL_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.to_i1 %[[VAL_0]]
|
||||
// CHECK: %[[VAL_4:.*]] = scf.if %[[VAL_3]] -> (i64) {
|
||||
// CHECK: %[[VAL_5:.*]] = torch.to_i64 %[[VAL_1]]
|
||||
// CHECK: scf.yield %[[VAL_5]] : i64
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[VAL_6:.*]] = torch.to_i64 %[[VAL_2]]
|
||||
// CHECK: scf.yield %[[VAL_6]] : i64
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_7:.*]] = torch.from_i64 %[[VAL_8:.*]]
|
||||
// CHECK: return %[[VAL_7]] : !torch.int
|
||||
func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.prim.If %arg0 -> (!torch.int) {
|
||||
torch.prim.If.yield %int2 : !torch.int
|
||||
} else {
|
||||
torch.prim.If.yield %int1 : !torch.int
|
||||
}
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @aten.prim.if$nested(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.bool) -> !torch.int {
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[VAL_5:.*]] = torch.to_i1 %[[VAL_0]]
|
||||
// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_5]] -> (i64) {
|
||||
// CHECK: %[[VAL_7:.*]] = torch.to_i1 %[[VAL_1]]
|
||||
// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (i64) {
|
||||
// CHECK: %[[VAL_9:.*]] = torch.to_i64 %[[VAL_2]]
|
||||
// CHECK: scf.yield %[[VAL_9]] : i64
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[VAL_10:.*]] = torch.to_i64 %[[VAL_3]]
|
||||
// CHECK: scf.yield %[[VAL_10]] : i64
|
||||
// CHECK: }
|
||||
// CHECK: scf.yield %[[VAL_11:.*]] : i64
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[VAL_12:.*]] = torch.to_i64 %[[VAL_4]]
|
||||
// CHECK: scf.yield %[[VAL_12]] : i64
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_13:.*]] = torch.from_i64 %[[VAL_14:.*]]
|
||||
// CHECK: return %[[VAL_13]] : !torch.int
|
||||
func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int {
|
||||
%int2 = torch.constant.int 2
|
||||
%int3 = torch.constant.int 3
|
||||
%int4 = torch.constant.int 4
|
||||
%0 = torch.prim.If %arg0 -> (!torch.int) {
|
||||
%1 = torch.prim.If %arg1 -> (!torch.int) {
|
||||
torch.prim.If.yield %int2 : !torch.int
|
||||
} else {
|
||||
torch.prim.If.yield %int3 : !torch.int
|
||||
}
|
||||
torch.prim.If.yield %1 : !torch.int
|
||||
} else {
|
||||
torch.prim.If.yield %int4 : !torch.int
|
||||
}
|
||||
return %0 : !torch.int
|
||||
}
|
Loading…
Reference in New Issue