From 197b3b475c2fa4c452f08c79f2cab1c7482d6ccc Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 15 Jan 2024 09:31:22 -0800 Subject: [PATCH] [onnx] Convert `onnx.constant` to `torch` literal tensor (#2748) Handles the multiple cases of `onnx` constant values and converts them to `torch` literal tensors. This can include splats with a single integer or floating point value, a set of explicit integer values, or an elements array attr of values. --- .../Conversion/TorchOnnxToTorch/Patterns.h | 13 +++++ .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 53 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 41 ++++++++++++++ 3 files changed, 107 insertions(+) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index b6189b375..44e33ab09 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -190,6 +190,19 @@ struct OpBinder { return failure(); } + ParseResult denseElementsAttr(ElementsAttr elementsattr, + StringRef nameSuffix) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + Attribute attr = op->getAttr(name); + if (!attr || !isa(attr)) { + return failure(); + } + + elementsattr = cast(attr); + return success(); + } + ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix, std::string defaultValue = "") { SmallString<64> name("torch.onnx."); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 2cfc9940a..aa3b5fc01 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -590,6 +590,59 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( tensorList, cstDim); return success(); }); + patterns.onOp( + "Constant", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + if (binder.tensorResultType(resultType)) + return failure(); + auto dtype = resultType.getDtype(); + Value scalarValue; + + float floatValue; + if (binder.op->hasAttr("torch.onnx.value_float") && + !binder.f32FloatAttr(floatValue, "value_float", 0.0)) { + auto splatAttr = + SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + rewriter.getFloatAttr(dtype, floatValue)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, splatAttr); + return success(); + } + + int64_t intValue; + if (binder.op->hasAttr("torch.onnx.value_int") && + !binder.s64IntegerAttr(intValue, "value_int", 0)) { + auto splatAttr = + SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + rewriter.getIntegerAttr(dtype, intValue)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, splatAttr); + return success(); + } + + if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value") + .dyn_cast_or_null()) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, attr); + return success(); + } + + llvm::SmallVector intValues; + if (!binder.s64IntegerArrayAttr(intValues, "value_ints", {}) && + !intValues.empty()) { + llvm::SmallVector apValues; + for (auto intVal : intValues) { + apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal)); + } + auto attr = DenseElementsAttr::get( + resultType.toBuiltinTensor().clone(dtype), apValues); + rewriter.replaceOpWithNewOp( + binder.op, resultType, attr); + return success(); + } + + return failure(); + }); patterns.onOp( "Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index fc9706127..f8bc219dc 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -979,3 +979,44 @@ func.func @test_depthtospace_crd_mode_example(%arg0: !torch.vtensor<[1,8,2,3],f3 %0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "CRD"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> return %0 : !torch.vtensor<[1,2,4,6],f32> } + +// ----- + +// CHECK-LABEL: @float_constant +func.func @float_constant() -> !torch.vtensor<[], f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<2.500000e-01> : tensor) : !torch.vtensor<[],f32> + // CHECK: return %[[CST]] + %0 = torch.operator "onnx.Constant"() {torch.onnx.value_float = 0.25 : f32} : () -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: @int_constant +func.func @int_constant() -> !torch.vtensor<[], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<79> : tensor) : !torch.vtensor<[],si64> + // CHECK: return %[[CST]] + %0 = torch.operator "onnx.Constant"() {torch.onnx.value_int = 79 : si64} : () -> !torch.vtensor<[],si64> + return %0 : !torch.vtensor<[],si64> +} + +// ----- + +// CHECK-LABEL: @dense_constant +func.func @dense_constant() -> !torch.vtensor<[1], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<13> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: return %[[CST]] + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<13> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// ----- + +// CHECK-LABEL: @ints_constant +func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[7, 9]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + // CHECK: return %[[CST]] + %0 = "torch.operator"() <{name = "onnx.Constant"}> {torch.onnx.value_ints = [7 : si64, 9 : si64]} : () -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} +