mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] preserve sparsity during lowering torch to linalg (#2809)
This preserves sparsity at the most obvious places of lowering TORCH tensors to MLIR RankedTensorType tensors. Other places are marked for audit. With some initial lowering tests.pull/2813/head
parent
da7c6d2c16
commit
46a25d7241
|
@ -978,6 +978,7 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
// TODO: audit possibility of sparsity on these tensors
|
||||
Type adjustedResultType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(outputShape), resultType.getElementType());
|
||||
Type adjustedInputType = RankedTensorType::get(
|
||||
|
@ -1005,6 +1006,7 @@ public:
|
|||
intermediateShape.push_back(sum);
|
||||
}
|
||||
|
||||
// TODO: audit possibility of sparsity on these tensor
|
||||
Type intermediateResultType =
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(intermediateShape),
|
||||
resultType.getElementType());
|
||||
|
@ -1657,6 +1659,7 @@ public:
|
|||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
int64_t srcRank = srcType.getRank();
|
||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||
// TODO: audit possibility of sparsity on these tensor
|
||||
auto abstractSrcType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
|
||||
Value abstractSrc =
|
||||
|
|
|
@ -206,8 +206,8 @@ namespace {
|
|||
//
|
||||
// TODO: Find an optimal lowering.
|
||||
// current lowering is not optimal for bags of large embeddings.
|
||||
// Since it traverses the output tensor multiple times.
|
||||
//
|
||||
// Since it traverses the output tensor multiple times.
|
||||
//
|
||||
//
|
||||
|
||||
class ConvertAtenEmbeddingBagPaddingIdxOp
|
||||
|
|
|
@ -377,8 +377,8 @@ public:
|
|||
// TODO: Improve usage of static shape information.
|
||||
SmallVector<int64_t> lhsTargetShape(lhsBroadcastToShape.size(),
|
||||
ShapedType::kDynamic);
|
||||
auto lhsBroadcastType =
|
||||
RankedTensorType::get(lhsTargetShape, lhsType.getElementType());
|
||||
auto lhsBroadcastType = RankedTensorType::get(
|
||||
lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding());
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType,
|
||||
broadcastedLhs))) {
|
||||
|
@ -387,8 +387,8 @@ public:
|
|||
}
|
||||
SmallVector<int64_t> rhsTargetShape(rhsBroadcastToShape.size(),
|
||||
ShapedType::kDynamic);
|
||||
auto rhsBroadcastType =
|
||||
RankedTensorType::get(rhsTargetShape, rhsType.getElementType());
|
||||
auto rhsBroadcastType = RankedTensorType::get(
|
||||
rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding());
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType,
|
||||
broadcastedRhs))) {
|
||||
|
@ -880,7 +880,7 @@ public:
|
|||
if(numSpacialDims != 2)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D grouped convolution supported");
|
||||
|
||||
|
||||
// Special depthwise case
|
||||
auto inShape = makeShapeTorchCompatible(
|
||||
input.getType().cast<RankedTensorType>().getShape());
|
||||
|
@ -894,6 +894,7 @@ public:
|
|||
(weightShape[0] == kUnknownSize ? kUnknownSize
|
||||
: weightShape[0] * weightShape[1]),
|
||||
weightShape[2], weightShape[3]};
|
||||
// TODO: audit possibility of sparsity on this tensor
|
||||
Type collapsedType = RankedTensorType::get(
|
||||
makeShapeLLVMCompatible(collapsedShape), elementType);
|
||||
Value collapsedWeight = rewriter.create<tensor::CollapseShapeOp>(
|
||||
|
|
|
@ -87,6 +87,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
|||
*pad = castIntToIndex(b, loc, *pad);
|
||||
|
||||
Type elementType = input.getType().cast<RankedTensorType>().getElementType();
|
||||
// TODO: audit possibility of sparsity on this tensor
|
||||
Type inputType =
|
||||
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
|
||||
SmallVector<int64_t>(inRank, kUnknownSize))),
|
||||
|
|
|
@ -467,7 +467,8 @@ TensorType ValueTensorType::toBuiltinTensor() const {
|
|||
Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype());
|
||||
if (!elementType)
|
||||
return nullptr;
|
||||
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType);
|
||||
return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType,
|
||||
getOptionalSparsity());
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
|
||||
|
||||
// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
|
||||
// CHECK-LABEL: func.func @sum(
|
||||
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[64,64],f32,#[[$CSR]]> -> tensor<64x64xf32, #[[$CSR]]>
|
||||
// CHECK: linalg.generic {{{.*}}} ins(%[[S]] : tensor<64x64xf32, #[[$CSR]]>)
|
||||
func.func @sum(%arg0: !torch.vtensor<[64,64],f32,#CSR>) -> !torch.vtensor<[],f32> {
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.sum %arg0, %none
|
||||
: !torch.vtensor<[64,64],f32,#CSR>, !torch.none -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
|
||||
|
||||
// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
|
||||
// CHECK-LABEL: func.func @SpMM(
|
||||
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,16],f32,#[[$CSR]]>,
|
||||
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32>
|
||||
// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[8,16],f32,#[[$CSR]]> -> tensor<8x16xf32, #[[$CSR]]>
|
||||
// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[B]] : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32>
|
||||
// CHECK: linalg.matmul ins(%[[S]], %[[T]] : tensor<8x16xf32, #[[$CSR]]>, tensor<16x8xf32>)
|
||||
func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>,
|
||||
%arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> {
|
||||
%0 = torch.aten.matmul %arg0, %arg1
|
||||
: !torch.vtensor<[8,16],f32,#CSR>,
|
||||
!torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
|
||||
return %0 : !torch.vtensor<[8,8],f32>
|
||||
}
|
Loading…
Reference in New Issue