[torch-mlir][sparse] recognize to_dense primitive (#3308)

also maps simply to sparse_tensor.convert
the sparsity types do the rest!
pull/3241/head
Aart Bik 2024-05-08 22:50:17 -07:00 committed by GitHub
parent 89bb7404c1
commit a033bbfe6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 3 deletions

View File

@ -2451,8 +2451,8 @@ private:
}; };
// Static initializer. // Static initializer.
SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = { SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
"torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc", "torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr",
"torch.aten._to_bsr", "torch.aten._to_bsc", "torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc",
}; };
} // namespace } // namespace

View File

@ -54,7 +54,7 @@ func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>,
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32>) // CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32>)
// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32> -> tensor<128x64x30x30x6xf32> // CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32> -> tensor<128x64x30x30x6xf32>
// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32> to tensor<128x64x30x30x6xf32, #[[$ST]]> // CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32> to tensor<128x64x30x30x6xf32, #[[$ST]]>
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]> // CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]> -> !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>
// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]> // CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>
func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>) func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>)
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse> { -> !torch.vtensor<[128,64,30,30,6],f32,#sparse> {
@ -66,3 +66,35 @@ func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>)
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse> -> !torch.vtensor<[128,64,30,30,6],f32,#sparse>
return %result : !torch.vtensor<[128,64,30,30,6],f32,#sparse> return %result : !torch.vtensor<[128,64,30,30,6],f32,#sparse>
} }
// -----
#sparse = #sparse_tensor.encoding<{
map = (d0, d1, d2, d3, d4) ->
(d0 : compressed(nonunique),
d1 : singleton(nonunique, soa),
d2 : singleton(nonunique, soa),
d3 : singleton(nonunique, soa),
d4 : singleton(soa)
),
posWidth = 64,
crdWidth = 64
}>
// CHECK: #[[$ST:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3, d4) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(nonunique, soa), d3 : singleton(nonunique, soa), d4 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
// CHECK-LABEL: func.func @deactivate(
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>)
// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]> -> tensor<128x64x30x30x6xf32, #[[$ST]]>
// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32, #[[$ST]]> to tensor<128x64x30x30x6xf32>
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32> -> !torch.vtensor<[128,64,30,30,6],f32>
// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32>
func.func @deactivate(%arg0: !torch.vtensor<[128,64,30,30,6],f32,#sparse>)
-> !torch.vtensor<[128,64,30,30,6],f32> {
%none_0 = torch.constant.none
%none_1 = torch.constant.none
%none_2 = torch.constant.none
%result = torch.operator "torch.aten._to_dense"(%arg0, %none_0, %none_1, %none_2)
: (!torch.vtensor<[128,64,30,30,6],f32,#sparse>, !torch.none, !torch.none, !torch.none)
-> !torch.vtensor<[128,64,30,30,6],f32>
return %result : !torch.vtensor<[128,64,30,30,6],f32>
}