mirror of https://github.com/llvm/torch-mlir
build: update llvm tag to bebc9695 (#1415)
Summary of changes: - Renamed OptionalArrayRefParameter since the name conflicts with an upstream symbol that has a different meaning (https://reviews.llvm.org/D133819) - Removed extraneous dependency between TorchMLIRTorchToMhlo and ChloOps, since the existing dependency on MhloDialect is sufficient - Fixed code to prevent warnings related to comparisons between signed and unsigned valuespull/1417/head
parent
3e27aa2be3
commit
a60acf272d
|
@ -1 +1 @@
|
||||||
Subproject commit 458598ccc50c5118107f05d60f3d043772a91f26
|
Subproject commit bebc96956b76bdbc36f1d82a788c810e5b12e2c5
|
|
@ -1 +1 @@
|
||||||
Subproject commit cd9da150e729fd046109e7962e5f63f5fe067a3b
|
Subproject commit 7b0ecf7827e3fc07d2af90e147bcedc165bc78ac
|
|
@ -45,7 +45,7 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// For standard ArrayRefs, which require allocation.
|
// For standard ArrayRefs, which require allocation.
|
||||||
class OptionalArrayRefParameter<string arrayOf, string desc = ""> :
|
class OptionalArrayRefTorchParameter<string arrayOf, string desc = ""> :
|
||||||
AttrOrTypeParameter<
|
AttrOrTypeParameter<
|
||||||
"::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> {
|
"::llvm::Optional<::llvm::ArrayRef<" # arrayOf # ">>", desc> {
|
||||||
let allocator = [{
|
let allocator = [{
|
||||||
|
@ -146,7 +146,7 @@ class AnyTorchTensorType<string name, string typeMnemonic>
|
||||||
- `getElementType()` -> `getDtype()` (but be sure that `hasDtype()` though).
|
- `getElementType()` -> `getDtype()` (but be sure that `hasDtype()` though).
|
||||||
}];
|
}];
|
||||||
let parameters = (ins
|
let parameters = (ins
|
||||||
OptionalArrayRefParameter<"int64_t", "sizes of dimensions">:$optionalSizes,
|
OptionalArrayRefTorchParameter<"int64_t", "sizes of dimensions">:$optionalSizes,
|
||||||
"::mlir::Type":$optionalDtype
|
"::mlir::Type":$optionalDtype
|
||||||
);
|
);
|
||||||
let genVerifyDecl = 1;
|
let genVerifyDecl = 1;
|
||||||
|
|
|
@ -22,7 +22,6 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
||||||
Core
|
Core
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
ChloOps
|
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRPass
|
MLIRPass
|
||||||
MhloDialect
|
MhloDialect
|
||||||
|
|
|
@ -102,13 +102,13 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
|
||||||
}
|
}
|
||||||
SmallVector<int64_t> outShape;
|
SmallVector<int64_t> outShape;
|
||||||
// set batch dims, will skip invalid dimensions
|
// set batch dims, will skip invalid dimensions
|
||||||
for (size_t k = 0; k < lhsShape.size(); ++k) {
|
for (int64_t k = 0; k < static_cast<int64_t>(lhsShape.size()); ++k) {
|
||||||
if (k == lhsResultDim || k == lhsContractingDim)
|
if (k == lhsResultDim || k == lhsContractingDim)
|
||||||
continue;
|
continue;
|
||||||
outShape.push_back(lhsShape[k]);
|
outShape.push_back(lhsShape[k]);
|
||||||
}
|
}
|
||||||
for (size_t k = 0, b = 0; k < rhsShape.size(); ++k) {
|
for (int64_t k = 0, b = 0; k < static_cast<int64_t>(rhsShape.size()); ++k) {
|
||||||
if (b >= outShape.size())
|
if (b >= static_cast<int64_t>(outShape.size()))
|
||||||
break;
|
break;
|
||||||
if (k == rhsResultDim || k == rhsContractingDim)
|
if (k == rhsResultDim || k == rhsContractingDim)
|
||||||
continue;
|
continue;
|
||||||
|
@ -119,10 +119,10 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
// set result dimensions
|
// set result dimensions
|
||||||
if (lhsResultDim < lhsShape.size() && lhsResultDim >= 0) {
|
if (lhsResultDim < static_cast<int64_t>(lhsShape.size()) && lhsResultDim >= 0) {
|
||||||
outShape.push_back(lhsShape[lhsResultDim]);
|
outShape.push_back(lhsShape[lhsResultDim]);
|
||||||
}
|
}
|
||||||
if (rhsResultDim < rhsShape.size() && rhsResultDim >= 0) {
|
if (rhsResultDim < static_cast<int64_t>(rhsShape.size()) && rhsResultDim >= 0) {
|
||||||
outShape.push_back(rhsShape[rhsResultDim]);
|
outShape.push_back(rhsShape[rhsResultDim]);
|
||||||
}
|
}
|
||||||
return RankedTensorType::get(outShape, lhsTy.getElementType());
|
return RankedTensorType::get(outShape, lhsTy.getElementType());
|
||||||
|
|
Loading…
Reference in New Issue