From 4361178caafe567449906d70a543ee98167c39ac Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 25 Apr 2024 11:32:07 -0700 Subject: [PATCH] [torch-mlir][sparse] recognize sparse tensor conversion (#3226) Sparse tensor conversions are represented by special aten operators. This PR ensures the conversions are recognized (instead of failing the full torch aten lowering to linalg). --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 42 +++++++++++++++++++ .../TorchToLinalg/TorchToLinalg.cpp | 9 ++-- test/Conversion/TorchToLinalg/sparse.mlir | 32 ++++++++++++++ 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index a94f8882e..ad1a17b0a 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" @@ -2423,6 +2424,42 @@ public: }; } // namespace +namespace { +class ConvertSparseOperatorOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + static bool isSparsePrimitive(StringRef prim) { + return llvm::find(legalizedNames, prim) != legalizedNames.end(); + } + + // Rewriting method. + LogicalResult + matchAndRewrite(OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isSparsePrimitive(op.getNameAttr())) + return failure(); + // Conversion is completed specified by information in the sparse tensor + // type. Thus, we can rewrite all legalizedNames to the same construct. + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getOperands()[0]); + return success(); + } + +private: + // The operators that legalize to sparse tensor conversions. + static SmallVector legalizedNames; +}; +// Static initializer. +SmallVector ConvertSparseOperatorOp::legalizedNames = { + "torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc", + "torch.aten._to_bsr", "torch.aten._to_bsc", +}; +} // namespace + void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -2469,4 +2506,9 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + // Rewrite all special sparse conversions hidden as operators. + target.addDynamicallyLegalOp([&](Torch::OperatorOp op) { + return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr()); + }); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 1f9f4b17b..a4451041f 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -53,10 +54,10 @@ public: void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); - target.addLegalDialect(); + target.addLegalDialect< + linalg::LinalgDialect, func::FuncDialect, cf::ControlFlowDialect, + math::MathDialect, sparse_tensor::SparseTensorDialect, + tensor::TensorDialect, arith::ArithDialect, complex::ComplexDialect>(); target.addLegalOp(); TypeConverter typeConverter; diff --git a/test/Conversion/TorchToLinalg/sparse.mlir b/test/Conversion/TorchToLinalg/sparse.mlir index 5d952fde3..4dc580ea3 100644 --- a/test/Conversion/TorchToLinalg/sparse.mlir +++ b/test/Conversion/TorchToLinalg/sparse.mlir @@ -34,3 +34,35 @@ func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>, !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32> return %0 : !torch.vtensor<[8,8],f32> } + +// ----- + +#sparse = #sparse_tensor.encoding<{ + map = (d0, d1, d2, d3, d4) -> + (d0 : compressed(nonunique), + d1 : singleton(nonunique, soa), + d2 : singleton(nonunique, soa), + d3 : singleton(nonunique, soa), + d4 : singleton(soa) + ), + posWidth = 64, + crdWidth = 64 +}> + +// CHECK: #[[$ST:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3, d4) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(nonunique, soa), d3 : singleton(nonunique, soa), d4 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +// CHECK-LABEL: func.func @activate( +// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32>) +// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32> -> tensor<128x64x30x30x6xf32> +// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32> to tensor<128x64x30x30x6xf32, #[[$ST]]> +// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]> +// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]> +func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>) + -> !torch.vtensor<[128,64,30,30,6],f32,#sparse> { + %none_0 = torch.constant.none + %none_1 = torch.constant.none + %none_2 = torch.constant.none + %result = torch.operator "torch.aten._to_sparse"(%arg0, %none_0, %none_1, %none_2) + : (!torch.vtensor<[128,64,30,30,6],f32>, !torch.none, !torch.none, !torch.none) + -> !torch.vtensor<[128,64,30,30,6],f32,#sparse> + return %result : !torch.vtensor<[128,64,30,30,6],f32,#sparse> +}