[onnx] Add support for constants of `i1`s (#2978)

`getRawBuffer` expects a densely packed vector of `i1` values however
`onnx` does not densely pack the values. Include code to handle the
packing / unpacking.
pull/2980/head
Rob Suderman 2024-03-05 13:55:13 -08:00 committed by GitHub
parent 4d01b0f1a3
commit 933db87a07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 2 deletions

View File

@ -700,9 +700,19 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
}
auto ty = cast<ShapedType>(attr.getType());
ElementsAttr denseAttr;
auto ptr = attr.getRawHandle().getBlob()->getData();
DenseElementsAttr denseAttr =
DenseElementsAttr::getFromRawBuffer(ty, ptr);
if (cast<ShapedType>(attr.getType()).getElementType().isInteger(1)) {
llvm::SmallVector<APInt> newContents;
for (auto val : ptr) {
APInt apval(1, val);
newContents.push_back(apval);
}
denseAttr = DenseElementsAttr::get(ty, newContents);
} else {
denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr);
}
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, denseAttr);
return success();

View File

@ -1409,6 +1409,25 @@ func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : s
// -----
// CHECK-LABEL: @dense_constant_i1
func.func @dense_constant_i1() -> !torch.vtensor<[5],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64} {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, false, true, true]> : tensor<5xi1>) : !torch.vtensor<[5],i1>
// CHECK: return %[[CST]] : !torch.vtensor<[5],i1>
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<5xi1>} : () -> !torch.vtensor<[5],i1>
return %0 : !torch.vtensor<[5],i1>
}
{-#
dialect_resources: {
builtin: {
_: "0x080000000100000101"
}
}
#-}
// -----
// CHECK-LABEL: @test_flatten_4d_axis_2
func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2