[sparse] propagate sparsity properly when decompose torch operations. (#3318)

pull/3349/head
Peiming Liu 2024-05-15 10:09:27 -07:00 committed by GitHub
parent ba32b9cee7
commit ccb772cd0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 146 additions and 13 deletions

View File

@ -53,6 +53,9 @@ public:
/// convenient API. /// convenient API.
Type getOptionalDtype() const; Type getOptionalDtype() const;
/// Get the raw optional sparse tensor encoding.
Attribute getOptionalSparsity() const;
/// Return true if this type has a list of sizes. /// Return true if this type has a list of sizes.
bool hasSizes() const { return getOptionalSizes().has_value(); } bool hasSizes() const { return getOptionalSizes().has_value(); }
@ -93,6 +96,10 @@ public:
Type getWithSizesAndDtype(std::optional<ArrayRef<int64_t>> optionalSizes, Type getWithSizesAndDtype(std::optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) const; Type optionalDtype) const;
Type getWithSizesAndDtypeAndSparsity(
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype,
Attribute optionalSparsity) const;
/// Return a type with the same shape and dtype as this one, but with /// Return a type with the same shape and dtype as this one, but with
/// value semantics. /// value semantics.
ValueTensorType getWithValueSemantics() const; ValueTensorType getWithValueSemantics() const;
@ -129,23 +136,31 @@ namespace Torch {
inline std::optional<ArrayRef<int64_t>> inline std::optional<ArrayRef<int64_t>>
BaseTensorType::getOptionalSizes() const { BaseTensorType::getOptionalSizes() const {
if (auto tensor = dyn_cast<NonValueTensorType>()) if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getOptionalSizes(); return tensor.getOptionalSizes();
if (auto tensor = dyn_cast<ValueTensorType>()) if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor.getOptionalSizes(); return tensor.getOptionalSizes();
llvm_unreachable("not a BaseTensorType!"); llvm_unreachable("not a BaseTensorType!");
} }
inline Type BaseTensorType::getOptionalDtype() const { inline Type BaseTensorType::getOptionalDtype() const {
if (auto tensor = dyn_cast<NonValueTensorType>()) if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getOptionalDtype(); return tensor.getOptionalDtype();
if (auto tensor = dyn_cast<ValueTensorType>()) if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor.getOptionalDtype(); return tensor.getOptionalDtype();
llvm_unreachable("not a BaseTensorType!"); llvm_unreachable("not a BaseTensorType!");
} }
inline Attribute BaseTensorType::getOptionalSparsity() const {
if (auto tensor = mlir::dyn_cast<NonValueTensorType>(*this))
return tensor.getOptionalSparsity();
if (auto tensor = mlir::dyn_cast<ValueTensorType>(*this))
return tensor.getOptionalSparsity();
llvm_unreachable("not a BaseTensorType!");
}
inline bool BaseTensorType::classof(Type type) { inline bool BaseTensorType::classof(Type type) {
return type.isa<NonValueTensorType, ValueTensorType>(); return mlir::isa<NonValueTensorType, ValueTensorType>(type);
} }
} // namespace Torch } // namespace Torch

View File

@ -0,0 +1,28 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#ifndef TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H
#define TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace torch {
namespace Torch {
// Create a new SparseTensorEncodingAttr based on the provided `attr`, but with
// a new dense level inserted at `dim`.
FailureOr<Attribute> getSparsityWithDenseLTAtDim(Attribute attr, Value dim);
} // namespace Torch
} // namespace torch
} // namespace mlir
#endif // TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H

View File

@ -1880,9 +1880,11 @@ public:
op, adaptor, rewriter, resultShape, offsets, strides))) { op, adaptor, rewriter, resultShape, offsets, strides))) {
return failure(); return failure();
} }
SmallVector<int64_t> dynShape(resultType.getRank(), ShapedType::kDynamic);
auto sliceType = RankedTensorType::get(
dynShape, resultType.getElementType(), resultType.getEncoding());
Value result = rewriter.create<tensor::ExtractSliceOp>( Value result = rewriter.create<tensor::ExtractSliceOp>(
loc, input, offsets, resultShape, strides); loc, sliceType, input, offsets, resultShape, strides);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success(); return success();

View File

@ -235,6 +235,18 @@ Type BaseTensorType::getWithSizesAndDtype(
llvm_unreachable("not a BaseTensorType!"); llvm_unreachable("not a BaseTensorType!");
} }
Type BaseTensorType::getWithSizesAndDtypeAndSparsity(
std::optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype,
Attribute optionalSparsity) const {
if (mlir::isa<NonValueTensorType>(*this))
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype,
optionalSparsity);
if (mlir::isa<ValueTensorType>(*this))
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype,
optionalSparsity);
llvm_unreachable("not a BaseTensorType!");
}
ValueTensorType BaseTensorType::getWithValueSemantics() const { ValueTensorType BaseTensorType::getWithValueSemantics() const {
if (auto tensor = dyn_cast<NonValueTensorType>()) if (auto tensor = dyn_cast<NonValueTensorType>())
return tensor.getWithValueSemantics(); return tensor.getWithValueSemantics();

View File

@ -71,10 +71,10 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
} }
} }
Type resultType = tensorType.getWithSizesAndDtype( Type resultType = tensorType.getWithSizesAndDtypeAndSparsity(
!tensorType.hasSizes() ? std::optional<ArrayRef<int64_t>>() !tensorType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::ArrayRef(sizes), : llvm::ArrayRef(sizes),
tensorType.getOptionalDtype()); tensorType.getOptionalDtype(), tensorType.getOptionalSparsity());
return resultType; return resultType;
} }

View File

@ -1,5 +1,6 @@
add_mlir_dialect_library(TorchMLIRTorchUtils add_mlir_dialect_library(TorchMLIRTorchUtils
Utils.cpp Utils.cpp
SparsityUtils.cpp
TorchUpstream.cpp TorchUpstream.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS

View File

@ -0,0 +1,55 @@
//===----------------------------------------------------------------------===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/SmallVector.h"
#include <cstdint>
using namespace mlir;
using namespace mlir::sparse_tensor;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
FailureOr<Attribute> Torch::getSparsityWithDenseLTAtDim(Attribute attr,
Value dim) {
if (!attr)
return Attribute();
auto enc = cast<SparseTensorEncodingAttr>(attr);
int64_t dimInt = 0;
int64_t rank = enc.getDimRank() + 1;
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, rank);
if (!isValidDim(dimInt, rank)) {
return failure();
}
if (!enc.isIdentity()) {
// TODO: support block sparsity and permutation (CSC).
return failure();
}
auto denseLT = *LevelType::buildLvlType(LevelFormat::Dense, true, true);
SmallVector<LevelType> lvlTps = llvm::to_vector(enc.getLvlTypes());
lvlTps.insert(lvlTps.begin() + dimInt, denseLT);
auto dim2Lvl = AffineMap::getMultiDimIdentityMap(rank, attr.getContext());
return SparseTensorEncodingAttr::get(
enc.getContext(), lvlTps, dim2Lvl, AffineMap(), enc.getPosWidth(),
enc.getCrdWidth(), enc.getExplicitVal(), enc.getImplicitVal());
}
// Do not know how to handle dynamic dimension.
return failure();
}

View File

@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -318,6 +319,11 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
if (!inputType.hasSizes()) { if (!inputType.hasSizes()) {
return rewriter.notifyMatchFailure(op, "input tensor must have size"); return rewriter.notifyMatchFailure(op, "input tensor must have size");
} }
FailureOr<Attribute> enc =
getSparsityWithDenseLTAtDim(inputType.getOptionalSparsity(), dim);
if (failed(enc)) {
return failure();
}
SmallVector<int64_t> unsqueezedShape; SmallVector<int64_t> unsqueezedShape;
ArrayRef<int64_t> inputShape = inputType.getSizes(); ArrayRef<int64_t> inputShape = inputType.getSizes();
@ -334,8 +340,8 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
} else { } else {
unsqueezedShape.resize(unsqueezedRank, kUnknownSize); unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
} }
Type unsqueezedType = inputType.getWithSizesAndDtype( Type unsqueezedType = inputType.getWithSizesAndDtypeAndSparsity(
unsqueezedShape, inputType.getOptionalDtype()); unsqueezedShape, inputType.getOptionalDtype(), enc.value());
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>( Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
op->getLoc(), unsqueezedType, input, dim); op->getLoc(), unsqueezedType, input, dim);
return unsqueezed; return unsqueezed;

View File

@ -138,8 +138,6 @@ LOWERING_PIPELINE = (
"builtin.module(" "builtin.module("
+ ",".join( + ",".join(
[ [
"func.func(refback-generalize-tensor-pad)",
"func.func(refback-generalize-tensor-concat)",
# Apply some optimizations. It would be great if MLIR had more useful # Apply some optimizations. It would be great if MLIR had more useful
# optimizations that worked out of the box here. # optimizations that worked out of the box here.
# Note: When measured, this doesn't seem to actually help that much # Note: When measured, this doesn't seem to actually help that much
@ -157,6 +155,10 @@ LOWERING_PIPELINE = (
"sparse-storage-specifier-to-llvm", "sparse-storage-specifier-to-llvm",
# Buffer deallocation pass does not know how to handle realloc. # Buffer deallocation pass does not know how to handle realloc.
"func.func(expand-realloc)", "func.func(expand-realloc)",
# Generalize pad and concat after sparse compiler, as they are handled
# differently when the operations involve sparse operand.
"func.func(refback-generalize-tensor-pad)",
"func.func(refback-generalize-tensor-concat)",
# Bufferize. # Bufferize.
"func.func(scf-bufferize)", "func.func(scf-bufferize)",
"func.func(tm-tensor-bufferize)", "func.func(tm-tensor-bufferize)",

View File

@ -134,6 +134,16 @@ def sparse_export(
# elif opname == "_to_dense": # elif opname == "_to_dense":
# # hack (assumes we never really want the to_dense for now) # # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif opname == "select" and node.args[0].meta.get("sparsity", None):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
)
return prog return prog

View File

@ -90,6 +90,7 @@ gentbl_cc_library(
cc_library( cc_library(
name = "TorchMLIRTorchDialectUtils", name = "TorchMLIRTorchDialectUtils",
srcs = [ srcs = [
"lib/Dialect/Torch/Utils/SparsityUtils.cpp",
"lib/Dialect/Torch/Utils/TorchUpstream.cpp", "lib/Dialect/Torch/Utils/TorchUpstream.cpp",
"lib/Dialect/Torch/Utils/Utils.cpp", "lib/Dialect/Torch/Utils/Utils.cpp",
], ],
@ -97,6 +98,7 @@ cc_library(
"include/torch-mlir/Dialect/Torch/IR/TorchOps.h", "include/torch-mlir/Dialect/Torch/IR/TorchOps.h",
"include/torch-mlir/Dialect/Torch/IR/TorchTraits.h", "include/torch-mlir/Dialect/Torch/IR/TorchTraits.h",
"include/torch-mlir/Dialect/Torch/IR/TorchTypes.h", "include/torch-mlir/Dialect/Torch/IR/TorchTypes.h",
"include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h",
"include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h", "include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h",
"include/torch-mlir/Dialect/Torch/Utils/Utils.h", "include/torch-mlir/Dialect/Torch/Utils/Utils.h",
], ],