Make torch.constant.float print a little nicer.

This printing is chosen to be similar to how MLIR prints the values by
default.
pull/239/head
Sean Silva 2021-06-22 14:25:16 -07:00
parent 60a947b4a7
commit 90c6c64fd6
2 changed files with 34 additions and 4 deletions

View File

@ -674,7 +674,24 @@ OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef<Attribute> operands) {
void Torch::ConstantFloatOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "float");
// Calculate a stringified version of the number, compatible with MLIR
// identifier syntax. (in practice, this just removes the '+' from 'e+' in
// float string representation).
SmallVector<char> buf;
value().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
/*TruncateZero=*/false);
auto isValidMLIRIdentifierChar = [](char c) {
return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' ||
c == '-';
};
auto numberStr = llvm::to_vector<16>(
llvm::make_filter_range(buf, isValidMLIRIdentifierChar));
// Construct the identifier string.
buf.clear();
llvm::append_range(buf, StringRef("float"));
llvm::append_range(buf, numberStr);
setNameFn(getResult(), StringRef(buf.data(), buf.size()));
}
//===----------------------------------------------------------------------===//

View File

@ -84,10 +84,23 @@ func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
%true = torch.constant.bool true
// CHECK: %false = torch.constant.bool false
%false = torch.constant.bool false
// CHECK: %int3 = torch.constant.int 3
%int3 = torch.constant.int 3
// CHECK: %float = torch.constant.float 4.250000e+01
%float = torch.constant.float 4.250000e+01
// CHECK: %int-3 = torch.constant.int -3
%int-3 = torch.constant.int -3
// CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00
%float1.000000e00 = torch.constant.float 1.000000e+00
// CHECK: %float-1.000000e00 = torch.constant.float -1.000000e+00
%float-1.000000e00 = torch.constant.float -1.000000e+00
// CHECK: %float1.000000e-10 = torch.constant.float 1.000000e-10
%float1.000000e-10 = torch.constant.float 1.000000e-10
// CHECK: %float1.000000e10 = torch.constant.float 1.000000e+10
%float1.000000e10 = torch.constant.float 1.000000e+10
// CHECK: %float4.250000e01 = torch.constant.float 4.250000e+01
%float4.250000e01 = torch.constant.float 4.250000e+01
%tensor = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
// CHECK: %none = torch.constant.none
%none = torch.constant.none
@ -113,7 +126,7 @@ torch.class_type @test {
torch.nn_module {
torch.slot "b", %true : !torch.bool
torch.slot "i", %int3 : !torch.int
torch.slot "f", %float : !torch.float
torch.slot "f", %float1.000000e00 : !torch.float
torch.slot "t", %tensor : !torch.tensor
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
torch.slot "ob", %none : !torch.none