mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] recognize sparse tensor conversion (#3226)
Sparse tensor conversions are represented by special aten operators. This PR ensures the conversions are recognized (instead of failing the full torch aten lowering to linalg).pull/3238/head
parent
9e2fe47c5d
commit
4361178caa
|
@ -20,6 +20,7 @@
|
|||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
||||
|
@ -2423,6 +2424,42 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
static bool isSparsePrimitive(StringRef prim) {
|
||||
return llvm::find(legalizedNames, prim) != legalizedNames.end();
|
||||
}
|
||||
|
||||
// Rewriting method.
|
||||
LogicalResult
|
||||
matchAndRewrite(OperatorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isSparsePrimitive(op.getNameAttr()))
|
||||
return failure();
|
||||
// Conversion is completed specified by information in the sparse tensor
|
||||
// type. Thus, we can rewrite all legalizedNames to the same construct.
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
|
||||
op, resultType, adaptor.getOperands()[0]);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// The operators that legalize to sparse tensor conversions.
|
||||
static SmallVector<StringRef> legalizedNames;
|
||||
};
|
||||
// Static initializer.
|
||||
SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
|
||||
"torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc",
|
||||
"torch.aten._to_bsr", "torch.aten._to_bsc",
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
|
@ -2469,4 +2506,9 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
patterns.add<ConvertAtenDiagonalOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenDiagEmbedOp>();
|
||||
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
|
||||
// Rewrite all special sparse conversions hidden as operators.
|
||||
target.addDynamicallyLegalOp<OperatorOp>([&](Torch::OperatorOp op) {
|
||||
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr());
|
||||
});
|
||||
patterns.add<ConvertSparseOperatorOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -53,10 +54,10 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
||||
cf::ControlFlowDialect, math::MathDialect,
|
||||
tensor::TensorDialect, arith::ArithDialect,
|
||||
complex::ComplexDialect>();
|
||||
target.addLegalDialect<
|
||||
linalg::LinalgDialect, func::FuncDialect, cf::ControlFlowDialect,
|
||||
math::MathDialect, sparse_tensor::SparseTensorDialect,
|
||||
tensor::TensorDialect, arith::ArithDialect, complex::ComplexDialect>();
|
||||
target.addLegalOp<TorchConversion::GetNextSeedOp>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
|
|
|
@ -34,3 +34,35 @@ func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>,
|
|||
!torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
|
||||
return %0 : !torch.vtensor<[8,8],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#sparse = #sparse_tensor.encoding<{
|
||||
map = (d0, d1, d2, d3, d4) ->
|
||||
(d0 : compressed(nonunique),
|
||||
d1 : singleton(nonunique, soa),
|
||||
d2 : singleton(nonunique, soa),
|
||||
d3 : singleton(nonunique, soa),
|
||||
d4 : singleton(soa)
|
||||
),
|
||||
posWidth = 64,
|
||||
crdWidth = 64
|
||||
}>
|
||||
|
||||
// CHECK: #[[$ST:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3, d4) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(nonunique, soa), d3 : singleton(nonunique, soa), d4 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||
// CHECK-LABEL: func.func @activate(
|
||||
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32>)
|
||||
// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32> -> tensor<128x64x30x30x6xf32>
|
||||
// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32> to tensor<128x64x30x30x6xf32, #[[$ST]]>
|
||||
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]>
|
||||
// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>
|
||||
func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>)
|
||||
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse> {
|
||||
%none_0 = torch.constant.none
|
||||
%none_1 = torch.constant.none
|
||||
%none_2 = torch.constant.none
|
||||
%result = torch.operator "torch.aten._to_sparse"(%arg0, %none_0, %none_1, %none_2)
|
||||
: (!torch.vtensor<[128,64,30,30,6],f32>, !torch.none, !torch.none, !torch.none)
|
||||
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse>
|
||||
return %result : !torch.vtensor<[128,64,30,30,6],f32,#sparse>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue