[sparse] propagate sparsity properly when decompose torch operations. (#3318)

pull/3349/head
Peiming Liu 2024-05-15 10:09:27 -07:00 committed by GitHub
parent ba32b9cee7
commit ccb772cd0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 146 additions and 13 deletions

View File

@ -53,6 +53,9 @@ public:
/// convenient API.
Type getOptionalDtype() const;
/// Get the raw optional sparse tensor encoding.
Attribute getOptionalSparsity() const;
/// Return true if this type has a list of sizes.
bool hasSizes() const { return getOptionalSizes().has_value(); }
@ -93,6 +96,10 @@ public:
Type getWithSizesAndDtype(std::optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) const;
Type getWithSizesAndDtypeAndSparsity(
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype,
Attribute optionalSparsity) const;
/// Return a type with the same shape and dtype as this one, but with
/// value semantics.
ValueTensorType getWithValueSemantics() const;
@ -129,23 +136,31 @@ namespace Torch {
inline std::optional<ArrayRef<int64_t>>
BaseTensorType::getOptionalSizes() const {
if (auto tensor = dyn_cast<NonValueTensorType>())
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getOptionalSizes();
if (auto tensor = dyn_cast<ValueTensorType>())
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor.getOptionalSizes();
llvm_unreachable("not a BaseTensorType!");
}
inline Type BaseTensorType::getOptionalDtype() const {
if (auto tensor = dyn_cast<NonValueTensorType>())
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getOptionalDtype();
if (auto tensor = dyn_cast<ValueTensorType>())
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor.getOptionalDtype();
llvm_unreachable("not a BaseTensorType!");
}
inline Attribute BaseTensorType::getOptionalSparsity() const {
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getOptionalSparsity();
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor.getOptionalSparsity();
llvm_unreachable("not a BaseTensorType!");
}
inline bool BaseTensorType::classof(Type type) {
return type.isa<NonValueTensorType, ValueTensorType>();
return mlir::isa<NonValueTensorType, ValueTensorType>(type);
}
} // namespace Torch

View File

@ -0,0 +1,28 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H
#define TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace torch {
namespace Torch {
// Create a new SparseTensorEncodingAttr based on the provided `attr`, but with
// a new dense level inserted at `dim`.
FailureOr<Attribute> getSparsityWithDenseLTAtDim(Attribute attr, Value dim);
} // namespace Torch
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H

View File

@ -1880,9 +1880,11 @@ public:
op, adaptor, rewriter, resultShape, offsets, strides))) {
return failure();
}
SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
auto sliceType = RankedTensorType::get(
dynShape, resultType.getElementType(), resultType.getEncoding());
Value result = rewriter.create<tensor::ExtractSliceOp>(
loc, input, offsets, resultShape, strides);
loc, sliceType, input, offsets, resultShape, strides);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();

View File

@ -235,6 +235,18 @@ Type BaseTensorType::getWithSizesAndDtype(
llvm_unreachable("not a BaseTensorType!");
}
Type BaseTensorType::getWithSizesAndDtypeAndSparsity(
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype,
Attribute optionalSparsity) const {
if (mlir::isa<NonValueTensorType>(*this))
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype,
optionalSparsity);
if (mlir::isa<ValueTensorType>(*this))
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype,
optionalSparsity);
llvm_unreachable("not a BaseTensorType!");
}
ValueTensorType BaseTensorType::getWithValueSemantics() const {
if (auto tensor = dyn_cast<NonValueTensorType>())
return tensor.getWithValueSemantics();

View File

@ -71,10 +71,10 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
}
}
Type resultType = tensorType.getWithSizesAndDtype(
Type resultType = tensorType.getWithSizesAndDtypeAndSparsity(
!tensorType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::ArrayRef(sizes),
tensorType.getOptionalDtype());
tensorType.getOptionalDtype(), tensorType.getOptionalSparsity());
return resultType;
}

View File

@ -1,5 +1,6 @@
add_mlir_dialect_library(TorchMLIRTorchUtils
Utils.cpp
SparsityUtils.cpp
TorchUpstream.cpp
ADDITIONAL_HEADER_DIRS

View File

@ -0,0 +1,55 @@
//===----------------------------------------------------------------------===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/SmallVector.h"
#include <cstdint>
using namespace mlir;
using namespace mlir::sparse_tensor;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
FailureOr<Attribute> Torch::getSparsityWithDenseLTAtDim(Attribute attr,
Value dim) {
if (!attr)
return Attribute();
auto enc = cast<SparseTensorEncodingAttr>(attr);
int64_t dimInt = 0;
int64_t rank = enc.getDimRank() + 1;
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, rank);
if (!isValidDim(dimInt, rank)) {
return failure();
}
if (!enc.isIdentity()) {
// TODO: support block sparsity and permutation (CSC).
return failure();
}
auto denseLT = *LevelType::buildLvlType(LevelFormat::Dense, true, true);
SmallVector<LevelType> lvlTps = llvm::to_vector(enc.getLvlTypes());
lvlTps.insert(lvlTps.begin() + dimInt, denseLT);
auto dim2Lvl = AffineMap::getMultiDimIdentityMap(rank, attr.getContext());
return SparseTensorEncodingAttr::get(
enc.getContext(), lvlTps, dim2Lvl, AffineMap(), enc.getPosWidth(),
enc.getCrdWidth(), enc.getExplicitVal(), enc.getImplicitVal());
}
// Do not know how to handle dynamic dimension.
return failure();
}

View File

@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h"
using namespace mlir;
using namespace mlir::torch;
@ -318,6 +319,11 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(op, "input tensor must have size");
}
FailureOr<Attribute> enc =
getSparsityWithDenseLTAtDim(inputType.getOptionalSparsity(), dim);
if (failed(enc)) {
return failure();
}
SmallVector<int64_t> unsqueezedShape;
ArrayRef<int64_t> inputShape = inputType.getSizes();
@ -334,8 +340,8 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
} else {
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
}
Type unsqueezedType = inputType.getWithSizesAndDtype(
unsqueezedShape, inputType.getOptionalDtype());
Type unsqueezedType = inputType.getWithSizesAndDtypeAndSparsity(
unsqueezedShape, inputType.getOptionalDtype(), enc.value());
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
op->getLoc(), unsqueezedType, input, dim);
return unsqueezed;

View File

@ -138,8 +138,6 @@ LOWERING_PIPELINE = (
"builtin.module("
+ ",".join(
[
"func.func(refback-generalize-tensor-pad)",
"func.func(refback-generalize-tensor-concat)",
# Apply some optimizations. It would be great if MLIR had more useful
# optimizations that worked out of the box here.
# Note: When measured, this doesn't seem to actually help that much
@ -157,6 +155,10 @@ LOWERING_PIPELINE = (
"sparse-storage-specifier-to-llvm",
# Buffer deallocation pass does not know how to handle realloc.
"func.func(expand-realloc)",
# Generalize pad and concat after sparse compiler, as they are handled
# differently when the operations involve sparse operand.
"func.func(refback-generalize-tensor-pad)",
"func.func(refback-generalize-tensor-concat)",
# Bufferize.
"func.func(scf-bufferize)",
"func.func(tm-tensor-bufferize)",

View File

@ -134,6 +134,16 @@ def sparse_export(
# elif opname == "_to_dense":
# # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "select" and node.args[0].meta.get("sparsity", None):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
)
return prog

View File

@ -90,6 +90,7 @@ gentbl_cc_library(
cc_library(
name = "TorchMLIRTorchDialectUtils",
srcs = [
"lib/Dialect/Torch/Utils/SparsityUtils.cpp",
"lib/Dialect/Torch/Utils/TorchUpstream.cpp",
"lib/Dialect/Torch/Utils/Utils.cpp",
],
@ -97,6 +98,7 @@ cc_library(
"include/torch-mlir/Dialect/Torch/IR/TorchOps.h",
"include/torch-mlir/Dialect/Torch/IR/TorchTraits.h",
"include/torch-mlir/Dialect/Torch/IR/TorchTypes.h",
"include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h",
"include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h",
"include/torch-mlir/Dialect/Torch/Utils/Utils.h",
],