aboutsummaryrefslogtreecommitdiffstats
path: root/PyPI/jumpavg/jumpavg/BitCountingClassifier.py
blob: 9a723199d25a8a9e59c59bbeada66df5bec2775e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright (c) 2018 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module holding BitCountingClassifier class.

This is the main class to be used by callers."""

from AbstractGroupClassifier import AbstractGroupClassifier
from BitCountingGroup import BitCountingGroup
from BitCountingGroupList import BitCountingGroupList
from BitCountingMetadataFactory import BitCountingMetadataFactory
from ClassifiedMetadataFactory import ClassifiedMetadataFactory


class BitCountingClassifier(AbstractGroupClassifier):
    """Classifier using Minimal Description Length principle."""

    def classify(self, values):
        """Return the values in groups of optimal bit count.

        The current implementation could be a static method,
        but we might support options in later versions,
        for example for chosing encodings.

        :param values: Sequence of runs to classify.
        :type values: Iterable of float or of AvgStdevMetadata
        :returns: Classified group list.
        :rtype: BitCountingGroupList
        """
        max_value = BitCountingMetadataFactory.find_max_value(values)
        factory = BitCountingMetadataFactory(max_value)
        opened_at = []
        closed_before = [BitCountingGroupList()]
        for index, value in enumerate(values):
            singleton = BitCountingGroup(factory, [value])
            newly_opened = closed_before[index].with_group_appended(singleton)
            opened_at.append(newly_opened)
            record_group_list = newly_opened
            for previous in range(index):
                previous_opened_list = opened_at[previous]
                still_opened = (
                    previous_opened_list.with_value_added_to_last_group(value))
                opened_at[previous] = still_opened
                if still_opened.bits < record_group_list.bits:
                    record_group_list = still_opened
            closed_before.append(record_group_list)
        partition = closed_before[-1]
        previous_average = partition[0].metadata.avg
        for group in partition:
            if group.metadata.avg == previous_average:
                group.metadata = ClassifiedMetadataFactory.with_classification(
                    group.metadata, "normal")
            elif group.metadata.avg < previous_average:
                group.metadata = ClassifiedMetadataFactory.with_classification(
                    group.metadata, "regression")
            elif group.metadata.avg > previous_average:
                group.metadata = ClassifiedMetadataFactory.with_classification(
                    group.metadata, "progression")
            previous_average = group.metadata.avg
        return partition