mirror of https://github.com/llvm/torch-mlir
[onnx] Add support for constants of `i1`s (#2978)
`getRawBuffer` expects a densely packed vector of `i1` values however `onnx` does not densely pack the values. Include code to handle the packing / unpacking.pull/2980/head
parent
4d01b0f1a3
commit
933db87a07
|
@ -700,9 +700,19 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
}
|
||||
|
||||
auto ty = cast<ShapedType>(attr.getType());
|
||||
ElementsAttr denseAttr;
|
||||
auto ptr = attr.getRawHandle().getBlob()->getData();
|
||||
DenseElementsAttr denseAttr =
|
||||
DenseElementsAttr::getFromRawBuffer(ty, ptr);
|
||||
if (cast<ShapedType>(attr.getType()).getElementType().isInteger(1)) {
|
||||
llvm::SmallVector<APInt> newContents;
|
||||
for (auto val : ptr) {
|
||||
APInt apval(1, val);
|
||||
newContents.push_back(apval);
|
||||
}
|
||||
denseAttr = DenseElementsAttr::get(ty, newContents);
|
||||
} else {
|
||||
denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||
binder.op, resultType, denseAttr);
|
||||
return success();
|
||||
|
|
|
@ -1409,6 +1409,25 @@ func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : s
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @dense_constant_i1
|
||||
func.func @dense_constant_i1() -> !torch.vtensor<[5],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64} {
|
||||
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, false, true, true]> : tensor<5xi1>) : !torch.vtensor<[5],i1>
|
||||
// CHECK: return %[[CST]] : !torch.vtensor<[5],i1>
|
||||
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<5xi1>} : () -> !torch.vtensor<[5],i1>
|
||||
return %0 : !torch.vtensor<[5],i1>
|
||||
}
|
||||
|
||||
{-#
|
||||
dialect_resources: {
|
||||
builtin: {
|
||||
_: "0x080000000100000101"
|
||||
}
|
||||
}
|
||||
#-}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// CHECK-LABEL: @test_flatten_4d_axis_2
|
||||
func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
|
||||
|
|
Loading…
Reference in New Issue