mirror of https://github.com/llvm/torch-mlir
NFC: Re-organize ATen directory structure and fix warnings.
* Still some more work to do on the Transforms tree to bring it in line with the others (will do that as I add things).pull/92/head
parent
d09300886a
commit
9618c2dbf7
|
@ -25,10 +25,10 @@
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenOpReport.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenOpReport.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenPasses.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenPasses.h"
|
||||||
#include "npcomp/Dialect/ATen/LivenessReport.h"
|
#include "npcomp/Dialect/ATen/Transforms/LivenessReport.h"
|
||||||
|
|
||||||
#include "init_python_bindings.h"
|
#include "init_python_bindings.h"
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,2 @@
|
||||||
include_directories(${PROJECT_SOURCE_DIR}/dialect)
|
add_subdirectory(IR)
|
||||||
|
add_subdirectory(Transforms)
|
||||||
add_mlir_dialect(ATen aten)
|
|
||||||
set(LLVM_TARGET_DEFINITIONS ATen.td)
|
|
||||||
mlir_tablegen(ATenEnums.h.inc -gen-enum-decls)
|
|
||||||
mlir_tablegen(ATenEnums.cpp.inc -gen-enum-defs)
|
|
||||||
add_public_tablegen_target(MLIRATenEnumsIncGen)
|
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS ATenOpInterface.td)
|
|
||||||
mlir_tablegen(ATenOpInterfaces.h.inc -gen-op-interface-decls)
|
|
||||||
mlir_tablegen(ATenOpInterfaces.cpp.inc -gen-op-interface-defs)
|
|
||||||
add_public_tablegen_target(MLIRATenOpInterfacesIncGen)
|
|
||||||
add_dependencies(mlir-generic-headers MLIRATenOpInterfacesIncGen)
|
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS ATenToStd.td)
|
|
||||||
mlir_tablegen(ATenToStd.cpp.inc -gen-rewriters)
|
|
||||||
add_public_tablegen_target(MLIRATenToStdIncGen)
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -6,8 +6,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_ATEN_DIALECT_H
|
#ifndef NPCOMP_DIALECT_ATEN_IR_DIALECT_H
|
||||||
#define NPCOMP_DIALECT_ATEN_DIALECT_H
|
#define NPCOMP_DIALECT_ATEN_IR_DIALECT_H
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
@ -54,7 +54,7 @@ namespace {
|
||||||
// Return the tensor volume (i.e., the number of elements) of the given shaped
|
// Return the tensor volume (i.e., the number of elements) of the given shaped
|
||||||
// type. If the type does not have a rank, return 1. If the type doesn't
|
// type. If the type does not have a rank, return 1. If the type doesn't
|
||||||
// have a static shape, return 0.
|
// have a static shape, return 0.
|
||||||
uint64_t getTensorVolume(const ShapedType ty) {
|
inline uint64_t getTensorVolume(const ShapedType ty) {
|
||||||
if (!ty.hasRank())
|
if (!ty.hasRank())
|
||||||
return 1;
|
return 1;
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ uint64_t getTensorVolume(const ShapedType ty) {
|
||||||
// If the type doesn't have a shape, return 1. If the type is shaped, but
|
// If the type doesn't have a shape, return 1. If the type is shaped, but
|
||||||
// does not have a rank, return 1. If the type is shaped, but doesn't have a
|
// does not have a rank, return 1. If the type is shaped, but doesn't have a
|
||||||
// static shape, return 0.
|
// static shape, return 0.
|
||||||
uint64_t getTensorVolume(const Type ty) {
|
inline uint64_t getTensorVolume(const Type ty) {
|
||||||
if (auto t = ty.dyn_cast<ShapedType>()) {
|
if (auto t = ty.dyn_cast<ShapedType>()) {
|
||||||
return getTensorVolume(t);
|
return getTensorVolume(t);
|
||||||
} else {
|
} else {
|
||||||
|
@ -84,12 +84,12 @@ uint64_t getTensorVolume(const Type ty) {
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenOpInterfaces.h"
|
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.h"
|
||||||
|
|
||||||
// include TableGen generated Op definitions
|
// include TableGen generated Op definitions
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/ATen/ATen.h.inc"
|
#include "npcomp/Dialect/ATen/IR/ATenOps.h.inc"
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h.inc"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h.inc"
|
||||||
|
|
||||||
#endif
|
#endif // NPCOMP_DIALECT_ATEN_IR_DIALECT_H
|
|
@ -0,0 +1,40 @@
|
||||||
|
//===- ATenDialect.td --------------------------------------*- tablegen -*-===//
|
||||||
|
//
|
||||||
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_DIALECT_ATEN_IR_ATEN_DIALECT
|
||||||
|
#define NPCOMP_DIALECT_ATEN_IR_ATEN_DIALECT
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect definition
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// The ATenDialect models 'A Tensor library' from Pytorch. The intention
|
||||||
|
/// is to provide an abstraction which is isomorphic with datastructures
|
||||||
|
/// returned from the pytorch jit, enabling integration with Pytorch models.
|
||||||
|
/// Most of the actual operation definitions in tablegen are themselves
|
||||||
|
/// generated from C APIs exported by Pytorch.
|
||||||
|
def ATen_Dialect : Dialect {
|
||||||
|
let name = "aten";
|
||||||
|
let cppNamespace = "::mlir::NPCOMP::aten";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect types
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ATen_ListType : DialectType<ATen_Dialect,
|
||||||
|
CPred<"$_self.isa<::mlir::NPCOMP::aten::ATenListType>()">, "ATen List">,
|
||||||
|
BuildableType<"$_builder.getType<::mlir::NPCOMP::aten::ATenListType()"> {
|
||||||
|
let typeDescription = [{
|
||||||
|
A variadic list of arguments in ATen.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // NPCOMP_DIALECT_ATEN_IR_ATEN_DIALECT
|
|
@ -8,8 +8,8 @@
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
#ifndef ATEN_OP_INTERFACES
|
#ifndef NPCOMP_DIALECT_ATEN_IR_ATEN_OP_INTERFACES
|
||||||
#define ATEN_OP_INTERFACES
|
#define NPCOMP_DIALECT_ATEN_IR_ATEN_OP_INTERFACES
|
||||||
|
|
||||||
def StatisticsOpInterface : OpInterface<"StatisticsOpInterface"> {
|
def StatisticsOpInterface : OpInterface<"StatisticsOpInterface"> {
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -67,4 +67,4 @@ def AnyScalar : TypeConstraint<Or<[AnySignlessInteger.predicate,
|
||||||
"scalar">;
|
"scalar">;
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif // NPCOMP_DIALECT_ATEN_IR_ATEN_OP_INTERFACES
|
|
@ -6,15 +6,15 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_ATEN_OPINTERFACES_H
|
#ifndef NPCOMP_DIALECT_ATEN_IR_OPINTERFACES_H
|
||||||
#define NPCOMP_DIALECT_ATEN_OPINTERFACES_H
|
#define NPCOMP_DIALECT_ATEN_IR_OPINTERFACES_H
|
||||||
|
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
#include "npcomp/Dialect/ATen/ATenOpInterfaces.h.inc"
|
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.h.inc"
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif
|
#endif // NPCOMP_DIALECT_ATEN_IR_OPINTERFACES_H
|
|
@ -6,10 +6,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_ATEN_OPSTATISTICSUTILS_H
|
#ifndef NPCOMP_DIALECT_ATEN_IR_OPSTATISTICSUTILS_H
|
||||||
#define NPCOMP_DIALECT_ATEN_OPSTATISTICSUTILS_H
|
#define NPCOMP_DIALECT_ATEN_IR_OPSTATISTICSUTILS_H
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
|
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
|
@ -36,7 +36,6 @@ std::map<std::string, uint64_t> getConv2dStatistics(T *o, uint64_t groups) {
|
||||||
TensorType biasTy = o->bias().getType().template cast<TensorType>();
|
TensorType biasTy = o->bias().getType().template cast<TensorType>();
|
||||||
|
|
||||||
uint64_t ofm_volume = getTensorVolume(resultTy);
|
uint64_t ofm_volume = getTensorVolume(resultTy);
|
||||||
uint64_t ofm_depth = resultTy.getShape()[1];
|
|
||||||
|
|
||||||
uint64_t ifm_depth = inputTy.getShape()[1];
|
uint64_t ifm_depth = inputTy.getShape()[1];
|
||||||
uint64_t kernel_height = weightTy.getShape()[2];
|
uint64_t kernel_height = weightTy.getShape()[2];
|
||||||
|
@ -142,27 +141,20 @@ uint64_t getConv2dOperandTransferVolume(T *o, unsigned int idx, bool read) {
|
||||||
float filter_height = weightTy.getShape()[3];
|
float filter_height = weightTy.getShape()[3];
|
||||||
|
|
||||||
float batch_sw = inputTy.getShape()[0];
|
float batch_sw = inputTy.getShape()[0];
|
||||||
float ifm_depth_sw = inputTy.getShape()[1];
|
|
||||||
float ih = inputTy.getShape()[2];
|
float ih = inputTy.getShape()[2];
|
||||||
float iw = inputTy.getShape()[3];
|
|
||||||
|
|
||||||
float ofm_depth_sw = resultTy.getShape()[1];
|
float ofm_depth_sw = resultTy.getShape()[1];
|
||||||
|
|
||||||
const float batch_hw = 4;
|
const float batch_hw = 4;
|
||||||
const float ifm_depth_hw = 32;
|
|
||||||
const float ofm_depth_hw = 32;
|
const float ofm_depth_hw = 32;
|
||||||
|
|
||||||
const float ifm_tile_height = 4;
|
const float ifm_tile_height = 4;
|
||||||
const float ifm_tile_width = 4;
|
|
||||||
const float ofm_tile_height = 4;
|
|
||||||
const float ofm_tile_width = 4;
|
|
||||||
|
|
||||||
float ifm_aperture = ifm_tile_height - ceilf(filter_height / 2.0f);
|
float ifm_aperture = ifm_tile_height - ceilf(filter_height / 2.0f);
|
||||||
float ifm_overlap = ceilf(filter_height / 2.0f);
|
float ifm_overlap = ceilf(filter_height / 2.0f);
|
||||||
|
|
||||||
float bl = ceilf(batch_sw / batch_hw);
|
float bl = ceilf(batch_sw / batch_hw);
|
||||||
float ol = ceilf(ofm_depth_sw / ofm_depth_hw);
|
float ol = ceilf(ofm_depth_sw / ofm_depth_hw);
|
||||||
float il = ceilf(ifm_depth_sw / ifm_depth_hw);
|
|
||||||
|
|
||||||
float ifm_overhead = 1.0f;
|
float ifm_overhead = 1.0f;
|
||||||
float weight_overhead = 1.0f;
|
float weight_overhead = 1.0f;
|
||||||
|
@ -274,4 +266,4 @@ std::map<std::string, uint64_t> getReLUOpStatistics(T op) {
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif
|
#endif // NPCOMP_DIALECT_ATEN_IR_OPSTATISTICSUTILS_H
|
|
@ -6,39 +6,13 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
#ifndef NPCOMP_DIALECT_ATEN_IR_ATEN_OPS
|
||||||
|
#define NPCOMP_DIALECT_ATEN_IR_ATEN_OPS
|
||||||
|
|
||||||
#ifndef ATEN_OPS
|
include "npcomp/Dialect/ATen/IR/ATenDialect.td"
|
||||||
#define ATEN_OPS
|
include "npcomp/Dialect/ATen/IR/ATenOpInterface.td"
|
||||||
|
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "npcomp/Dialect/ATen/ATenOpInterface.td"
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Dialect definition
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
/// The ATenDialect models 'A Tensor library' from Pytorch. The intention
|
|
||||||
/// is to provide an abstraction which is isomorphic with datastructures
|
|
||||||
/// returned from the pytorch jit, enabling integration with Pytorch models.
|
|
||||||
/// Most of the actual operation definitions in tablegen are themselves
|
|
||||||
/// generated from C APIs exported by Pytorch.
|
|
||||||
def ATen_Dialect : Dialect {
|
|
||||||
let name = "aten";
|
|
||||||
let cppNamespace = "::mlir::NPCOMP::aten";
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Dialect types
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def ATen_ListType : DialectType<ATen_Dialect,
|
|
||||||
CPred<"$_self.isa<::mlir::NPCOMP::aten::ATenListType>()">, "ATen List">,
|
|
||||||
BuildableType<"$_builder.getType<::mlir::NPCOMP::aten::ATenListType()"> {
|
|
||||||
let typeDescription = [{
|
|
||||||
A variadic list of arguments in ATen.
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: convert to "let results =" style
|
// TODO: convert to "let results =" style
|
||||||
// TODO: Rename prefix from "aten" to "ATen" for consistency.
|
// TODO: Rename prefix from "aten" to "ATen" for consistency.
|
||||||
|
@ -48,7 +22,7 @@ class aten_Op<string mnemonic, list<OpTrait> traits = [StatisticsOpInterface]> :
|
||||||
|
|
||||||
|
|
||||||
// Most ops are automatically generated from pytorch specs.
|
// Most ops are automatically generated from pytorch specs.
|
||||||
include "npcomp/Dialect/ATen/ATenOps.td"
|
include "npcomp/Dialect/ATen/IR/GeneratedATenOps.td"
|
||||||
|
|
||||||
|
|
||||||
def aten_BatchNormOp: aten_Op<"batch_norm", [NoSideEffect, StatisticsOpInterface]>,
|
def aten_BatchNormOp: aten_Op<"batch_norm", [NoSideEffect, StatisticsOpInterface]>,
|
||||||
|
@ -179,4 +153,4 @@ def aten_TypeCastOp : aten_Op<"type_cast", [NoSideEffect]>,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif // NPCOMP_DIALECT_ATEN_IR_ATEN_OPS
|
|
@ -0,0 +1,19 @@
|
||||||
|
set(dialect_namespace aten)
|
||||||
|
set(LLVM_TARGET_DEFINITIONS ATenOps.td)
|
||||||
|
|
||||||
|
mlir_tablegen(ATenOps.h.inc -gen-op-decls)
|
||||||
|
mlir_tablegen(ATenOps.cpp.inc -gen-op-defs)
|
||||||
|
mlir_tablegen(ATenDialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
|
||||||
|
mlir_tablegen(ATenEnums.h.inc -gen-enum-decls)
|
||||||
|
mlir_tablegen(ATenEnums.cpp.inc -gen-enum-defs)
|
||||||
|
add_public_tablegen_target(MLIRATenIncGen)
|
||||||
|
add_dependencies(mlir-headers MLIRATenIncGen)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS ATenOpInterface.td)
|
||||||
|
mlir_tablegen(ATenOpInterfaces.h.inc -gen-op-interface-decls)
|
||||||
|
mlir_tablegen(ATenOpInterfaces.cpp.inc -gen-op-interface-defs)
|
||||||
|
add_public_tablegen_target(MLIRATenOpInterfacesIncGen)
|
||||||
|
add_dependencies(mlir-generic-headers MLIRATenOpInterfacesIncGen)
|
||||||
|
|
||||||
|
add_mlir_doc(ATenDialect -gen-dialect-doc ATenDialect ATen/)
|
||||||
|
add_mlir_doc(ATenOps -gen-op-doc ATenOps ATen/)
|
|
@ -7,8 +7,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef ATEN_OP_DEFS
|
#ifndef NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS
|
||||||
#define ATEN_OP_DEFS
|
#define NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS
|
||||||
|
|
||||||
def aten_AddOp: aten_Op<"add", [NoSideEffect, StatisticsOpInterface]>,
|
def aten_AddOp: aten_Op<"add", [NoSideEffect, StatisticsOpInterface]>,
|
||||||
Results<(outs AnyTensor)> {
|
Results<(outs AnyTensor)> {
|
||||||
|
@ -730,4 +730,4 @@ def aten_MaxPool2dWithIndicesBackwardOp: aten_Op<"max_pool2d_with_indices_backwa
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif // NPCOMP_DIALECT_ATEN_IR_GENERATED_ATEN_OPS
|
|
@ -11,6 +11,8 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class Pass;
|
class Pass;
|
||||||
} // namespace mlir
|
} // namespace mlir
|
|
@ -9,10 +9,10 @@
|
||||||
#ifndef NPCOMP_DIALECT_ATEN_PASSES_H
|
#ifndef NPCOMP_DIALECT_ATEN_PASSES_H
|
||||||
#define NPCOMP_DIALECT_ATEN_PASSES_H
|
#define NPCOMP_DIALECT_ATEN_PASSES_H
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenLayerNamePass.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenLayerNamePass.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenLoweringPass.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenLoweringPass.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenOpReport.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenOpReport.h"
|
||||||
#include "npcomp/Dialect/ATen/ReturnEliminationPass.h"
|
#include "npcomp/Dialect/ATen/Transforms/ReturnEliminationPass.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
|
@ -6,19 +6,11 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifdef MLIR_ATEN_TO_STD_TD
|
#ifndef MLIR_ATEN_TO_STD_TD
|
||||||
#else
|
|
||||||
#define MLIR_ATEN_TO_STD_TD
|
#define MLIR_ATEN_TO_STD_TD
|
||||||
|
|
||||||
#ifdef STANDARD_OPS
|
|
||||||
#else
|
|
||||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
#endif // STANDARD_OPS
|
include "npcomp/Dialect/ATen/IR/ATenOps.td"
|
||||||
|
|
||||||
#ifdef ATEN_OPS
|
|
||||||
#else
|
|
||||||
include "ATen.td"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// The pytorch convolution operator has 9 arguments, but we only have a jit
|
// The pytorch convolution operator has 9 arguments, but we only have a jit
|
||||||
// library that supports the first six at the moment.
|
// library that supports the first six at the moment.
|
|
@ -0,0 +1,3 @@
|
||||||
|
set(LLVM_TARGET_DEFINITIONS ATenToStd.td)
|
||||||
|
mlir_tablegen(ATenToStd.cpp.inc -gen-rewriters)
|
||||||
|
add_public_tablegen_target(MLIRATenToStdIncGen)
|
|
@ -11,6 +11,9 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "mlir/IR/Module.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
namespace aten {
|
namespace aten {
|
|
@ -1,25 +1,2 @@
|
||||||
add_npcomp_dialect_library(NPCOMPATenDialect
|
add_subdirectory(IR)
|
||||||
ATenDialect.cpp
|
add_subdirectory(Transforms)
|
||||||
ATenDialectOpStats.cpp
|
|
||||||
ATenPasses.cpp
|
|
||||||
ATenLayerNamePass.cpp
|
|
||||||
ATenLoweringPass.cpp
|
|
||||||
ATenOpReport.cpp
|
|
||||||
ATenToStd.cpp
|
|
||||||
LivenessReport.cpp
|
|
||||||
ReturnEliminationPass.cpp
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
|
||||||
${PROJECT_SOURCE_DIR}/dialect/include
|
|
||||||
${PROJECT_BINARY_DIR}/dialect/include
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
MLIRATenIncGen
|
|
||||||
MLIRATenEnumsIncGen
|
|
||||||
MLIRATenOpInterfacesIncGen
|
|
||||||
MLIRATenToStdIncGen
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
MLIRPass
|
|
||||||
MLIRTransformUtils
|
|
||||||
)
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -56,8 +56,6 @@ mlir::Type ATenListType::getElementType() {
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type ATenDialect::parseType(DialectAsmParser &parser) const {
|
mlir::Type ATenDialect::parseType(DialectAsmParser &parser) const {
|
||||||
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
|
|
||||||
|
|
||||||
// All types start with an identifier that we switch on.
|
// All types start with an identifier that we switch on.
|
||||||
StringRef typeNameSpelling;
|
StringRef typeNameSpelling;
|
||||||
if (failed(parser.parseKeyword(&typeNameSpelling)))
|
if (failed(parser.parseKeyword(&typeNameSpelling)))
|
||||||
|
@ -99,11 +97,11 @@ void ATenDialect::initialize() {
|
||||||
addTypes<ATenListType>();
|
addTypes<ATenListType>();
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "npcomp/Dialect/ATen/ATen.cpp.inc"
|
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/ATen/ATen.cpp.inc"
|
#include "npcomp/Dialect/ATen/IR/ATenOps.cpp.inc"
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenOpInterfaces.cpp.inc"
|
#include "npcomp/Dialect/ATen/IR/ATenOpInterfaces.cpp.inc"
|
|
@ -6,8 +6,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenOpStatisticsUtils.h"
|
#include "npcomp/Dialect/ATen/IR/ATenOpStatisticsUtils.h"
|
||||||
|
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
add_npcomp_dialect_library(NPCOMPATenDialect
|
||||||
|
ATenDialect.cpp
|
||||||
|
ATenDialectOpStats.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/ATen
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
MLIRATenIncGen
|
||||||
|
#MLIRATenEnumsIncGen
|
||||||
|
MLIRATenOpInterfacesIncGen
|
||||||
|
#MLIRATenToStdIncGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRPass
|
||||||
|
MLIRTransformUtils
|
||||||
|
)
|
|
@ -6,8 +6,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenLayerNamePass.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenLayerNamePass.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
|
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
|
@ -6,9 +6,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenLoweringPass.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenLoweringPass.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenToStd.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenToStd.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Affine/EDSC/Builders.h"
|
#include "mlir/Dialect/Affine/EDSC/Builders.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
@ -70,7 +70,7 @@ static Value typeCast(PatternRewriter &builder, Value val, Type destTy) {
|
||||||
/// unknown shape.
|
/// unknown shape.
|
||||||
static MemRefType getShapeErasedMemRefType(MemRefType type) {
|
static MemRefType getShapeErasedMemRefType(MemRefType type) {
|
||||||
std::vector<int64_t> shape = type.getShape();
|
std::vector<int64_t> shape = type.getShape();
|
||||||
for (int i = 0; i < shape.size(); i++) {
|
for (size_t i = 0, e = shape.size(); i < e; i++) {
|
||||||
shape[i] = -1;
|
shape[i] = -1;
|
||||||
}
|
}
|
||||||
return MemRefType::get(shape, type.getElementType(), type.getAffineMaps(),
|
return MemRefType::get(shape, type.getElementType(), type.getAffineMaps(),
|
||||||
|
@ -120,27 +120,6 @@ static std::string getFullyMangledType(const Type ty) {
|
||||||
return ret.str();
|
return ret.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mangle the argument shapes into the function name. This is impractical for
|
|
||||||
// a library-based implementation, since each different shape has to be
|
|
||||||
// implemented by a different function. The function name is constructed
|
|
||||||
// from the prefix, the mangled result types, the mangled operand types.
|
|
||||||
// Types are mangled in a way that encodes the full shape information.
|
|
||||||
static std::string getFullyMangledFuncName(std::string prefix,
|
|
||||||
FunctionType fnTy) {
|
|
||||||
std::string sep = "_";
|
|
||||||
|
|
||||||
ArrayRef<Type> resultTy = fnTy.getResults();
|
|
||||||
ArrayRef<Type> operTy = fnTy.getInputs();
|
|
||||||
|
|
||||||
std::string ret = prefix + "_AtenAcapOp_";
|
|
||||||
for (const Type t : resultTy)
|
|
||||||
ret = ret + sep + getFullyMangledType(t);
|
|
||||||
for (const Type t : operTy)
|
|
||||||
ret = ret + sep + getFullyMangledType(t);
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mangle the argument ranks into the function name.
|
// Mangle the argument ranks into the function name.
|
||||||
// TODO: Currently only supports MemRef, Float, Integer, and AtenList (poorly)
|
// TODO: Currently only supports MemRef, Float, Integer, and AtenList (poorly)
|
||||||
static std::string getSimplyMangledType(const Type ty) {
|
static std::string getSimplyMangledType(const Type ty) {
|
||||||
|
@ -192,15 +171,6 @@ static std::string getSimplyMangledFuncName(std::string prefix,
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
static std::string getSimplyMangledFuncName(std::string prefix,
|
|
||||||
FunctionType fnTy) {
|
|
||||||
|
|
||||||
return getSimplyMangledFuncName(prefix, fnTy.getInputs(), fnTy.getResults());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string getMangledFuncName(std::string prefix, FunctionType fnTy) {
|
|
||||||
return getSimplyMangledFuncName(prefix, fnTy);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string getMangledFuncName(std::string prefix, ArrayRef<Type> opTys,
|
std::string getMangledFuncName(std::string prefix, ArrayRef<Type> opTys,
|
||||||
ArrayRef<Type> retTys) {
|
ArrayRef<Type> retTys) {
|
||||||
|
@ -254,13 +224,10 @@ public:
|
||||||
Value result = rewriter.create<AllocOp>(loc, memRefResultTy);
|
Value result = rewriter.create<AllocOp>(loc, memRefResultTy);
|
||||||
Value lhs = memRefTypeCast(rewriter, operands[0]);
|
Value lhs = memRefTypeCast(rewriter, operands[0]);
|
||||||
Value rhs = memRefTypeCast(rewriter, operands[1]);
|
Value rhs = memRefTypeCast(rewriter, operands[1]);
|
||||||
auto indexType = IndexType::get(op->getContext());
|
|
||||||
|
|
||||||
using namespace edsc;
|
using namespace edsc;
|
||||||
|
|
||||||
ScopedContext scope(rewriter, loc);
|
ScopedContext scope(rewriter, loc);
|
||||||
Value zero = intrinsics::std_constant_index(0);
|
Value zero = intrinsics::std_constant_index(0);
|
||||||
Value one = intrinsics::std_constant_index(1);
|
|
||||||
MemRefBoundsCapture vRes(result), vLHS(lhs), vRHS(rhs);
|
MemRefBoundsCapture vRes(result), vLHS(lhs), vRHS(rhs);
|
||||||
StdIndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
|
StdIndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
|
||||||
Value M(vRes.ub(0));
|
Value M(vRes.ub(0));
|
||||||
|
@ -345,8 +312,6 @@ LogicalResult rewriteWithVoidFunctionCallExplicit(
|
||||||
TensorType tensorResultTy = t.cast<TensorType>();
|
TensorType tensorResultTy = t.cast<TensorType>();
|
||||||
MemRefType memRefResultTy = mlir::MemRefType::get(
|
MemRefType memRefResultTy = mlir::MemRefType::get(
|
||||||
tensorResultTy.getShape(), tensorResultTy.getElementType(), {}, 0);
|
tensorResultTy.getShape(), tensorResultTy.getElementType(), {}, 0);
|
||||||
MemRefType erasedMemRefResultTy =
|
|
||||||
getShapeErasedMemRefType(memRefResultTy);
|
|
||||||
retTys.push_back(memRefResultTy);
|
retTys.push_back(memRefResultTy);
|
||||||
|
|
||||||
// assume memRefResultTy has known shape, so we don't need any
|
// assume memRefResultTy has known shape, so we don't need any
|
||||||
|
@ -367,8 +332,7 @@ LogicalResult rewriteWithVoidFunctionCallExplicit(
|
||||||
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
|
FuncOp funcOp = getATenFn(op->getParentOfType<ModuleOp>(),
|
||||||
mangledFunctionName, newOps, empty);
|
mangledFunctionName, newOps, empty);
|
||||||
|
|
||||||
auto new_call =
|
callOperation(empty, rewriter.getSymbolRefAttr(funcOp), newOps);
|
||||||
callOperation(empty, rewriter.getSymbolRefAttr(funcOp), newOps);
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, newResults);
|
rewriter.replaceOp(op, newResults);
|
||||||
return success();
|
return success();
|
||||||
|
@ -442,8 +406,6 @@ public:
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
edsc::ScopedContext scope(rewriter, loc);
|
edsc::ScopedContext scope(rewriter, loc);
|
||||||
|
|
||||||
auto constOp = cast<mlir::NPCOMP::aten::ConstantOp>(op);
|
|
||||||
|
|
||||||
Value result = op->getResult(0);
|
Value result = op->getResult(0);
|
||||||
Type t = result.getType();
|
Type t = result.getType();
|
||||||
if (t.isa<IntegerType>()) {
|
if (t.isa<IntegerType>()) {
|
|
@ -6,8 +6,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenOpReport.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenOpReport.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
|
||||||
|
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
@ -15,6 +14,7 @@
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -25,10 +25,6 @@ using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::string getAsString(std::map<std::string, uint64_t> &m, std::string &e) {
|
|
||||||
return m.count(e) ? std::to_string(m[e]) : " ";
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Query operations through the StatisticsOpInterface and print the result
|
/// Query operations through the StatisticsOpInterface and print the result
|
||||||
/// in a human-readable way. This replicates the functionality in various
|
/// in a human-readable way. This replicates the functionality in various
|
||||||
/// network analysis tools and is a stepping stone toward using the information
|
/// network analysis tools and is a stepping stone toward using the information
|
|
@ -6,7 +6,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenPasses.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenPasses.h"
|
||||||
|
|
||||||
using namespace mlir::NPCOMP::aten;
|
using namespace mlir::NPCOMP::aten;
|
||||||
|
|
|
@ -6,16 +6,17 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenToStd.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenToStd.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// import patterns
|
// import patterns
|
||||||
#include "npcomp/Dialect/ATen/ATenToStd.cpp.inc"
|
#include "npcomp/Dialect/ATen/Transforms/ATenToStd.cpp.inc"
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
|
@ -0,0 +1,12 @@
|
||||||
|
add_npcomp_conversion_library(NPCOMPATenPasses
|
||||||
|
ATenPasses.cpp
|
||||||
|
ATenLayerNamePass.cpp
|
||||||
|
ATenLoweringPass.cpp
|
||||||
|
ATenOpReport.cpp
|
||||||
|
ATenToStd.cpp
|
||||||
|
LivenessReport.cpp
|
||||||
|
ReturnEliminationPass.cpp
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
MLIRATenToStdIncGen
|
||||||
|
)
|
|
@ -6,7 +6,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/Transforms/LivenessReport.h"
|
||||||
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
|
|
||||||
#include "mlir/Analysis/Liveness.h"
|
#include "mlir/Analysis/Liveness.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -15,8 +16,6 @@
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/LivenessReport.h"
|
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -25,28 +24,6 @@
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
uint64_t getTensorVolume(const ShapedType ty) {
|
|
||||||
|
|
||||||
if (!ty.hasRank())
|
|
||||||
return 1;
|
|
||||||
|
|
||||||
uint64_t volume = 1;
|
|
||||||
for (auto &d : ty.getShape())
|
|
||||||
volume *= d;
|
|
||||||
return volume;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t getTensorVolume(const Type ty) {
|
|
||||||
if (auto t = ty.dyn_cast<ShapedType>()) {
|
|
||||||
return getTensorVolume(t);
|
|
||||||
} else {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
namespace aten {
|
namespace aten {
|
||||||
|
@ -72,9 +49,6 @@ std::string LivenessReport::generateTextReport() {
|
||||||
std::string LivenessReport::emitJSONReport() {
|
std::string LivenessReport::emitJSONReport() {
|
||||||
resolveLiveness();
|
resolveLiveness();
|
||||||
llvm::json::Object top;
|
llvm::json::Object top;
|
||||||
auto context = module.getContext();
|
|
||||||
auto loc = mlir::UnknownLoc::get(context);
|
|
||||||
|
|
||||||
auto graph = module.lookupSymbol<mlir::FuncOp>("graph");
|
auto graph = module.lookupSymbol<mlir::FuncOp>("graph");
|
||||||
|
|
||||||
std::map<Operation *, std::vector<Value>> liveAt;
|
std::map<Operation *, std::vector<Value>> liveAt;
|
||||||
|
@ -117,7 +91,6 @@ std::string LivenessReport::emitJSONReport() {
|
||||||
if (v.getDefiningOp()) {
|
if (v.getDefiningOp()) {
|
||||||
if (auto a =
|
if (auto a =
|
||||||
v.getDefiningOp()->getAttrOfType<StringAttr>("layer_name")) {
|
v.getDefiningOp()->getAttrOfType<StringAttr>("layer_name")) {
|
||||||
auto definingOp = v.getDefiningOp();
|
|
||||||
auto ld = layerDetail.getInteger(a.getValue().str());
|
auto ld = layerDetail.getInteger(a.getValue().str());
|
||||||
if (ld)
|
if (ld)
|
||||||
layerDetail[a.getValue().str()] = *ld + vol;
|
layerDetail[a.getValue().str()] = *ld + vol;
|
|
@ -6,8 +6,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ReturnEliminationPass.h"
|
#include "npcomp/Dialect/ATen/Transforms/ReturnEliminationPass.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
|
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
@ -78,8 +78,8 @@ public:
|
||||||
newCallArgs.push_back(valueMap[v]);
|
newCallArgs.push_back(valueMap[v]);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newCallOp = builder->create<CallOp>(op->getLoc(), newFnName,
|
builder->create<CallOp>(op->getLoc(), newFnName, ArrayRef<Type>{},
|
||||||
ArrayRef<Type>{}, newCallArgs);
|
newCallArgs);
|
||||||
erasedOps.insert(op);
|
erasedOps.insert(op);
|
||||||
auto fn = module.lookupSymbol<FuncOp>(callOp.callee());
|
auto fn = module.lookupSymbol<FuncOp>(callOp.callee());
|
||||||
if (fn && fn.use_empty())
|
if (fn && fn.use_empty())
|
||||||
|
@ -105,7 +105,6 @@ public:
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
auto context = module.getContext();
|
|
||||||
|
|
||||||
// check that a function called "graph" exists
|
// check that a function called "graph" exists
|
||||||
auto graph = module.lookupSymbol<mlir::FuncOp>("graph");
|
auto graph = module.lookupSymbol<mlir::FuncOp>("graph");
|
|
@ -8,8 +8,8 @@
|
||||||
|
|
||||||
#include "npcomp/InitAll.h"
|
#include "npcomp/InitAll.h"
|
||||||
|
|
||||||
#include "npcomp/Dialect/ATen/ATenDialect.h"
|
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||||
#include "npcomp/Dialect/ATen/ATenPasses.h"
|
#include "npcomp/Dialect/ATen/Transforms/ATenPasses.h"
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h"
|
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
|
|
|
@ -28,8 +28,6 @@ using namespace mlir::NPCOMP;
|
||||||
// conversion about them.
|
// conversion about them.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// This is a type conversion similar to CallOpSignatureConversion.
|
// This is a type conversion similar to CallOpSignatureConversion.
|
||||||
class LowerSelectOpTypes : public OpConversionPattern<SelectOp> {
|
class LowerSelectOpTypes : public OpConversionPattern<SelectOp> {
|
||||||
|
|
Loading…
Reference in New Issue