[torch-mlir][sparse] preserve sparsity during lowering torch to linalg (#2809)

This preserves sparsity at the most obvious places of lowering TORCH
tensors to MLIR RankedTensorType tensors. Other places are marked for
audit. With some initial lowering tests.
pull/2813/head
Aart Bik 2024-01-26 10:54:59 -08:00 committed by GitHub
parent da7c6d2c16
commit 46a25d7241
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 50 additions and 8 deletions

View File

@ -978,6 +978,7 @@ public:
return success();
}
// TODO: audit possibility of sparsity on these tensors
Type adjustedResultType = RankedTensorType::get(
makeShapeLLVMCompatible(outputShape), resultType.getElementType());
Type adjustedInputType = RankedTensorType::get(
@ -1005,6 +1006,7 @@ public:
intermediateShape.push_back(sum);
}
// TODO: audit possibility of sparsity on these tensor
Type intermediateResultType =
RankedTensorType::get(makeShapeLLVMCompatible(intermediateShape),
resultType.getElementType());
@ -1657,6 +1659,7 @@ public:
auto srcType = src.getType().cast<RankedTensorType>();
int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
// TODO: audit possibility of sparsity on these tensor
auto abstractSrcType = RankedTensorType::get(
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
Value abstractSrc =

View File

@ -206,8 +206,8 @@ namespace {
//
// TODO: Find an optimal lowering.
// current lowering is not optimal for bags of large embeddings.
// Since it traverses the output tensor multiple times.
//
// Since it traverses the output tensor multiple times.
//
//
class ConvertAtenEmbeddingBagPaddingIdxOp

View File

@ -377,8 +377,8 @@ public:
// TODO: Improve usage of static shape information.
SmallVector<int64_t> lhsTargetShape(lhsBroadcastToShape.size(),
ShapedType::kDynamic);
auto lhsBroadcastType =
RankedTensorType::get(lhsTargetShape, lhsType.getElementType());
auto lhsBroadcastType = RankedTensorType::get(
lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding());
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType,
broadcastedLhs))) {
@ -387,8 +387,8 @@ public:
}
SmallVector<int64_t> rhsTargetShape(rhsBroadcastToShape.size(),
ShapedType::kDynamic);
auto rhsBroadcastType =
RankedTensorType::get(rhsTargetShape, rhsType.getElementType());
auto rhsBroadcastType = RankedTensorType::get(
rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding());
if (failed(torch_to_linalg::broadcastToGivenShape(
op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType,
broadcastedRhs))) {
@ -880,7 +880,7 @@ public:
if(numSpacialDims != 2)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D grouped convolution supported");
// Special depthwise case
auto inShape = makeShapeTorchCompatible(
input.getType().cast<RankedTensorType>().getShape());
@ -894,6 +894,7 @@ public:
(weightShape[0] == kUnknownSize ? kUnknownSize
: weightShape[0] * weightShape[1]),
weightShape[2], weightShape[3]};
// TODO: audit possibility of sparsity on this tensor
Type collapsedType = RankedTensorType::get(
makeShapeLLVMCompatible(collapsedShape), elementType);
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(

View File

@ -87,6 +87,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
*pad = castIntToIndex(b, loc, *pad);
Type elementType = input.getType().cast<RankedTensorType>().getElementType();
// TODO: audit possibility of sparsity on this tensor
Type inputType =
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
SmallVector<int64_t>(inRank, kUnknownSize))),

View File

@ -467,7 +467,8 @@ TensorType ValueTensorType::toBuiltinTensor() const {
Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype());
if (!elementType)
return nullptr;
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType);
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType,
getOptionalSparsity());
}
LogicalResult

View File

@ -0,0 +1,36 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// -----
#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
// CHECK-LABEL: func.func @sum(
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32>
// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[64,64],f32,#[[$CSR]]> -> tensor<64x64xf32, #[[$CSR]]>
// CHECK: linalg.generic {{{.*}}} ins(%[[S]] : tensor<64x64xf32, #[[$CSR]]>)
func.func @sum(%arg0: !torch.vtensor<[64,64],f32,#CSR>) -> !torch.vtensor<[],f32> {
%none = torch.constant.none
%0 = torch.aten.sum %arg0, %none
: !torch.vtensor<[64,64],f32,#CSR>, !torch.none -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
// CHECK-LABEL: func.func @SpMM(
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,16],f32,#[[$CSR]]>,
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32>
// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[8,16],f32,#[[$CSR]]> -> tensor<8x16xf32, #[[$CSR]]>
// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[B]] : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32>
// CHECK: linalg.matmul ins(%[[S]], %[[T]] : tensor<8x16xf32, #[[$CSR]]>, tensor<16x8xf32>)
func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>,
%arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> {
%0 = torch.aten.matmul %arg0, %arg1
: !torch.vtensor<[8,16],f32,#CSR>,
!torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
return %0 : !torch.vtensor<[8,8],f32>
}