mirror of https://github.com/llvm/torch-mlir
[sparse] propagate sparsity properly when decompose torch operations. (#3318)
parent
ba32b9cee7
commit
ccb772cd0f
|
@ -53,6 +53,9 @@ public:
|
|||
/// convenient API.
|
||||
Type getOptionalDtype() const;
|
||||
|
||||
/// Get the raw optional sparse tensor encoding.
|
||||
Attribute getOptionalSparsity() const;
|
||||
|
||||
/// Return true if this type has a list of sizes.
|
||||
bool hasSizes() const { return getOptionalSizes().has_value(); }
|
||||
|
||||
|
@ -93,6 +96,10 @@ public:
|
|||
Type getWithSizesAndDtype(std::optional<ArrayRef<int64_t>> optionalSizes,
|
||||
Type optionalDtype) const;
|
||||
|
||||
Type getWithSizesAndDtypeAndSparsity(
|
||||
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype,
|
||||
Attribute optionalSparsity) const;
|
||||
|
||||
/// Return a type with the same shape and dtype as this one, but with
|
||||
/// value semantics.
|
||||
ValueTensorType getWithValueSemantics() const;
|
||||
|
@ -129,23 +136,31 @@ namespace Torch {
|
|||
|
||||
inline std::optional<ArrayRef<int64_t>>
|
||||
BaseTensorType::getOptionalSizes() const {
|
||||
if (auto tensor = dyn_cast<NonValueTensorType>())
|
||||
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
|
||||
return tensor.getOptionalSizes();
|
||||
if (auto tensor = dyn_cast<ValueTensorType>())
|
||||
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
|
||||
return tensor.getOptionalSizes();
|
||||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
||||
inline Type BaseTensorType::getOptionalDtype() const {
|
||||
if (auto tensor = dyn_cast<NonValueTensorType>())
|
||||
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
|
||||
return tensor.getOptionalDtype();
|
||||
if (auto tensor = dyn_cast<ValueTensorType>())
|
||||
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
|
||||
return tensor.getOptionalDtype();
|
||||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
||||
inline Attribute BaseTensorType::getOptionalSparsity() const {
|
||||
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
|
||||
return tensor.getOptionalSparsity();
|
||||
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
|
||||
return tensor.getOptionalSparsity();
|
||||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
||||
inline bool BaseTensorType::classof(Type type) {
|
||||
return type.isa<NonValueTensorType, ValueTensorType>();
|
||||
return mlir::isa<NonValueTensorType, ValueTensorType>(type);
|
||||
}
|
||||
|
||||
} // namespace Torch
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H
|
||||
#define TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
// Create a new SparseTensorEncodingAttr based on the provided `attr`, but with
|
||||
// a new dense level inserted at `dim`.
|
||||
FailureOr<Attribute> getSparsityWithDenseLTAtDim(Attribute attr, Value dim);
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H
|
|
@ -1880,9 +1880,11 @@ public:
|
|||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
|
||||
auto sliceType = RankedTensorType::get(
|
||||
dynShape, resultType.getElementType(), resultType.getEncoding());
|
||||
Value result = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, input, offsets, resultShape, strides);
|
||||
loc, sliceType, input, offsets, resultShape, strides);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
return success();
|
||||
|
|
|
@ -235,6 +235,18 @@ Type BaseTensorType::getWithSizesAndDtype(
|
|||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
||||
Type BaseTensorType::getWithSizesAndDtypeAndSparsity(
|
||||
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype,
|
||||
Attribute optionalSparsity) const {
|
||||
if (mlir::isa<NonValueTensorType>(*this))
|
||||
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype,
|
||||
optionalSparsity);
|
||||
if (mlir::isa<ValueTensorType>(*this))
|
||||
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype,
|
||||
optionalSparsity);
|
||||
llvm_unreachable("not a BaseTensorType!");
|
||||
}
|
||||
|
||||
ValueTensorType BaseTensorType::getWithValueSemantics() const {
|
||||
if (auto tensor = dyn_cast<NonValueTensorType>())
|
||||
return tensor.getWithValueSemantics();
|
||||
|
|
|
@ -71,10 +71,10 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
|
|||
}
|
||||
}
|
||||
|
||||
Type resultType = tensorType.getWithSizesAndDtype(
|
||||
Type resultType = tensorType.getWithSizesAndDtypeAndSparsity(
|
||||
!tensorType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
|
||||
: llvm::ArrayRef(sizes),
|
||||
tensorType.getOptionalDtype());
|
||||
tensorType.getOptionalDtype(), tensorType.getOptionalSparsity());
|
||||
return resultType;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
add_mlir_dialect_library(TorchMLIRTorchUtils
|
||||
Utils.cpp
|
||||
SparsityUtils.cpp
|
||||
TorchUpstream.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include <cstdint>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::sparse_tensor;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
FailureOr<Attribute> Torch::getSparsityWithDenseLTAtDim(Attribute attr,
|
||||
Value dim) {
|
||||
if (!attr)
|
||||
return Attribute();
|
||||
|
||||
auto enc = cast<SparseTensorEncodingAttr>(attr);
|
||||
int64_t dimInt = 0;
|
||||
int64_t rank = enc.getDimRank() + 1;
|
||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, rank);
|
||||
if (!isValidDim(dimInt, rank)) {
|
||||
return failure();
|
||||
}
|
||||
if (!enc.isIdentity()) {
|
||||
// TODO: support block sparsity and permutation (CSC).
|
||||
return failure();
|
||||
}
|
||||
auto denseLT = *LevelType::buildLvlType(LevelFormat::Dense, true, true);
|
||||
SmallVector<LevelType> lvlTps = llvm::to_vector(enc.getLvlTypes());
|
||||
lvlTps.insert(lvlTps.begin() + dimInt, denseLT);
|
||||
auto dim2Lvl = AffineMap::getMultiDimIdentityMap(rank, attr.getContext());
|
||||
return SparseTensorEncodingAttr::get(
|
||||
enc.getContext(), lvlTps, dim2Lvl, AffineMap(), enc.getPosWidth(),
|
||||
enc.getCrdWidth(), enc.getExplicitVal(), enc.getImplicitVal());
|
||||
}
|
||||
// Do not know how to handle dynamic dimension.
|
||||
return failure();
|
||||
}
|
|
@ -11,6 +11,7 @@
|
|||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -318,6 +319,11 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
|
|||
if (!inputType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op, "input tensor must have size");
|
||||
}
|
||||
FailureOr<Attribute> enc =
|
||||
getSparsityWithDenseLTAtDim(inputType.getOptionalSparsity(), dim);
|
||||
if (failed(enc)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> unsqueezedShape;
|
||||
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||
|
@ -334,8 +340,8 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
|
|||
} else {
|
||||
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
|
||||
}
|
||||
Type unsqueezedType = inputType.getWithSizesAndDtype(
|
||||
unsqueezedShape, inputType.getOptionalDtype());
|
||||
Type unsqueezedType = inputType.getWithSizesAndDtypeAndSparsity(
|
||||
unsqueezedShape, inputType.getOptionalDtype(), enc.value());
|
||||
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
|
||||
op->getLoc(), unsqueezedType, input, dim);
|
||||
return unsqueezed;
|
||||
|
|
|
@ -138,8 +138,6 @@ LOWERING_PIPELINE = (
|
|||
"builtin.module("
|
||||
+ ",".join(
|
||||
[
|
||||
"func.func(refback-generalize-tensor-pad)",
|
||||
"func.func(refback-generalize-tensor-concat)",
|
||||
# Apply some optimizations. It would be great if MLIR had more useful
|
||||
# optimizations that worked out of the box here.
|
||||
# Note: When measured, this doesn't seem to actually help that much
|
||||
|
@ -157,6 +155,10 @@ LOWERING_PIPELINE = (
|
|||
"sparse-storage-specifier-to-llvm",
|
||||
# Buffer deallocation pass does not know how to handle realloc.
|
||||
"func.func(expand-realloc)",
|
||||
# Generalize pad and concat after sparse compiler, as they are handled
|
||||
# differently when the operations involve sparse operand.
|
||||
"func.func(refback-generalize-tensor-pad)",
|
||||
"func.func(refback-generalize-tensor-concat)",
|
||||
# Bufferize.
|
||||
"func.func(scf-bufferize)",
|
||||
"func.func(tm-tensor-bufferize)",
|
||||
|
|
|
@ -134,6 +134,16 @@ def sparse_export(
|
|||
# elif opname == "_to_dense":
|
||||
# # hack (assumes we never really want the to_dense for now)
|
||||
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
|
||||
elif opname == "select" and node.args[0].meta.get("sparsity", None):
|
||||
dim = len(node.meta.get("val").shape)
|
||||
node.meta["sparsity"] = SparsityMeta(
|
||||
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
|
||||
)
|
||||
elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
|
||||
dim = len(node.meta.get("val").shape)
|
||||
node.meta["sparsity"] = SparsityMeta(
|
||||
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
|
||||
)
|
||||
return prog
|
||||
|
||||
|
||||
|
|
|
@ -90,6 +90,7 @@ gentbl_cc_library(
|
|||
cc_library(
|
||||
name = "TorchMLIRTorchDialectUtils",
|
||||
srcs = [
|
||||
"lib/Dialect/Torch/Utils/SparsityUtils.cpp",
|
||||
"lib/Dialect/Torch/Utils/TorchUpstream.cpp",
|
||||
"lib/Dialect/Torch/Utils/Utils.cpp",
|
||||
],
|
||||
|
@ -97,6 +98,7 @@ cc_library(
|
|||
"include/torch-mlir/Dialect/Torch/IR/TorchOps.h",
|
||||
"include/torch-mlir/Dialect/Torch/IR/TorchTraits.h",
|
||||
"include/torch-mlir/Dialect/Torch/IR/TorchTypes.h",
|
||||
"include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h",
|
||||
"include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h",
|
||||
"include/torch-mlir/Dialect/Torch/Utils/Utils.h",
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue