From 90c6c64fd61b31cc1f023a78a956168bb45c8e58 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 22 Jun 2021 14:25:16 -0700 Subject: [PATCH] Make torch.constant.float print a little nicer. This printing is chosen to be similar to how MLIR prints the values by default. --- lib/Dialect/Torch/IR/TorchOps.cpp | 19 ++++++++++++++++++- test/Dialect/Torch/ops.mlir | 19 ++++++++++++++++--- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 00ddce894..13e0952fe 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -674,7 +674,24 @@ OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef operands) { void Torch::ConstantFloatOp::getAsmResultNames( function_ref setNameFn) { - setNameFn(getResult(), "float"); + // Calculate a stringified version of the number, compatible with MLIR + // identifier syntax. (in practice, this just removes the '+' from 'e+' in + // float string representation). + SmallVector buf; + value().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, + /*TruncateZero=*/false); + auto isValidMLIRIdentifierChar = [](char c) { + return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' || + c == '-'; + }; + auto numberStr = llvm::to_vector<16>( + llvm::make_filter_range(buf, isValidMLIRIdentifierChar)); + + // Construct the identifier string. + buf.clear(); + llvm::append_range(buf, StringRef("float")); + llvm::append_range(buf, numberStr); + setNameFn(getResult(), StringRef(buf.data(), buf.size())); } //===----------------------------------------------------------------------===// diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index a1074ab55..e718e84e4 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -84,10 +84,23 @@ func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int { %true = torch.constant.bool true // CHECK: %false = torch.constant.bool false %false = torch.constant.bool false + // CHECK: %int3 = torch.constant.int 3 %int3 = torch.constant.int 3 -// CHECK: %float = torch.constant.float 4.250000e+01 -%float = torch.constant.float 4.250000e+01 +// CHECK: %int-3 = torch.constant.int -3 +%int-3 = torch.constant.int -3 + +// CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00 +%float1.000000e00 = torch.constant.float 1.000000e+00 +// CHECK: %float-1.000000e00 = torch.constant.float -1.000000e+00 +%float-1.000000e00 = torch.constant.float -1.000000e+00 +// CHECK: %float1.000000e-10 = torch.constant.float 1.000000e-10 +%float1.000000e-10 = torch.constant.float 1.000000e-10 +// CHECK: %float1.000000e10 = torch.constant.float 1.000000e+10 +%float1.000000e10 = torch.constant.float 1.000000e+10 +// CHECK: %float4.250000e01 = torch.constant.float 4.250000e+01 +%float4.250000e01 = torch.constant.float 4.250000e+01 + %tensor = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor // CHECK: %none = torch.constant.none %none = torch.constant.none @@ -113,7 +126,7 @@ torch.class_type @test { torch.nn_module { torch.slot "b", %true : !torch.bool torch.slot "i", %int3 : !torch.int - torch.slot "f", %float : !torch.float + torch.slot "f", %float1.000000e00 : !torch.float torch.slot "t", %tensor : !torch.tensor torch.slot "submodule", %submodule : !torch.nn.Module<"empty"> torch.slot "ob", %none : !torch.none