mirror of https://github.com/llvm/torch-mlir
[torch-mlir][sparse] recognize to_dense primitive (#3308)
also maps simply to sparse_tensor.convert the sparsity types do the rest!pull/3241/head
parent
89bb7404c1
commit
a033bbfe6c
|
@ -2451,8 +2451,8 @@ private:
|
||||||
};
|
};
|
||||||
// Static initializer.
|
// Static initializer.
|
||||||
SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
|
SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
|
||||||
"torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc",
|
"torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr",
|
||||||
"torch.aten._to_bsr", "torch.aten._to_bsc",
|
"torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc",
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>,
|
||||||
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32>)
|
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32>)
|
||||||
// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32> -> tensor<128x64x30x30x6xf32>
|
// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32> -> tensor<128x64x30x30x6xf32>
|
||||||
// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32> to tensor<128x64x30x30x6xf32, #[[$ST]]>
|
// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32> to tensor<128x64x30x30x6xf32, #[[$ST]]>
|
||||||
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]>
|
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]> -> !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>
|
||||||
// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>
|
// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>
|
||||||
func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>)
|
func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>)
|
||||||
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse> {
|
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse> {
|
||||||
|
@ -66,3 +66,35 @@ func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>)
|
||||||
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse>
|
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse>
|
||||||
return %result : !torch.vtensor<[128,64,30,30,6],f32,#sparse>
|
return %result : !torch.vtensor<[128,64,30,30,6],f32,#sparse>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#sparse = #sparse_tensor.encoding<{
|
||||||
|
map = (d0, d1, d2, d3, d4) ->
|
||||||
|
(d0 : compressed(nonunique),
|
||||||
|
d1 : singleton(nonunique, soa),
|
||||||
|
d2 : singleton(nonunique, soa),
|
||||||
|
d3 : singleton(nonunique, soa),
|
||||||
|
d4 : singleton(soa)
|
||||||
|
),
|
||||||
|
posWidth = 64,
|
||||||
|
crdWidth = 64
|
||||||
|
}>
|
||||||
|
|
||||||
|
// CHECK: #[[$ST:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3, d4) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(nonunique, soa), d3 : singleton(nonunique, soa), d4 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
|
||||||
|
// CHECK-LABEL: func.func @deactivate(
|
||||||
|
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>)
|
||||||
|
// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]> -> tensor<128x64x30x30x6xf32, #[[$ST]]>
|
||||||
|
// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32, #[[$ST]]> to tensor<128x64x30x30x6xf32>
|
||||||
|
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32> -> !torch.vtensor<[128,64,30,30,6],f32>
|
||||||
|
// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32>
|
||||||
|
func.func @deactivate(%arg0: !torch.vtensor<[128,64,30,30,6],f32,#sparse>)
|
||||||
|
-> !torch.vtensor<[128,64,30,30,6],f32> {
|
||||||
|
%none_0 = torch.constant.none
|
||||||
|
%none_1 = torch.constant.none
|
||||||
|
%none_2 = torch.constant.none
|
||||||
|
%result = torch.operator "torch.aten._to_dense"(%arg0, %none_0, %none_1, %none_2)
|
||||||
|
: (!torch.vtensor<[128,64,30,30,6],f32,#sparse>, !torch.none, !torch.none, !torch.none)
|
||||||
|
-> !torch.vtensor<[128,64,30,30,6],f32>
|
||||||
|
return %result : !torch.vtensor<[128,64,30,30,6],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue