Add TorchToSCF pass.

1. Add TorchToSCF pass.
2. Convert prim.If and prim.If.yield.
pull/239/head
Yi Zhang 2021-06-18 19:40:40 +00:00 committed by Sean Silva
parent 5ad144c4fe
commit 45f2edfc7a
8 changed files with 208 additions and 0 deletions

View File

@ -20,6 +20,11 @@ def ConvertTorchToStd : Pass<"convert-torch-to-std", "FuncOp"> {
let constructor = "mlir::NPCOMP::createConvertTorchToStdPass()"; let constructor = "mlir::NPCOMP::createConvertTorchToStdPass()";
} }
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "FuncOp"> {
let summary = "Convert recognized Torch ops to SCF ops";
let constructor = "mlir::NPCOMP::createConvertTorchToSCFPass()";
}
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> { def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
let summary = "Convert recognized Torch ops to Linalg ops"; let summary = "Convert recognized Torch ops to Linalg ops";
let description = [{ let description = [{

View File

@ -0,0 +1,20 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
#define NPCOMP_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToSCFPass();
}
} // namespace mlir
#endif // NPCOMP_CONVERSION_TORCHTOSCF_TORCHTOSCF_H

View File

@ -1,4 +1,5 @@
add_subdirectory(TorchToLinalg) add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToStd) add_subdirectory(TorchToStd)
add_subdirectory(BasicpyToStd) add_subdirectory(BasicpyToStd)
add_subdirectory(NumpyToTCF) add_subdirectory(NumpyToTCF)

View File

@ -14,6 +14,7 @@
#include "npcomp/Conversion/TCFToStd/TCFToStd.h" #include "npcomp/Conversion/TCFToStd/TCFToStd.h"
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h" #include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h" #include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
#include "npcomp/Conversion/TorchToStd/TorchToStd.h" #include "npcomp/Conversion/TorchToStd/TorchToStd.h"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,19 @@
add_npcomp_conversion_library(NPCOMPTorchToSCF
TorchToSCF.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TorchToSCF
DEPENDS
NPCOMPConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRSCF
MLIRStandard
NPCOMPTorchDialect
)

View File

@ -0,0 +1,93 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
#include "../PassDetail.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
namespace {
class ConvertTorchPrimIfYieldOp : public OpConversionPattern<PrimIfYieldOp> {
public:
using OpConversionPattern<PrimIfYieldOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PrimIfYieldOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, operands);
return success();
}
};
} // namespace
namespace {
class ConvertTorchPrimIfOp : public OpConversionPattern<PrimIfOp> {
public:
using OpConversionPattern<PrimIfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(PrimIfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type, 1> newResultTypes;
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
newResultTypes)))
return rewriter.notifyMatchFailure(op,
"could not convert PrimIfOp outputs");
auto scfIf = rewriter.create<scf::IfOp>(
op->getLoc(), newResultTypes, operands[0], /*withElseRegion=*/true);
auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
rewriter.eraseBlock(&dstRegion.back());
};
inlineIfCase(op.thenRegion(), scfIf.thenRegion());
inlineIfCase(op.elseRegion(), scfIf.elseRegion());
rewriter.replaceOp(op, scfIf.getResults());
return success();
}
};
} // namespace
namespace {
class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context);
target.addIllegalOp<PrimIfOp>();
patterns.add<ConvertTorchPrimIfOp>(typeConverter, context);
target.addIllegalOp<PrimIfYieldOp>();
patterns.add<ConvertTorchPrimIfYieldOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertTorchToSCFPass() {
return std::make_unique<ConvertTorchToSCF>();
}

View File

@ -13,6 +13,7 @@
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "npcomp/Backend/Common/Passes.h" #include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h" #include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
#include "npcomp/Conversion/TorchToStd/TorchToStd.h" #include "npcomp/Conversion/TorchToStd/TorchToStd.h"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -150,6 +151,8 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
// TODO: Improve torch op canonicalizations. // TODO: Improve torch op canonicalizations.
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass()); pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
// Lower to linalg + guards which is the input to codegen backends. // Lower to linalg + guards which is the input to codegen backends.
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass()); pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());

View File

@ -0,0 +1,66 @@
// RUN: npcomp-opt <%s -convert-torch-to-scf | FileCheck %s
// CHECK-LABEL: func @torch.prim.if(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 2
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.to_i1 %[[VAL_0]]
// CHECK: %[[VAL_4:.*]] = scf.if %[[VAL_3]] -> (i64) {
// CHECK: %[[VAL_5:.*]] = torch.to_i64 %[[VAL_1]]
// CHECK: scf.yield %[[VAL_5]] : i64
// CHECK: } else {
// CHECK: %[[VAL_6:.*]] = torch.to_i64 %[[VAL_2]]
// CHECK: scf.yield %[[VAL_6]] : i64
// CHECK: }
// CHECK: %[[VAL_7:.*]] = torch.from_i64 %[[VAL_8:.*]]
// CHECK: return %[[VAL_7]] : !torch.int
func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%0 = torch.prim.If %arg0 -> (!torch.int) {
torch.prim.If.yield %int2 : !torch.int
} else {
torch.prim.If.yield %int1 : !torch.int
}
return %0 : !torch.int
}
// CHECK-LABEL: func @aten.prim.if$nested(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.constant.int 4
// CHECK: %[[VAL_5:.*]] = torch.to_i1 %[[VAL_0]]
// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_5]] -> (i64) {
// CHECK: %[[VAL_7:.*]] = torch.to_i1 %[[VAL_1]]
// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (i64) {
// CHECK: %[[VAL_9:.*]] = torch.to_i64 %[[VAL_2]]
// CHECK: scf.yield %[[VAL_9]] : i64
// CHECK: } else {
// CHECK: %[[VAL_10:.*]] = torch.to_i64 %[[VAL_3]]
// CHECK: scf.yield %[[VAL_10]] : i64
// CHECK: }
// CHECK: scf.yield %[[VAL_11:.*]] : i64
// CHECK: } else {
// CHECK: %[[VAL_12:.*]] = torch.to_i64 %[[VAL_4]]
// CHECK: scf.yield %[[VAL_12]] : i64
// CHECK: }
// CHECK: %[[VAL_13:.*]] = torch.from_i64 %[[VAL_14:.*]]
// CHECK: return %[[VAL_13]] : !torch.int
func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int {
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%0 = torch.prim.If %arg0 -> (!torch.int) {
%1 = torch.prim.If %arg1 -> (!torch.int) {
torch.prim.If.yield %int2 : !torch.int
} else {
torch.prim.If.yield %int3 : !torch.int
}
torch.prim.If.yield %1 : !torch.int
} else {
torch.prim.If.yield %int4 : !torch.int
}
return %0 : !torch.int
}