mirror of https://github.com/llvm/torch-mlir
105 lines
4.5 KiB
Python
105 lines
4.5 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
import torch
|
|
|
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|
|
|
|
|
# ==============================================================================
|
|
# Global parameters
|
|
NUM_SEGMENTS = 42
|
|
NUM_BINS = 5000
|
|
NUM_LOGITS = 5000
|
|
|
|
class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._num_segments = NUM_SEGMENTS
|
|
self._num_bins = NUM_BINS
|
|
self._num_logits = NUM_LOGITS
|
|
_num_interval = (self._num_segments + 1) * self._num_bins
|
|
_lower_bound = 0
|
|
_upper_bound = 1
|
|
l, u = _lower_bound, _upper_bound
|
|
w = (u - l) / self._num_bins
|
|
self.step = w
|
|
self.register_buffer("_boundaries", torch.arange(l + w, u - w / 2, w))
|
|
self.register_buffer(
|
|
"_bin_num_examples",
|
|
torch.empty([_num_interval], dtype=torch.float64).fill_(0.0),
|
|
)
|
|
self.register_buffer(
|
|
"_bin_num_positives",
|
|
torch.empty([_num_interval], dtype=torch.float64).fill_(0.0),
|
|
)
|
|
self.register_buffer("_bin_ids", torch.arange(_num_interval))
|
|
self.positive_weight = torch.tensor([0.4])
|
|
self.bin_ctr_in_use_after = 0
|
|
self.bin_ctr_weight_value = 0.9995
|
|
self.oneminusbin_ctr_weight_value = 0.0005
|
|
self._iteration = 0
|
|
|
|
@export
|
|
@annotate_args([
|
|
None,
|
|
([-1], torch.int32, True),
|
|
([-1], torch.int32, True),
|
|
([-1], torch.float32, True),
|
|
])
|
|
def forward(self, segment_value, segment_lengths, logit):
|
|
origin_prediction = torch.sigmoid(
|
|
logit + torch.log(self.positive_weight))
|
|
dense_segment_value = torch.zeros(logit.numel(), dtype=torch.int32)
|
|
validoffsets = torch.gt(
|
|
segment_lengths[1:self._num_logits+1], segment_lengths[0:self._num_logits])
|
|
gathered_segment_values = (
|
|
segment_value[segment_lengths[0:self._num_logits].long()]+1).int()
|
|
dense_segment_value = torch.where(
|
|
validoffsets, gathered_segment_values, dense_segment_value)
|
|
zeros = torch.empty_like(
|
|
dense_segment_value, dtype=torch.int32).fill_(0)
|
|
isnotvalid = torch.gt(dense_segment_value, self._num_segments)
|
|
dense_segment_value = torch.where(
|
|
isnotvalid, zeros, dense_segment_value)
|
|
bin_ids_data = torch.ceil(origin_prediction/self.step)-1
|
|
bin_ids_data = bin_ids_data.long()
|
|
curr_segment_value = dense_segment_value * self._num_bins
|
|
bin_ids_data2 = bin_ids_data
|
|
bin_ids_data = bin_ids_data + curr_segment_value
|
|
curr_segment_value = self._bin_num_positives[bin_ids_data]
|
|
curr_bin_num_examples = self._bin_num_examples[bin_ids_data]
|
|
curr_segment_value = curr_segment_value / curr_bin_num_examples
|
|
curr_segment_value = curr_segment_value.float()
|
|
curr_segment_value = curr_segment_value * self.bin_ctr_weight_value + \
|
|
origin_prediction * self.oneminusbin_ctr_weight_value
|
|
isvalid = torch.gt(curr_bin_num_examples,
|
|
self.bin_ctr_in_use_after)
|
|
calibrated_prediction_data = torch.where(
|
|
isvalid, curr_segment_value, origin_prediction.float())
|
|
return calibrated_prediction_data, bin_ids_data
|
|
|
|
|
|
@register_test_case(module_factory=lambda: HistogramBinningCalibrationByFeature())
|
|
def HBC_basic(module, tu: TestUtils):
|
|
logits = torch.rand(NUM_LOGITS, dtype=torch.float)
|
|
segment_lengths: Tensor = torch.randint(
|
|
0, 2, (NUM_LOGITS,), dtype=torch.int)
|
|
segment_offsets: Tensor = torch.cumsum(segment_lengths, 0)
|
|
segment_offsets: Tensor = torch.cat(
|
|
(torch.tensor([0]), segment_offsets), 0)
|
|
num_values: int = int(torch.sum(segment_lengths).item())
|
|
segment_values: Tensor = torch.randint(
|
|
0,
|
|
NUM_SEGMENTS,
|
|
(num_values,),
|
|
)
|
|
segment_values = torch.cat(
|
|
(segment_values, torch.zeros(NUM_LOGITS-segment_values.numel())), 0)
|
|
module.forward(segment_values.int(), segment_offsets.int(), logits)
|
|
#input shape (5000, 5001, 5000)
|