mirror of https://github.com/llvm/torch-mlir
[Bazel] Resolve circular dependency and add targets for conversion to MLProgram dialect (#1694)
A circular dependency was introduced inpull/1699/heade7edcc62fd
. Specifically, the `makeShapeLLVMCompatible` and `makeShapeTorchCompatible` utilities were being called from `lib/Dialect/Torch/IR/TorchTypes.cpp` and `lib/Dialect/Torch/IR/TorchOps.cpp` defined under the `:TorchMLIRTorchDialect` bazel target, leading it to take a dependency on `:TorchMLIRConversionUtils` which already depends on `:TorchMLIRTorchDialect`, hence creating a circular dependency. This commit resolves the same by moving said utilities from `lib/Conversion/Utils/Utils.cpp` to `lib/Dialect/Torch/Utils/Utils.cpp`. Please LMK if there's a better way to fix this and I will update the code. This commit also adds the required targets to support building the new conversions from Torch to ML Program dialect that was introduced inf416953600
. Bazel build GHA triggered manually to verify: https://github.com/sjain-stanford/torch-mlir/actions/runs/3645944517
parent
a54b334578
commit
f8a2592905
|
@ -86,12 +86,6 @@ Value convertScalarToDtype(
|
||||||
OpBuilder &b, Location loc, Value scalar, Type dtype,
|
OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
llvm::Optional<Type> srcOriginalDtype = llvm::None);
|
llvm::Optional<Type> srcOriginalDtype = llvm::None);
|
||||||
|
|
||||||
// Return the number of elements of a tensor if the shape is static; otherwise,
|
|
||||||
// return -1.
|
|
||||||
int64_t getNumberOfElements(RankedTensorType inputType);
|
|
||||||
|
|
||||||
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
|
|
||||||
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -55,6 +55,12 @@ bool isViewLikeOp(Operation *op);
|
||||||
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
|
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
|
||||||
float value, Type dtype);
|
float value, Type dtype);
|
||||||
|
|
||||||
|
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||||
|
// return -1.
|
||||||
|
int64_t getNumberOfElements(RankedTensorType inputType);
|
||||||
|
|
||||||
|
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
|
||||||
|
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -10,7 +10,6 @@
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
|
|
@ -316,41 +316,6 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
llvm_unreachable("convertScalarToDtype should handle all the types");
|
llvm_unreachable("convertScalarToDtype should handle all the types");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the number of elements of a tensor if the shape is static; otherwise,
|
|
||||||
// return -1.
|
|
||||||
int64_t getNumberOfElements(RankedTensorType inputType) {
|
|
||||||
if (!inputType.hasStaticShape())
|
|
||||||
return -1;
|
|
||||||
SmallVector<int64_t> inputShape =
|
|
||||||
makeShapeTorchCompatible(inputType.getShape());
|
|
||||||
int64_t numel = 1;
|
|
||||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
|
||||||
numel *= inputShape[i];
|
|
||||||
return numel;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape) {
|
|
||||||
SmallVector<int64_t> updatedShape(shape);
|
|
||||||
int64_t kDynamic = ShapedType::kDynamic;
|
|
||||||
for (unsigned i = 0; i < shape.size(); i++) {
|
|
||||||
assert(shape[i] >= 0 || shape[i] == kUnknownSize);
|
|
||||||
if (shape[i] == kUnknownSize)
|
|
||||||
updatedShape[i] = kDynamic;
|
|
||||||
}
|
|
||||||
return updatedShape;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
|
|
||||||
SmallVector<int64_t> updatedShape(shape);
|
|
||||||
int64_t kDynamic = ShapedType::kDynamic;
|
|
||||||
for (unsigned i = 0; i < shape.size(); i++) {
|
|
||||||
assert(shape[i] >= 0 || shape[i] == kDynamic);
|
|
||||||
if (shape[i] == kDynamic)
|
|
||||||
updatedShape[i] = kUnknownSize;
|
|
||||||
}
|
|
||||||
return updatedShape;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
|
@ -178,3 +178,38 @@ Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
||||||
llvm::report_fatal_error(
|
llvm::report_fatal_error(
|
||||||
"unhandled type for getConstantWithGivenDtypeAndValue");
|
"unhandled type for getConstantWithGivenDtypeAndValue");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the number of elements of a tensor if the shape is static; otherwise,
|
||||||
|
// return -1.
|
||||||
|
int64_t Torch::getNumberOfElements(RankedTensorType inputType) {
|
||||||
|
if (!inputType.hasStaticShape())
|
||||||
|
return -1;
|
||||||
|
SmallVector<int64_t> inputShape =
|
||||||
|
makeShapeTorchCompatible(inputType.getShape());
|
||||||
|
int64_t numel = 1;
|
||||||
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||||
|
numel *= inputShape[i];
|
||||||
|
return numel;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> Torch::makeShapeLLVMCompatible(ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t> updatedShape(shape);
|
||||||
|
int64_t kDynamic = ShapedType::kDynamic;
|
||||||
|
for (unsigned i = 0; i < shape.size(); i++) {
|
||||||
|
assert(shape[i] >= 0 || shape[i] == kUnknownSize);
|
||||||
|
if (shape[i] == kUnknownSize)
|
||||||
|
updatedShape[i] = kDynamic;
|
||||||
|
}
|
||||||
|
return updatedShape;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t> updatedShape(shape);
|
||||||
|
int64_t kDynamic = ShapedType::kDynamic;
|
||||||
|
for (unsigned i = 0; i < shape.size(); i++) {
|
||||||
|
assert(shape[i] >= 0 || shape[i] == kDynamic);
|
||||||
|
if (shape[i] == kDynamic)
|
||||||
|
updatedShape[i] = kUnknownSize;
|
||||||
|
}
|
||||||
|
return updatedShape;
|
||||||
|
}
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
||||||
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
|
|
@ -397,6 +397,27 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "TorchMLIRTorchConversionToMLProgram",
|
||||||
|
srcs = glob([
|
||||||
|
"lib/Conversion/*.h",
|
||||||
|
"lib/Conversion/TorchConversionToMLProgram/*.cpp",
|
||||||
|
]),
|
||||||
|
hdrs = glob([
|
||||||
|
"include/torch-mlir/Conversion/TorchConversionToMLProgram/*.h",
|
||||||
|
]),
|
||||||
|
strip_include_prefix = "include",
|
||||||
|
deps = [
|
||||||
|
":TorchMLIRConversionPassesIncGen",
|
||||||
|
":TorchMLIRConversionUtils",
|
||||||
|
":TorchMLIRTorchBackendTypeConversion",
|
||||||
|
":TorchMLIRTorchConversionDialect",
|
||||||
|
"@llvm-project//mlir:Dialect",
|
||||||
|
"@llvm-project//mlir:LinalgDialect",
|
||||||
|
"@llvm-project//mlir:MLProgramDialect",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "TorchMLIRTorchToTMTensor",
|
name = "TorchMLIRTorchToTMTensor",
|
||||||
srcs = glob([
|
srcs = glob([
|
||||||
|
@ -450,6 +471,7 @@ cc_library(
|
||||||
":TorchMLIRTorchToSCF",
|
":TorchMLIRTorchToSCF",
|
||||||
":TorchMLIRTorchToTMTensor",
|
":TorchMLIRTorchToTMTensor",
|
||||||
":TorchMLIRTorchToTosa",
|
":TorchMLIRTorchToTosa",
|
||||||
|
":TorchMLIRTorchConversionToMLProgram",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -473,6 +495,7 @@ cc_library(
|
||||||
":TorchMLIRTorchToSCF",
|
":TorchMLIRTorchToSCF",
|
||||||
":TorchMLIRTorchToTMTensor",
|
":TorchMLIRTorchToTMTensor",
|
||||||
":TorchMLIRTorchToTosa",
|
":TorchMLIRTorchToTosa",
|
||||||
|
":TorchMLIRTorchConversionToMLProgram",
|
||||||
"@llvm-project//mlir:ConversionPasses",
|
"@llvm-project//mlir:ConversionPasses",
|
||||||
"@llvm-project//mlir:FuncDialect",
|
"@llvm-project//mlir:FuncDialect",
|
||||||
"@llvm-project//mlir:LinalgDialect",
|
"@llvm-project//mlir:LinalgDialect",
|
||||||
|
@ -758,6 +781,7 @@ cc_library(
|
||||||
":TorchMLIRTorchConversionDialect",
|
":TorchMLIRTorchConversionDialect",
|
||||||
"@llvm-project//mlir:ArithTransforms",
|
"@llvm-project//mlir:ArithTransforms",
|
||||||
"@llvm-project//mlir:LinalgDialect",
|
"@llvm-project//mlir:LinalgDialect",
|
||||||
|
"@llvm-project//mlir:MLProgramDialect",
|
||||||
"@llvm-project//mlir:LinalgTransforms",
|
"@llvm-project//mlir:LinalgTransforms",
|
||||||
"@llvm-project//mlir:MathTransforms",
|
"@llvm-project//mlir:MathTransforms",
|
||||||
"@llvm-project//mlir:MemRefDialect",
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
|
|
Loading…
Reference in New Issue