mirror of https://github.com/llvm/torch-mlir
[onnx] Add support for constants of `i1`s (#2978)
`getRawBuffer` expects a densely packed vector of `i1` values however `onnx` does not densely pack the values. Include code to handle the packing / unpacking.pull/2980/head
parent
4d01b0f1a3
commit
933db87a07
|
@ -700,9 +700,19 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ty = cast<ShapedType>(attr.getType());
|
auto ty = cast<ShapedType>(attr.getType());
|
||||||
|
ElementsAttr denseAttr;
|
||||||
auto ptr = attr.getRawHandle().getBlob()->getData();
|
auto ptr = attr.getRawHandle().getBlob()->getData();
|
||||||
DenseElementsAttr denseAttr =
|
if (cast<ShapedType>(attr.getType()).getElementType().isInteger(1)) {
|
||||||
DenseElementsAttr::getFromRawBuffer(ty, ptr);
|
llvm::SmallVector<APInt> newContents;
|
||||||
|
for (auto val : ptr) {
|
||||||
|
APInt apval(1, val);
|
||||||
|
newContents.push_back(apval);
|
||||||
|
}
|
||||||
|
denseAttr = DenseElementsAttr::get(ty, newContents);
|
||||||
|
} else {
|
||||||
|
denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr);
|
||||||
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||||
binder.op, resultType, denseAttr);
|
binder.op, resultType, denseAttr);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -1409,6 +1409,25 @@ func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : s
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @dense_constant_i1
|
||||||
|
func.func @dense_constant_i1() -> !torch.vtensor<[5],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64} {
|
||||||
|
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, false, true, true]> : tensor<5xi1>) : !torch.vtensor<[5],i1>
|
||||||
|
// CHECK: return %[[CST]] : !torch.vtensor<[5],i1>
|
||||||
|
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<5xi1>} : () -> !torch.vtensor<[5],i1>
|
||||||
|
return %0 : !torch.vtensor<[5],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
{-#
|
||||||
|
dialect_resources: {
|
||||||
|
builtin: {
|
||||||
|
_: "0x080000000100000101"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#-}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: @test_flatten_4d_axis_2
|
// CHECK-LABEL: @test_flatten_4d_axis_2
|
||||||
func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
|
// CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
|
||||||
|
|
Loading…
Reference in New Issue