aboutsummaryrefslogtreecommitdiffstats
path: root/netmodel/model/result_value.py
blob: 1812d5c40498b1cf8d228d6c480d3da394ddedd1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2017 Cisco and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pprint
import time
import traceback

from netmodel.network.packet       import ErrorPacket
from netmodel.model.query          import Query as Record

# type
SUCCESS     = 0
WARNING     = 1
ERROR       = 2

# origin
CORE        = 0
GATEWAY     = 1

class ResultValue(dict):

    ALLOWED_FIELDS = set(["origin", "type", "code", "value", "description", 
            "traceback", "ts"])

    def __init__(self, *args, **kwargs):
        if args:
            if kwargs:
                raise Exception("Bad initialization for ResultValue")

            if len(args) == 1 and isinstance(args[0], dict):
                kwargs = args[0]

        given = set(kwargs.keys())
        cstr_success = set(["code", "origin", "value"]) <= given
        cstr_error   = set(["code", "type", "origin", "description"]) <= given
        assert given <= self.ALLOWED_FIELDS, \
                "Wrong fields in ResultValue constructor: %r" % \
                    (given - self.ALLOWED_FIELDS)
        assert cstr_success or cstr_error, \
            "Incomplete set of fields in ResultValue constructor: %r" % given

        dict.__init__(self, **kwargs)

        # Set missing fields to None
        for field in self.ALLOWED_FIELDS - given:
            self[field] = None
        if not "ts" in self:
            self["ts"] = time.time()

    def get_code(self):
        """
        Returns:
            The code transported in this ResultValue instance/
        """
        return self["code"]

    @classmethod
    def get(self, records, errors):
        num_errors = len(errors)

        if num_errors == 0:
            return ResultValue.success(records)
        elif records:
            return ResultValue.warning(records, errors)
        else:
            return ResultValue.errors(errors)

    @classmethod
    def success(self, result):
        return ResultValue(
            code        = SUCCESS,
            type        = SUCCESS,
            origin      = [CORE, 0],
            value       = result
        )

    @staticmethod
    def warning(result, errors):
        return ResultValue(
            code        = ERROR, 
            type        = WARNING,
            origin      = [CORE, 0],
            value       = result,
            description = errors
        )

    @staticmethod
    def error(description, code = ERROR):
        assert isinstance(description, str),\
            "Invalid description = %s (%s)" % (description, type(description))
        assert isinstance(code, int),\
            "Invalid code = %s (%s)" % (code, type(code))

        return ResultValue(
            type        = ERROR,
            code        = code,
            origin      = [CORE, 0],
            description = [ErrorPacket(type = ERROR, code = code, 
                message = description, traceback = None)]
        )

    @staticmethod
    def errors(errors):
        """
        Make a ResultValue corresponding to an error and
        gathering a set of ErrorPacket instances.
        Args:
            errors: A list of ErrorPacket instances.
        Returns:
            The corresponding ResultValue instance.
        """
        assert isinstance(errors, list),\
            "Invalid errors = %s (%s)" % (errors, type(errors))

        return ResultValue(
            type        = ERROR,
            code        = ERROR,
            origin      = [CORE, 0],
            description = errors
        )

    def is_warning(self):
        return self["type"] == WARNING

    def is_success(self):
        return self["type"] == SUCCESS and self["code"] == SUCCESS

    def get_all(self):
        """
        Retrieve the Records embedded in this ResultValue.
        Raises:
            RuntimeError: in case of failure.
        Returns:
            A Records instance.
        """
        if not self.is_success() and not self.is_warning():
            raise RuntimeError("Error executing query: %s" % \
                    (self["description"]))
        try:
            records = self["value"]
            if len(records) > 0 and not isinstance(records[0], Record):
                raise TypeError("Please put Record instances in ResultValue")
            return records
        except AttributeError as e:
            raise RuntimeError(e)

    def get_one(self):
        """
        Retrieve the only Record embeded in this ResultValue.
        Raises:
            RuntimeError: if there is 0 or more that 1 Record in
                this ResultValue.
        Returns:
            A list of Records (and not of dict).
        """
        records = self.get_all()
        num_records = len(records)
        if num_records != 1:
            raise RuntimeError('Cannot call get_one() with multiple records')
        return records.get_one()

    def get_error_message(self):
        return "%r" % self["description"]

    @staticmethod
    def to_html(raw_dict):
        return pprint.pformat(raw_dict).replace("\\n","<br/>")

    def to_dict(self):
        return dict(self)