From 46a25d72412c1bb00bd947a44b1c7dde7bd7ef53 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Fri, 26 Jan 2024 10:54:59 -0800 Subject: [PATCH] [torch-mlir][sparse] preserve sparsity during lowering torch to linalg (#2809) This preserves sparsity at the most obvious places of lowering TORCH tensors to MLIR RankedTensorType tensors. Other places are marked for audit. With some initial lowering tests. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 3 ++ .../TorchToLinalg/IndirectDataMovement.cpp | 4 +-- lib/Conversion/TorchToLinalg/Linear.cpp | 11 +++--- lib/Conversion/TorchToLinalg/Utils.cpp | 1 + lib/Dialect/Torch/IR/TorchTypes.cpp | 3 +- test/Conversion/TorchToLinalg/sparse.mlir | 36 +++++++++++++++++++ 6 files changed, 50 insertions(+), 8 deletions(-) create mode 100644 test/Conversion/TorchToLinalg/sparse.mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 297a0f4c2..e96d65970 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -978,6 +978,7 @@ public: return success(); } + // TODO: audit possibility of sparsity on these tensors Type adjustedResultType = RankedTensorType::get( makeShapeLLVMCompatible(outputShape), resultType.getElementType()); Type adjustedInputType = RankedTensorType::get( @@ -1005,6 +1006,7 @@ public: intermediateShape.push_back(sum); } + // TODO: audit possibility of sparsity on these tensor Type intermediateResultType = RankedTensorType::get(makeShapeLLVMCompatible(intermediateShape), resultType.getElementType()); @@ -1657,6 +1659,7 @@ public: auto srcType = src.getType().cast(); int64_t srcRank = srcType.getRank(); SmallVector srcAbstractSizes(srcRank, kUnknownSize); + // TODO: audit possibility of sparsity on these tensor auto abstractSrcType = RankedTensorType::get( makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType()); Value abstractSrc = diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 277341bea..f9ee56070 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -206,8 +206,8 @@ namespace { // // TODO: Find an optimal lowering. // current lowering is not optimal for bags of large embeddings. -// Since it traverses the output tensor multiple times. -// +// Since it traverses the output tensor multiple times. +// // class ConvertAtenEmbeddingBagPaddingIdxOp diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index d818b99c0..6d0d72075 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -377,8 +377,8 @@ public: // TODO: Improve usage of static shape information. SmallVector lhsTargetShape(lhsBroadcastToShape.size(), ShapedType::kDynamic); - auto lhsBroadcastType = - RankedTensorType::get(lhsTargetShape, lhsType.getElementType()); + auto lhsBroadcastType = RankedTensorType::get( + lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType, broadcastedLhs))) { @@ -387,8 +387,8 @@ public: } SmallVector rhsTargetShape(rhsBroadcastToShape.size(), ShapedType::kDynamic); - auto rhsBroadcastType = - RankedTensorType::get(rhsTargetShape, rhsType.getElementType()); + auto rhsBroadcastType = RankedTensorType::get( + rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType, broadcastedRhs))) { @@ -880,7 +880,7 @@ public: if(numSpacialDims != 2) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D grouped convolution supported"); - + // Special depthwise case auto inShape = makeShapeTorchCompatible( input.getType().cast().getShape()); @@ -894,6 +894,7 @@ public: (weightShape[0] == kUnknownSize ? kUnknownSize : weightShape[0] * weightShape[1]), weightShape[2], weightShape[3]}; + // TODO: audit possibility of sparsity on this tensor Type collapsedType = RankedTensorType::get( makeShapeLLVMCompatible(collapsedShape), elementType); Value collapsedWeight = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 77459aca3..8bff5034c 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -87,6 +87,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( *pad = castIntToIndex(b, loc, *pad); Type elementType = input.getType().cast().getElementType(); + // TODO: audit possibility of sparsity on this tensor Type inputType = RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef( SmallVector(inRank, kUnknownSize))), diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index b5b63954f..a154fb465 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -467,7 +467,8 @@ TensorType ValueTensorType::toBuiltinTensor() const { Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype()); if (!elementType) return nullptr; - return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType); + return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType, + getOptionalSparsity()); } LogicalResult diff --git a/test/Conversion/TorchToLinalg/sparse.mlir b/test/Conversion/TorchToLinalg/sparse.mlir new file mode 100644 index 000000000..5d952fde3 --- /dev/null +++ b/test/Conversion/TorchToLinalg/sparse.mlir @@ -0,0 +1,36 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> + +// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK-LABEL: func.func @sum( +// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> +// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[64,64],f32,#[[$CSR]]> -> tensor<64x64xf32, #[[$CSR]]> +// CHECK: linalg.generic {{{.*}}} ins(%[[S]] : tensor<64x64xf32, #[[$CSR]]>) +func.func @sum(%arg0: !torch.vtensor<[64,64],f32,#CSR>) -> !torch.vtensor<[],f32> { + %none = torch.constant.none + %0 = torch.aten.sum %arg0, %none + : !torch.vtensor<[64,64],f32,#CSR>, !torch.none -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> + +// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK-LABEL: func.func @SpMM( +// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,16],f32,#[[$CSR]]>, +// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> +// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[8,16],f32,#[[$CSR]]> -> tensor<8x16xf32, #[[$CSR]]> +// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[B]] : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32> +// CHECK: linalg.matmul ins(%[[S]], %[[T]] : tensor<8x16xf32, #[[$CSR]]>, tensor<16x8xf32>) +func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>, + %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> { + %0 = torch.aten.matmul %arg0, %arg1 + : !torch.vtensor<[8,16],f32,#CSR>, + !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32> + return %0 : !torch.vtensor<[8,8],f32> +}