mirror of https://github.com/llvm/torch-mlir
Add TorchToSCF pass.
1. Add TorchToSCF pass. 2. Convert prim.If and prim.If.yield.pull/239/head
parent
5ad144c4fe
commit
45f2edfc7a
|
@ -20,6 +20,11 @@ def ConvertTorchToStd : Pass<"convert-torch-to-std", "FuncOp"> {
|
||||||
let constructor = "mlir::NPCOMP::createConvertTorchToStdPass()";
|
let constructor = "mlir::NPCOMP::createConvertTorchToStdPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "FuncOp"> {
|
||||||
|
let summary = "Convert recognized Torch ops to SCF ops";
|
||||||
|
let constructor = "mlir::NPCOMP::createConvertTorchToSCFPass()";
|
||||||
|
}
|
||||||
|
|
||||||
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "FuncOp"> {
|
||||||
let summary = "Convert recognized Torch ops to Linalg ops";
|
let summary = "Convert recognized Torch ops to Linalg ops";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
//===------------------------------------------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
||||||
|
#define NPCOMP_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace NPCOMP {
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createConvertTorchToSCFPass();
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // NPCOMP_CONVERSION_TORCHTOSCF_TORCHTOSCF_H
|
|
@ -1,4 +1,5 @@
|
||||||
add_subdirectory(TorchToLinalg)
|
add_subdirectory(TorchToLinalg)
|
||||||
|
add_subdirectory(TorchToSCF)
|
||||||
add_subdirectory(TorchToStd)
|
add_subdirectory(TorchToStd)
|
||||||
add_subdirectory(BasicpyToStd)
|
add_subdirectory(BasicpyToStd)
|
||||||
add_subdirectory(NumpyToTCF)
|
add_subdirectory(NumpyToTCF)
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
|
#include "npcomp/Conversion/TCFToStd/TCFToStd.h"
|
||||||
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
#include "npcomp/Conversion/TCFToTCP/TCFToTCP.h"
|
||||||
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
|
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
add_npcomp_conversion_library(NPCOMPTorchToSCF
|
||||||
|
TorchToSCF.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/TorchToSCF
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
NPCOMPConversionPassIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRPass
|
||||||
|
MLIRSCF
|
||||||
|
MLIRStandard
|
||||||
|
NPCOMPTorchDialect
|
||||||
|
)
|
|
@ -0,0 +1,93 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
|
|
||||||
|
#include "../PassDetail.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "npcomp/Dialect/Torch/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
using namespace mlir::NPCOMP::Torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertTorchPrimIfYieldOp : public OpConversionPattern<PrimIfYieldOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<PrimIfYieldOp>::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(PrimIfYieldOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, operands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertTorchPrimIfOp : public OpConversionPattern<PrimIfOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<PrimIfOp>::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(PrimIfOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<Type, 1> newResultTypes;
|
||||||
|
if (failed(getTypeConverter()->convertTypes(op.getResultTypes(),
|
||||||
|
newResultTypes)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"could not convert PrimIfOp outputs");
|
||||||
|
auto scfIf = rewriter.create<scf::IfOp>(
|
||||||
|
op->getLoc(), newResultTypes, operands[0], /*withElseRegion=*/true);
|
||||||
|
auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
|
||||||
|
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
|
||||||
|
rewriter.eraseBlock(&dstRegion.back());
|
||||||
|
};
|
||||||
|
inlineIfCase(op.thenRegion(), scfIf.thenRegion());
|
||||||
|
inlineIfCase(op.elseRegion(), scfIf.elseRegion());
|
||||||
|
rewriter.replaceOp(op, scfIf.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertTorchToSCF : public ConvertTorchToSCFBase<ConvertTorchToSCF> {
|
||||||
|
public:
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<scf::SCFDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
target.addLegalDialect<Torch::TorchDialect, scf::SCFDialect>();
|
||||||
|
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
target.addIllegalOp<PrimIfOp>();
|
||||||
|
patterns.add<ConvertTorchPrimIfOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<PrimIfYieldOp>();
|
||||||
|
patterns.add<ConvertTorchPrimIfYieldOp>(typeConverter, context);
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
std::move(patterns))))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
|
mlir::NPCOMP::createConvertTorchToSCFPass() {
|
||||||
|
return std::make_unique<ConvertTorchToSCF>();
|
||||||
|
}
|
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "npcomp/Backend/Common/Passes.h"
|
#include "npcomp/Backend/Common/Passes.h"
|
||||||
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "npcomp/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
|
#include "npcomp/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
#include "npcomp/Conversion/TorchToStd/TorchToStd.h"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -150,6 +151,8 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
||||||
// TODO: Improve torch op canonicalizations.
|
// TODO: Improve torch op canonicalizations.
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
||||||
|
|
||||||
|
pm.addNestedPass<FuncOp>(createConvertTorchToSCFPass());
|
||||||
|
|
||||||
// Lower to linalg + guards which is the input to codegen backends.
|
// Lower to linalg + guards which is the input to codegen backends.
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
// RUN: npcomp-opt <%s -convert-torch-to-scf | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.prim.if(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.to_i1 %[[VAL_0]]
|
||||||
|
// CHECK: %[[VAL_4:.*]] = scf.if %[[VAL_3]] -> (i64) {
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.to_i64 %[[VAL_1]]
|
||||||
|
// CHECK: scf.yield %[[VAL_5]] : i64
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch.to_i64 %[[VAL_2]]
|
||||||
|
// CHECK: scf.yield %[[VAL_6]] : i64
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch.from_i64 %[[VAL_8:.*]]
|
||||||
|
// CHECK: return %[[VAL_7]] : !torch.int
|
||||||
|
func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.prim.If %arg0 -> (!torch.int) {
|
||||||
|
torch.prim.If.yield %int2 : !torch.int
|
||||||
|
} else {
|
||||||
|
torch.prim.If.yield %int1 : !torch.int
|
||||||
|
}
|
||||||
|
return %0 : !torch.int
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @aten.prim.if$nested(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.bool) -> !torch.int {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.to_i1 %[[VAL_0]]
|
||||||
|
// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_5]] -> (i64) {
|
||||||
|
// CHECK: %[[VAL_7:.*]] = torch.to_i1 %[[VAL_1]]
|
||||||
|
// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (i64) {
|
||||||
|
// CHECK: %[[VAL_9:.*]] = torch.to_i64 %[[VAL_2]]
|
||||||
|
// CHECK: scf.yield %[[VAL_9]] : i64
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_10:.*]] = torch.to_i64 %[[VAL_3]]
|
||||||
|
// CHECK: scf.yield %[[VAL_10]] : i64
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: scf.yield %[[VAL_11:.*]] : i64
|
||||||
|
// CHECK: } else {
|
||||||
|
// CHECK: %[[VAL_12:.*]] = torch.to_i64 %[[VAL_4]]
|
||||||
|
// CHECK: scf.yield %[[VAL_12]] : i64
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %[[VAL_13:.*]] = torch.from_i64 %[[VAL_14:.*]]
|
||||||
|
// CHECK: return %[[VAL_13]] : !torch.int
|
||||||
|
func @aten.prim.if$nested(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.int {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%0 = torch.prim.If %arg0 -> (!torch.int) {
|
||||||
|
%1 = torch.prim.If %arg1 -> (!torch.int) {
|
||||||
|
torch.prim.If.yield %int2 : !torch.int
|
||||||
|
} else {
|
||||||
|
torch.prim.If.yield %int3 : !torch.int
|
||||||
|
}
|
||||||
|
torch.prim.If.yield %1 : !torch.int
|
||||||
|
} else {
|
||||||
|
torch.prim.If.yield %int4 : !torch.int
|
||||||
|
}
|
||||||
|
return %0 : !torch.int
|
||||||
|
}
|
Loading…
Reference in New Issue