[Bazel] Resolve circular dependency and add targets for conversion to MLProgram dialect (#1694)

A circular dependency was introduced in e7edcc62fd. 

Specifically, the `makeShapeLLVMCompatible` and `makeShapeTorchCompatible` utilities were being called from `lib/Dialect/Torch/IR/TorchTypes.cpp` and `lib/Dialect/Torch/IR/TorchOps.cpp` defined under the `:TorchMLIRTorchDialect` bazel target, leading it to take a dependency on `:TorchMLIRConversionUtils` which already depends on `:TorchMLIRTorchDialect`, hence creating a circular dependency.

This commit resolves the same by moving said utilities from `lib/Conversion/Utils/Utils.cpp` to `lib/Dialect/Torch/Utils/Utils.cpp`. Please LMK if there's a better way to fix this and I will update the code.

This commit also adds the required targets to support building the new conversions from Torch to ML Program dialect that was introduced in f416953600.

Bazel build GHA triggered manually to verify: https://github.com/sjain-stanford/torch-mlir/actions/runs/3645944517
pull/1699/head
Sambhav Jain 2022-12-08 09:49:54 -08:00 committed by GitHub
parent a54b334578
commit f8a2592905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 67 additions and 45 deletions

View File

@ -86,12 +86,6 @@ Value convertScalarToDtype(
OpBuilder &b, Location loc, Value scalar, Type dtype, OpBuilder &b, Location loc, Value scalar, Type dtype,
llvm::Optional<Type> srcOriginalDtype = llvm::None); llvm::Optional<Type> srcOriginalDtype = llvm::None);
// Return the number of elements of a tensor if the shape is static; otherwise,
// return -1.
int64_t getNumberOfElements(RankedTensorType inputType);
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);
} // namespace Torch } // namespace Torch
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -55,6 +55,12 @@ bool isViewLikeOp(Operation *op);
Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc, Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc,
float value, Type dtype); float value, Type dtype);
// Return the number of elements of a tensor if the shape is static; otherwise,
// return -1.
int64_t getNumberOfElements(RankedTensorType inputType);
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);
} // namespace Torch } // namespace Torch
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -10,7 +10,6 @@
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "../PassDetail.h" #include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"

View File

@ -316,41 +316,6 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
llvm_unreachable("convertScalarToDtype should handle all the types"); llvm_unreachable("convertScalarToDtype should handle all the types");
} }
// Return the number of elements of a tensor if the shape is static; otherwise,
// return -1.
int64_t getNumberOfElements(RankedTensorType inputType) {
if (!inputType.hasStaticShape())
return -1;
SmallVector<int64_t> inputShape =
makeShapeTorchCompatible(inputType.getShape());
int64_t numel = 1;
for (int64_t i = 0; i < inputType.getRank(); i++)
numel *= inputShape[i];
return numel;
}
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape) {
SmallVector<int64_t> updatedShape(shape);
int64_t kDynamic = ShapedType::kDynamic;
for (unsigned i = 0; i < shape.size(); i++) {
assert(shape[i] >= 0 || shape[i] == kUnknownSize);
if (shape[i] == kUnknownSize)
updatedShape[i] = kDynamic;
}
return updatedShape;
}
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
SmallVector<int64_t> updatedShape(shape);
int64_t kDynamic = ShapedType::kDynamic;
for (unsigned i = 0; i < shape.size(); i++) {
assert(shape[i] >= 0 || shape[i] == kDynamic);
if (shape[i] == kDynamic)
updatedShape[i] = kUnknownSize;
}
return updatedShape;
}
} // namespace Torch } // namespace Torch
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -8,7 +8,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"

View File

@ -9,7 +9,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"

View File

@ -178,3 +178,38 @@ Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
llvm::report_fatal_error( llvm::report_fatal_error(
"unhandled type for getConstantWithGivenDtypeAndValue"); "unhandled type for getConstantWithGivenDtypeAndValue");
} }
// Return the number of elements of a tensor if the shape is static; otherwise,
// return -1.
int64_t Torch::getNumberOfElements(RankedTensorType inputType) {
if (!inputType.hasStaticShape())
return -1;
SmallVector<int64_t> inputShape =
makeShapeTorchCompatible(inputType.getShape());
int64_t numel = 1;
for (int64_t i = 0; i < inputType.getRank(); i++)
numel *= inputShape[i];
return numel;
}
SmallVector<int64_t> Torch::makeShapeLLVMCompatible(ArrayRef<int64_t> shape) {
SmallVector<int64_t> updatedShape(shape);
int64_t kDynamic = ShapedType::kDynamic;
for (unsigned i = 0; i < shape.size(); i++) {
assert(shape[i] >= 0 || shape[i] == kUnknownSize);
if (shape[i] == kUnknownSize)
updatedShape[i] = kDynamic;
}
return updatedShape;
}
SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
SmallVector<int64_t> updatedShape(shape);
int64_t kDynamic = ShapedType::kDynamic;
for (unsigned i = 0; i < shape.size(); i++) {
assert(shape[i] >= 0 || shape[i] == kDynamic);
if (shape[i] == kDynamic)
updatedShape[i] = kUnknownSize;
}
return updatedShape;
}

View File

@ -8,7 +8,6 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"

View File

@ -397,6 +397,27 @@ cc_library(
], ],
) )
cc_library(
name = "TorchMLIRTorchConversionToMLProgram",
srcs = glob([
"lib/Conversion/*.h",
"lib/Conversion/TorchConversionToMLProgram/*.cpp",
]),
hdrs = glob([
"include/torch-mlir/Conversion/TorchConversionToMLProgram/*.h",
]),
strip_include_prefix = "include",
deps = [
":TorchMLIRConversionPassesIncGen",
":TorchMLIRConversionUtils",
":TorchMLIRTorchBackendTypeConversion",
":TorchMLIRTorchConversionDialect",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:MLProgramDialect",
],
)
cc_library( cc_library(
name = "TorchMLIRTorchToTMTensor", name = "TorchMLIRTorchToTMTensor",
srcs = glob([ srcs = glob([
@ -450,6 +471,7 @@ cc_library(
":TorchMLIRTorchToSCF", ":TorchMLIRTorchToSCF",
":TorchMLIRTorchToTMTensor", ":TorchMLIRTorchToTMTensor",
":TorchMLIRTorchToTosa", ":TorchMLIRTorchToTosa",
":TorchMLIRTorchConversionToMLProgram",
], ],
) )
@ -473,6 +495,7 @@ cc_library(
":TorchMLIRTorchToSCF", ":TorchMLIRTorchToSCF",
":TorchMLIRTorchToTMTensor", ":TorchMLIRTorchToTMTensor",
":TorchMLIRTorchToTosa", ":TorchMLIRTorchToTosa",
":TorchMLIRTorchConversionToMLProgram",
"@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:ConversionPasses",
"@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgDialect",
@ -758,6 +781,7 @@ cc_library(
":TorchMLIRTorchConversionDialect", ":TorchMLIRTorchConversionDialect",
"@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:MLProgramDialect",
"@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MathTransforms", "@llvm-project//mlir:MathTransforms",
"@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefDialect",