mirror of https://github.com/llvm/torch-mlir
Enable -Werror in lib/ and LTC. (#2841)
Required some massaging of LTC to make it warning clean, and I had to manually disable some warnings on the generated source files (which we don't control). The project is warning clean now. The `-Werror` flag is disabled by default as we can't control everywhere people will try to build/install. The CI enables it via -DTORCH_MLIR_ENABLE_WERROR_FLAG=ON.pull/2847/head
parent
943164d797
commit
7301aa80fd
|
@ -31,6 +31,7 @@ include(CMakeDependentOption)
|
|||
# Project options
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
option(TORCH_MLIR_ENABLE_WERROR_FLAG "Enable `-Werror` flag on supported directories, treat error as warning" OFF)
|
||||
option(TORCH_MLIR_USE_INSTALLED_PYTORCH "If depending on PyTorch use it as installed in the current Python environment" ON)
|
||||
|
||||
option(TORCH_MLIR_ENABLE_REFBACKEND "Enable reference backend" ON)
|
||||
|
@ -53,6 +54,14 @@ cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLI
|
|||
|
||||
option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF)
|
||||
|
||||
macro(torch_mlir_enable_werror)
|
||||
if(TORCH_MLIR_ENABLE_WERROR_FLAG)
|
||||
if(NOT MSVC)
|
||||
add_compile_options(-Werror)
|
||||
endif()
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Configure out-of-tree vs in-tree build
|
||||
#-------------------------------------------------------------------------------
|
||||
|
|
|
@ -42,6 +42,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \
|
|||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DPython3_EXECUTABLE="$(which python)" \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DTORCH_MLIR_ENABLE_WERROR_FLAG=ON \
|
||||
-DCMAKE_INSTALL_PREFIX="$install_dir" \
|
||||
-DCMAKE_INSTALL_LIBDIR=lib \
|
||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||
|
|
|
@ -23,7 +23,7 @@ extern "C" {
|
|||
MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context);
|
||||
|
||||
/** Registers all passes for symbolic access with the global registry. */
|
||||
MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses();
|
||||
MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ MLIR_CAPI_EXPORTED MlirType
|
|||
torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className);
|
||||
|
||||
/// Gets the !torch.nn.Module typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.optional type.
|
||||
|
@ -53,7 +53,7 @@ MLIR_CAPI_EXPORTED MlirType
|
|||
torchMlirTorchOptionalTypeGetContained(MlirType containedType);
|
||||
|
||||
/// Gets the !torch.optional typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tuple<T1, T2, T3> type.
|
||||
|
@ -75,7 +75,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t,
|
|||
intptr_t pos);
|
||||
|
||||
/// Gets the !torch.tuple typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.union<T1, T2, T3> type.
|
||||
|
@ -97,7 +97,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t,
|
|||
intptr_t pos);
|
||||
|
||||
/// Gets the !torch.union typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.list<T> type.
|
||||
|
@ -113,7 +113,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t);
|
||||
|
||||
/// Gets the !torch.list typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.Device type.
|
||||
|
@ -126,7 +126,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.device typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.Generator type.
|
||||
|
@ -139,7 +139,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.generator typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.bool type.
|
||||
|
@ -152,7 +152,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.bool typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.int type.
|
||||
|
@ -165,7 +165,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.int typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.float type.
|
||||
|
@ -178,7 +178,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.float typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.LinearParams type.
|
||||
|
@ -192,7 +192,7 @@ MLIR_CAPI_EXPORTED MlirType
|
|||
torchMlirTorchLinearParamsTypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.linearparams typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.qint8 type.
|
||||
|
@ -205,7 +205,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.qint8 typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.quint8 type.
|
||||
|
@ -218,7 +218,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.quint8 typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tensor type.
|
||||
|
@ -266,7 +266,7 @@ MLIR_CAPI_EXPORTED MlirType
|
|||
torchMlirTorchNonValueTensorTypeGetDtype(MlirType t);
|
||||
|
||||
/// Gets the !torch.tensor typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.vtensor type.
|
||||
|
@ -312,7 +312,7 @@ torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t);
|
||||
|
||||
/// Gets the !torch.vtensor typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.none type.
|
||||
|
@ -325,7 +325,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.none typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.str type.
|
||||
|
@ -338,7 +338,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.str typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.any type.
|
||||
|
@ -351,7 +351,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.any typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.number type.
|
||||
|
@ -364,7 +364,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context);
|
||||
|
||||
/// Gets the !torch.number typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.dict type.
|
||||
|
@ -387,7 +387,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t);
|
||||
|
||||
/// Gets the !torch.dict typeid.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
torch_mlir_enable_werror()
|
||||
|
||||
add_subdirectory(CAPI)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
|
|
|
@ -673,7 +673,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
if (operands.size() == 1) {
|
||||
if (noop_with_empty_axes == 0) {
|
||||
MLIRContext *context = binder.op->getContext();
|
||||
auto rank =
|
||||
int rank =
|
||||
data.getType().cast<Torch::ValueTensorType>().getSizes().size();
|
||||
SmallVector<Value, 1> dims;
|
||||
for (int i = 0; i < rank; i++) {
|
||||
|
|
|
@ -2,11 +2,21 @@
|
|||
# Setup PyTorch/LTC
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
torch_mlir_enable_werror()
|
||||
|
||||
set(LTC_GENERATED
|
||||
generated/LazyNativeFunctions.cpp
|
||||
generated/RegisterLazy.cpp
|
||||
generated/shape_inference.cpp
|
||||
)
|
||||
|
||||
# The auto generated files trigger some warnings we can't do anything about.
|
||||
if(NOT MSVC)
|
||||
set_source_files_properties(${LTC_GENERATED}
|
||||
PROPERTIES COMPILE_FLAGS "-Wno-sign-compare -Wno-unused-function"
|
||||
)
|
||||
endif()
|
||||
|
||||
set(LTC_BACKEND_DEPENDS
|
||||
mlir_lowering_context.cpp
|
||||
mlir_native_functions.cpp
|
||||
|
|
|
@ -24,7 +24,7 @@ std::string DimensionNode::ToString() const { return "DimensionNode"; }
|
|||
SizeNode::SizeNode(Value input, size_t dim)
|
||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::size")}, {input},
|
||||
MHash(dim)),
|
||||
dim_(dim){};
|
||||
dim_(dim) {}
|
||||
|
||||
int64_t SizeNode::getStaticValue() const {
|
||||
return dynamic_cast<const TorchMlirNode *>(operand(0).node)
|
||||
|
@ -35,7 +35,7 @@ int64_t SizeNode::getStaticValue() const {
|
|||
std::string SizeNode::ToString() const { return "SizeNode"; }
|
||||
|
||||
SizeAdd::SizeAdd(Value a, Value b)
|
||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){};
|
||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}) {}
|
||||
|
||||
int64_t SizeAdd::getStaticValue() const {
|
||||
return dynamic_cast<const DimensionNode *>(operand(0).node)
|
||||
|
@ -46,7 +46,7 @@ int64_t SizeAdd::getStaticValue() const {
|
|||
std::string SizeAdd::ToString() const { return "SizeAdd"; }
|
||||
|
||||
SizeMul::SizeMul(Value a, Value b)
|
||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){};
|
||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}) {}
|
||||
|
||||
int64_t SizeMul::getStaticValue() const {
|
||||
return dynamic_cast<const DimensionNode *>(operand(0).node)
|
||||
|
@ -57,7 +57,7 @@ int64_t SizeMul::getStaticValue() const {
|
|||
std::string SizeMul::ToString() const { return "SizeMul"; }
|
||||
|
||||
SizeDiv::SizeDiv(Value a, Value b)
|
||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}){};
|
||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}) {}
|
||||
|
||||
int64_t SizeDiv::getStaticValue() const {
|
||||
TORCH_CHECK(
|
||||
|
|
|
@ -150,15 +150,14 @@ public:
|
|||
|
||||
protected:
|
||||
size_t num_parameters_;
|
||||
std::unordered_map<int, std::string> parameters_map_;
|
||||
std::vector<std::string> parameter_names_;
|
||||
std::vector<Shape> parameter_shapes_;
|
||||
Shape result_shape_;
|
||||
|
||||
MlirModule module_op_;
|
||||
MlirContext mlir_context_;
|
||||
std::shared_ptr<torch::jit::Graph> graph_;
|
||||
InputOutputAliases input_output_aliases_;
|
||||
std::unordered_map<int, std::string> parameters_map_;
|
||||
std::vector<std::string> parameter_names_;
|
||||
std::vector<Shape> parameter_shapes_;
|
||||
Shape result_shape_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
|
|
|
@ -67,7 +67,7 @@ c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor> &tensor) {
|
|||
return c10::nullopt;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
|
||||
[[maybe_unused]] std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
|
||||
std::vector<at::Tensor> outs;
|
||||
outs.reserve(t_list.size());
|
||||
for (const auto &tensor : t_list) {
|
||||
|
@ -92,7 +92,7 @@ namespace lazy {
|
|||
|
||||
namespace {
|
||||
|
||||
at::Tensor
|
||||
[[maybe_unused]] at::Tensor
|
||||
CreateLtcTensor(const at::Tensor &tensor,
|
||||
const c10::optional<torch::lazy::BackendDevice> &device) {
|
||||
if (tensor.defined() && device) {
|
||||
|
@ -102,7 +102,7 @@ CreateLtcTensor(const at::Tensor &tensor,
|
|||
return tensor;
|
||||
}
|
||||
|
||||
c10::optional<torch::lazy::BackendDevice>
|
||||
[[maybe_unused]] c10::optional<torch::lazy::BackendDevice>
|
||||
GetLtcDevice(const c10::optional<c10::Device> &device) {
|
||||
if (!device) {
|
||||
return c10::nullopt;
|
||||
|
@ -334,7 +334,7 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
|||
std::move(node), lazy_self->GetDevice()));
|
||||
return result;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
at::Tensor LazyNativeFunctions::_unsafe_view(const at::Tensor &self,
|
||||
at::IntArrayRef size) {
|
||||
|
|
|
@ -14,8 +14,8 @@ static T GetEnv(const std::string &name, const T &default_value = T(0)) {
|
|||
return T(std::atoi(env));
|
||||
}
|
||||
|
||||
static std::string GetEnvString(const std::string &name,
|
||||
const std::string &default_value) {
|
||||
[[maybe_unused]] static std::string
|
||||
GetEnvString(const std::string &name, const std::string &default_value) {
|
||||
const char *env = std::getenv(name.c_str());
|
||||
if (!env) {
|
||||
return default_value;
|
||||
|
@ -23,7 +23,7 @@ static std::string GetEnvString(const std::string &name,
|
|||
return std::string(env);
|
||||
}
|
||||
|
||||
static bool GetEnvBool(const char *name, bool defval) {
|
||||
[[maybe_unused]] static bool GetEnvBool(const char *name, bool defval) {
|
||||
const char *env = std::getenv(name);
|
||||
if (env == nullptr) {
|
||||
return defval;
|
||||
|
|
Loading…
Reference in New Issue