mirror of https://github.com/llvm/torch-mlir
[Bazel] Resolve circular dependency and add targets for conversion to MLProgram dialect (#1694)
A circular dependency was introduced inpull/1699/heade7edcc62fd
. Specifically, the `makeShapeLLVMCompatible` and `makeShapeTorchCompatible` utilities were being called from `lib/Dialect/Torch/IR/TorchTypes.cpp` and `lib/Dialect/Torch/IR/TorchOps.cpp` defined under the `:TorchMLIRTorchDialect` bazel target, leading it to take a dependency on `:TorchMLIRConversionUtils` which already depends on `:TorchMLIRTorchDialect`, hence creating a circular dependency. This commit resolves the same by moving said utilities from `lib/Conversion/Utils/Utils.cpp` to `lib/Dialect/Torch/Utils/Utils.cpp`. Please LMK if there's a better way to fix this and I will update the code. This commit also adds the required targets to support building the new conversions from Torch to ML Program dialect that was introduced inf416953600
. Bazel build GHA triggered manually to verify: https://github.com/sjain-stanford/torch-mlir/actions/runs/3645944517
parent
a54b334578
commit
f8a2592905
|
@ -86,12 +86,6 @@ Value convertScalarToDtype(
|
|||
OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||
llvm::Optional<Type> srcOriginalDtype = llvm::None);
|
||||
|
||||
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||
// return -1.
|
||||
int64_t getNumberOfElements(RankedTensorType inputType);
|
||||
|
||||
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
|
||||
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -55,6 +55,12 @@ bool isViewLikeOp(Operation *op);
|
|||
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
|
||||
float value, Type dtype);
|
||||
|
||||
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||
// return -1.
|
||||
int64_t getNumberOfElements(RankedTensorType inputType);
|
||||
|
||||
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
|
||||
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
|
|
|
@ -316,41 +316,6 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||
}
|
||||
|
||||
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||
// return -1.
|
||||
int64_t getNumberOfElements(RankedTensorType inputType) {
|
||||
if (!inputType.hasStaticShape())
|
||||
return -1;
|
||||
SmallVector<int64_t> inputShape =
|
||||
makeShapeTorchCompatible(inputType.getShape());
|
||||
int64_t numel = 1;
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||
numel *= inputShape[i];
|
||||
return numel;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> updatedShape(shape);
|
||||
int64_t kDynamic = ShapedType::kDynamic;
|
||||
for (unsigned i = 0; i < shape.size(); i++) {
|
||||
assert(shape[i] >= 0 || shape[i] == kUnknownSize);
|
||||
if (shape[i] == kUnknownSize)
|
||||
updatedShape[i] = kDynamic;
|
||||
}
|
||||
return updatedShape;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> updatedShape(shape);
|
||||
int64_t kDynamic = ShapedType::kDynamic;
|
||||
for (unsigned i = 0; i < shape.size(); i++) {
|
||||
assert(shape[i] >= 0 || shape[i] == kDynamic);
|
||||
if (shape[i] == kDynamic)
|
||||
updatedShape[i] = kUnknownSize;
|
||||
}
|
||||
return updatedShape;
|
||||
}
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
|
|
@ -178,3 +178,38 @@ Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
|||
llvm::report_fatal_error(
|
||||
"unhandled type for getConstantWithGivenDtypeAndValue");
|
||||
}
|
||||
|
||||
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||
// return -1.
|
||||
int64_t Torch::getNumberOfElements(RankedTensorType inputType) {
|
||||
if (!inputType.hasStaticShape())
|
||||
return -1;
|
||||
SmallVector<int64_t> inputShape =
|
||||
makeShapeTorchCompatible(inputType.getShape());
|
||||
int64_t numel = 1;
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||
numel *= inputShape[i];
|
||||
return numel;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> Torch::makeShapeLLVMCompatible(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> updatedShape(shape);
|
||||
int64_t kDynamic = ShapedType::kDynamic;
|
||||
for (unsigned i = 0; i < shape.size(); i++) {
|
||||
assert(shape[i] >= 0 || shape[i] == kUnknownSize);
|
||||
if (shape[i] == kUnknownSize)
|
||||
updatedShape[i] = kDynamic;
|
||||
}
|
||||
return updatedShape;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> updatedShape(shape);
|
||||
int64_t kDynamic = ShapedType::kDynamic;
|
||||
for (unsigned i = 0; i < shape.size(); i++) {
|
||||
assert(shape[i] >= 0 || shape[i] == kDynamic);
|
||||
if (shape[i] == kDynamic)
|
||||
updatedShape[i] = kUnknownSize;
|
||||
}
|
||||
return updatedShape;
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
|
|
@ -397,6 +397,27 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "TorchMLIRTorchConversionToMLProgram",
|
||||
srcs = glob([
|
||||
"lib/Conversion/*.h",
|
||||
"lib/Conversion/TorchConversionToMLProgram/*.cpp",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"include/torch-mlir/Conversion/TorchConversionToMLProgram/*.h",
|
||||
]),
|
||||
strip_include_prefix = "include",
|
||||
deps = [
|
||||
":TorchMLIRConversionPassesIncGen",
|
||||
":TorchMLIRConversionUtils",
|
||||
":TorchMLIRTorchBackendTypeConversion",
|
||||
":TorchMLIRTorchConversionDialect",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:LinalgDialect",
|
||||
"@llvm-project//mlir:MLProgramDialect",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "TorchMLIRTorchToTMTensor",
|
||||
srcs = glob([
|
||||
|
@ -450,6 +471,7 @@ cc_library(
|
|||
":TorchMLIRTorchToSCF",
|
||||
":TorchMLIRTorchToTMTensor",
|
||||
":TorchMLIRTorchToTosa",
|
||||
":TorchMLIRTorchConversionToMLProgram",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -473,6 +495,7 @@ cc_library(
|
|||
":TorchMLIRTorchToSCF",
|
||||
":TorchMLIRTorchToTMTensor",
|
||||
":TorchMLIRTorchToTosa",
|
||||
":TorchMLIRTorchConversionToMLProgram",
|
||||
"@llvm-project//mlir:ConversionPasses",
|
||||
"@llvm-project//mlir:FuncDialect",
|
||||
"@llvm-project//mlir:LinalgDialect",
|
||||
|
@ -758,6 +781,7 @@ cc_library(
|
|||
":TorchMLIRTorchConversionDialect",
|
||||
"@llvm-project//mlir:ArithTransforms",
|
||||
"@llvm-project//mlir:LinalgDialect",
|
||||
"@llvm-project//mlir:MLProgramDialect",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:MathTransforms",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
|
|
Loading…
Reference in New Issue