diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 1c356db89..52cd59e89 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -700,9 +700,19 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } auto ty = cast(attr.getType()); + ElementsAttr denseAttr; auto ptr = attr.getRawHandle().getBlob()->getData(); - DenseElementsAttr denseAttr = - DenseElementsAttr::getFromRawBuffer(ty, ptr); + if (cast(attr.getType()).getElementType().isInteger(1)) { + llvm::SmallVector newContents; + for (auto val : ptr) { + APInt apval(1, val); + newContents.push_back(apval); + } + denseAttr = DenseElementsAttr::get(ty, newContents); + } else { + denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr); + } + rewriter.replaceOpWithNewOp( binder.op, resultType, denseAttr); return success(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 7dc262228..7c465d74b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1409,6 +1409,25 @@ func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : s // ----- +// CHECK-LABEL: @dense_constant_i1 +func.func @dense_constant_i1() -> !torch.vtensor<[5],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, false, true, true]> : tensor<5xi1>) : !torch.vtensor<[5],i1> + // CHECK: return %[[CST]] : !torch.vtensor<[5],i1> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<5xi1>} : () -> !torch.vtensor<[5],i1> + return %0 : !torch.vtensor<[5],i1> +} + +{-# + dialect_resources: { + builtin: { + _: "0x080000000100000101" + } + } +#-} + +// ----- + + // CHECK-LABEL: @test_flatten_4d_axis_2 func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2