diff options
Diffstat (limited to 'resources/libraries/python/MLRsearch/MultipleLossRatioSearch.py')
-rw-r--r-- | resources/libraries/python/MLRsearch/MultipleLossRatioSearch.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/resources/libraries/python/MLRsearch/MultipleLossRatioSearch.py b/resources/libraries/python/MLRsearch/MultipleLossRatioSearch.py index dd21444496..0e6c8cfa58 100644 --- a/resources/libraries/python/MLRsearch/MultipleLossRatioSearch.py +++ b/resources/libraries/python/MLRsearch/MultipleLossRatioSearch.py @@ -330,6 +330,7 @@ class MultipleLossRatioSearch: cur_lo1, cur_hi1, pre_lo, pre_hi, cur_lo2, cur_hi2 = bounds pre_lo_improves = self.improves(pre_lo, cur_lo1, cur_hi1) pre_hi_improves = self.improves(pre_hi, cur_lo1, cur_hi1) + # TODO: Detect also the other case for initial bisect, see below. if pre_lo_improves and pre_hi_improves: # We allowed larger width for previous phase # as single bisect here guarantees only one re-measurement. @@ -342,6 +343,10 @@ class MultipleLossRatioSearch: self.debug(f"Re-measuring lower bound for {ratio}, tr: {new_tr}") return new_tr if pre_hi_improves: + # This can also happen when we did not do initial bisect + # for this ratio yet, but the previous duration lower bound + # for this ratio got already re-measured as previous duration + # upper bound for previous ratio. new_tr = pre_hi.target_tr self.debug(f"Re-measuring upper bound for {ratio}, tr: {new_tr}") return new_tr @@ -397,7 +402,7 @@ class MultipleLossRatioSearch: If no second tightest (nor previous) upper bound is available, the behavior is governed by second_needed argument. - If true, return None, if false, start from width goal. + If true, return None. If false, start from width goal. This is useful, as if a bisect is possible, we want to give it a chance. @@ -414,6 +419,9 @@ class MultipleLossRatioSearch: """ state = self.state old_tr = cur_hi1.target_tr + if state.min_rate >= old_tr: + self.debug(u"Extend down hits min rate.") + return None next_bound = cur_hi2 if self.improves(pre_hi, cur_hi1, cur_hi2): next_bound = pre_hi @@ -427,9 +435,6 @@ class MultipleLossRatioSearch: old_tr, old_width, self.expansion_coefficient ) new_tr = max(new_tr, state.min_rate) - if new_tr >= old_tr: - self.debug(u"Extend down hits max rate.") - return None return new_tr def _extend_up(self, cur_lo1, cur_lo2, pre_lo): @@ -446,6 +451,9 @@ class MultipleLossRatioSearch: """ state = self.state old_tr = cur_lo1.target_tr + if state.max_rate <= old_tr: + self.debug(u"Extend up hits max rate.") + return None next_bound = cur_lo2 if self.improves(pre_lo, cur_lo2, cur_lo1): next_bound = pre_lo @@ -455,9 +463,6 @@ class MultipleLossRatioSearch: old_width = max(old_width, state.width_goal) new_tr = multiple_step_up(old_tr, old_width, self.expansion_coefficient) new_tr = min(new_tr, state.max_rate) - if new_tr <= old_tr: - self.debug(u"Extend up hits max rate.") - return None return new_tr def _bisect(self, lower_bound, upper_bound): |