[torch-mlir][sparse] recognize sparse tensor conversion (#3226)

Sparse tensor conversions are represented by special aten operators.
This PR ensures the conversions are recognized (instead of failing the
full torch aten lowering to linalg).
pull/2983/merge
Aart Bik 2024-04-25 11:32:07 -07:00 committed by GitHub
parent 9e2fe47c5d
commit 4361178caa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 4 deletions

View File

@ -20,6 +20,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
@ -2423,6 +2424,42 @@ public:
};
} // namespace
namespace {
class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> {
public:
using OpConversionPattern::OpConversionPattern;
static bool isSparsePrimitive(StringRef prim) {
return llvm::find(legalizedNames, prim) != legalizedNames.end();
}
// Rewriting method.
LogicalResult
matchAndRewrite(OperatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isSparsePrimitive(op.getNameAttr()))
return failure();
// Conversion is completed specified by information in the sparse tensor
// type. Thus, we can rewrite all legalizedNames to the same construct.
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
op, resultType, adaptor.getOperands()[0]);
return success();
}
private:
// The operators that legalize to sparse tensor conversions.
static SmallVector<StringRef> legalizedNames;
};
// Static initializer.
SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
"torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc",
"torch.aten._to_bsr", "torch.aten._to_bsc",
};
} // namespace
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
@ -2469,4 +2506,9 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenDiagonalOp>(typeConverter, context);
target.addIllegalOp<AtenDiagEmbedOp>();
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
// Rewrite all special sparse conversions hidden as operators.
target.addDynamicallyLegalOp<OperatorOp>([&](Torch::OperatorOp op) {
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr());
});
patterns.add<ConvertSparseOperatorOp>(typeConverter, context);
}

View File

@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -53,10 +54,10 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
cf::ControlFlowDialect, math::MathDialect,
tensor::TensorDialect, arith::ArithDialect,
complex::ComplexDialect>();
target.addLegalDialect<
linalg::LinalgDialect, func::FuncDialect, cf::ControlFlowDialect,
math::MathDialect, sparse_tensor::SparseTensorDialect,
tensor::TensorDialect, arith::ArithDialect, complex::ComplexDialect>();
target.addLegalOp<TorchConversion::GetNextSeedOp>();
TypeConverter typeConverter;

View File

@ -34,3 +34,35 @@ func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>,
!torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
return %0 : !torch.vtensor<[8,8],f32>
}
// -----
#sparse = #sparse_tensor.encoding<{
map = (d0, d1, d2, d3, d4) ->
(d0 : compressed(nonunique),
d1 : singleton(nonunique, soa),
d2 : singleton(nonunique, soa),
d3 : singleton(nonunique, soa),
d4 : singleton(soa)
),
posWidth = 64,
crdWidth = 64
}>
// CHECK: #[[$ST:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3, d4) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(nonunique, soa), d3 : singleton(nonunique, soa), d4 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
// CHECK-LABEL: func.func @activate(
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32>)
// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32> -> tensor<128x64x30x30x6xf32>
// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32> to tensor<128x64x30x30x6xf32, #[[$ST]]>
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]>
// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>
func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>)
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse> {
%none_0 = torch.constant.none
%none_1 = torch.constant.none
%none_2 = torch.constant.none
%result = torch.operator "torch.aten._to_sparse"(%arg0, %none_0, %none_1, %none_2)
: (!torch.vtensor<[128,64,30,30,6],f32>, !torch.none, !torch.none, !torch.none)
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse>
return %result : !torch.vtensor<[128,64,30,30,6],f32,#sparse>
}