mirror of https://github.com/llvm/torch-mlir
Make torch.constant.float print a little nicer.
This printing is chosen to be similar to how MLIR prints the values by default.pull/239/head
parent
60a947b4a7
commit
90c6c64fd6
|
@ -674,7 +674,24 @@ OpFoldResult Torch::ConstantFloatOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
void Torch::ConstantFloatOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "float");
|
||||
// Calculate a stringified version of the number, compatible with MLIR
|
||||
// identifier syntax. (in practice, this just removes the '+' from 'e+' in
|
||||
// float string representation).
|
||||
SmallVector<char> buf;
|
||||
value().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
|
||||
/*TruncateZero=*/false);
|
||||
auto isValidMLIRIdentifierChar = [](char c) {
|
||||
return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' ||
|
||||
c == '-';
|
||||
};
|
||||
auto numberStr = llvm::to_vector<16>(
|
||||
llvm::make_filter_range(buf, isValidMLIRIdentifierChar));
|
||||
|
||||
// Construct the identifier string.
|
||||
buf.clear();
|
||||
llvm::append_range(buf, StringRef("float"));
|
||||
llvm::append_range(buf, numberStr);
|
||||
setNameFn(getResult(), StringRef(buf.data(), buf.size()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -84,10 +84,23 @@ func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
|
|||
%true = torch.constant.bool true
|
||||
// CHECK: %false = torch.constant.bool false
|
||||
%false = torch.constant.bool false
|
||||
|
||||
// CHECK: %int3 = torch.constant.int 3
|
||||
%int3 = torch.constant.int 3
|
||||
// CHECK: %float = torch.constant.float 4.250000e+01
|
||||
%float = torch.constant.float 4.250000e+01
|
||||
// CHECK: %int-3 = torch.constant.int -3
|
||||
%int-3 = torch.constant.int -3
|
||||
|
||||
// CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00
|
||||
%float1.000000e00 = torch.constant.float 1.000000e+00
|
||||
// CHECK: %float-1.000000e00 = torch.constant.float -1.000000e+00
|
||||
%float-1.000000e00 = torch.constant.float -1.000000e+00
|
||||
// CHECK: %float1.000000e-10 = torch.constant.float 1.000000e-10
|
||||
%float1.000000e-10 = torch.constant.float 1.000000e-10
|
||||
// CHECK: %float1.000000e10 = torch.constant.float 1.000000e+10
|
||||
%float1.000000e10 = torch.constant.float 1.000000e+10
|
||||
// CHECK: %float4.250000e01 = torch.constant.float 4.250000e+01
|
||||
%float4.250000e01 = torch.constant.float 4.250000e+01
|
||||
|
||||
%tensor = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
// CHECK: %none = torch.constant.none
|
||||
%none = torch.constant.none
|
||||
|
@ -113,7 +126,7 @@ torch.class_type @test {
|
|||
torch.nn_module {
|
||||
torch.slot "b", %true : !torch.bool
|
||||
torch.slot "i", %int3 : !torch.int
|
||||
torch.slot "f", %float : !torch.float
|
||||
torch.slot "f", %float1.000000e00 : !torch.float
|
||||
torch.slot "t", %tensor : !torch.tensor
|
||||
torch.slot "submodule", %submodule : !torch.nn.Module<"empty">
|
||||
torch.slot "ob", %none : !torch.none
|
||||
|
|
Loading…
Reference in New Issue