summaryrefslogtreecommitdiffstats
path: root/scripts/external_libs/zmq/auth/base.py
blob: 9b4aaed79d25286f91d6790bb97a5f06f339428a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""Base implementation of 0MQ authentication."""

# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.

import logging

import zmq
from zmq.utils import z85
from zmq.utils.strtypes import bytes, unicode, b, u
from zmq.error import _check_version

from .certs import load_certificates


CURVE_ALLOW_ANY = '*'
VERSION = b'1.0'

class Authenticator(object):
    """Implementation of ZAP authentication for zmq connections.

    Note:
    - libzmq provides four levels of security: default NULL (which the Authenticator does
      not see), and authenticated NULL, PLAIN, and CURVE, which the Authenticator can see.
    - until you add policies, all incoming NULL connections are allowed
    (classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied.
    """

    def __init__(self, context=None, encoding='utf-8', log=None):
        _check_version((4,0), "security")
        self.context = context or zmq.Context.instance()
        self.encoding = encoding
        self.allow_any = False
        self.zap_socket = None
        self.whitelist = set()
        self.blacklist = set()
        # passwords is a dict keyed by domain and contains values
        # of dicts with username:password pairs.
        self.passwords = {}
        # certs is dict keyed by domain and contains values
        # of dicts keyed by the public keys from the specified location.
        self.certs = {}
        self.log = log or logging.getLogger('zmq.auth')
    
    def start(self):
        """Create and bind the ZAP socket"""
        self.zap_socket = self.context.socket(zmq.REP)
        self.zap_socket.linger = 1
        self.zap_socket.bind("inproc://zeromq.zap.01")

    def stop(self):
        """Close the ZAP socket"""
        if self.zap_socket:
            self.zap_socket.close()
        self.zap_socket = None

    def allow(self, *addresses):
        """Allow (whitelist) IP address(es).
        
        Connections from addresses not in the whitelist will be rejected.
        
        - For NULL, all clients from this address will be accepted.
        - For PLAIN and CURVE, they will be allowed to continue with authentication.
        
        whitelist is mutually exclusive with blacklist.
        """
        if self.blacklist:
            raise ValueError("Only use a whitelist or a blacklist, not both")
        self.whitelist.update(addresses)

    def deny(self, *addresses):
        """Deny (blacklist) IP address(es).
        
        Addresses not in the blacklist will be allowed to continue with authentication.
        
        Blacklist is mutually exclusive with whitelist.
        """
        if self.whitelist:
            raise ValueError("Only use a whitelist or a blacklist, not both")
        self.blacklist.update(addresses)

    def configure_plain(self, domain='*', passwords=None):
        """Configure PLAIN authentication for a given domain.
        
        PLAIN authentication uses a plain-text password file.
        To cover all domains, use "*".
        You can modify the password file at any time; it is reloaded automatically.
        """
        if passwords:
            self.passwords[domain] = passwords

    def configure_curve(self, domain='*', location=None):
        """Configure CURVE authentication for a given domain.
        
        CURVE authentication uses a directory that holds all public client certificates,
        i.e. their public keys.
        
        To cover all domains, use "*".
        
        You can add and remove certificates in that directory at any time.
        
        To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
        """
        # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise
        # treat location as a directory that holds the certificates.
        if location == CURVE_ALLOW_ANY:
            self.allow_any = True
        else:
            self.allow_any = False
            try:
                self.certs[domain] = load_certificates(location)
            except Exception as e:
                self.log.error("Failed to load CURVE certs from %s: %s", location, e)

    def handle_zap_message(self, msg):
        """Perform ZAP authentication"""
        if len(msg) < 6:
            self.log.error("Invalid ZAP message, not enough frames: %r", msg)
            if len(msg) < 2:
                self.log.error("Not enough information to reply")
            else:
                self._send_zap_reply(msg[1], b"400", b"Not enough frames")
            return
        
        version, request_id, domain, address, identity, mechanism = msg[:6]
        credentials = msg[6:]
        
        domain = u(domain, self.encoding, 'replace')
        address = u(address, self.encoding, 'replace')

        if (version != VERSION):
            self.log.error("Invalid ZAP version: %r", msg)
            self._send_zap_reply(request_id, b"400", b"Invalid version")
            return

        self.log.debug("version: %r, request_id: %r, domain: %r,"
                      " address: %r, identity: %r, mechanism: %r",
                      version, request_id, domain,
                      address, identity, mechanism,
        )


        # Is address is explicitly whitelisted or blacklisted?
        allowed = False
        denied = False
        reason = b"NO ACCESS"

        if self.whitelist:
            if address in self.whitelist:
                allowed = True
                self.log.debug("PASSED (whitelist) address=%s", address)
            else:
                denied = True
                reason = b"Address not in whitelist"
                self.log.debug("DENIED (not in whitelist) address=%s", address)

        elif self.blacklist:
            if address in self.blacklist:
                denied = True
                reason = b"Address is blacklisted"
                self.log.debug("DENIED (blacklist) address=%s", address)
            else:
                allowed = True
                self.log.debug("PASSED (not in blacklist) address=%s", address)

        # Perform authentication mechanism-specific checks if necessary
        username = u("user")
        if not denied:

            if mechanism == b'NULL' and not allowed:
                # For NULL, we allow if the address wasn't blacklisted
                self.log.debug("ALLOWED (NULL)")
                allowed = True

            elif mechanism == b'PLAIN':
                # For PLAIN, even a whitelisted address must authenticate
                if len(credentials) != 2:
                    self.log.error("Invalid PLAIN credentials: %r", credentials)
                    self._send_zap_reply(request_id, b"400", b"Invalid credentials")
                    return
                username, password = [ u(c, self.encoding, 'replace') for c in credentials ]
                allowed, reason = self._authenticate_plain(domain, username, password)

            elif mechanism == b'CURVE':
                # For CURVE, even a whitelisted address must authenticate
                if len(credentials) != 1:
                    self.log.error("Invalid CURVE credentials: %r", credentials)
                    self._send_zap_reply(request_id, b"400", b"Invalid credentials")
                    return
                key = credentials[0]
                allowed, reason = self._authenticate_curve(domain, key)

        if allowed:
            self._send_zap_reply(request_id, b"200", b"OK", username)
        else:
            self._send_zap_reply(request_id, b"400", reason)

    def _authenticate_plain(self, domain, username, password):
        """PLAIN ZAP authentication"""
        allowed = False
        reason = b""
        if self.passwords:
            # If no domain is not specified then use the default domain
            if not domain:
                domain = '*'

            if domain in self.passwords:
                if username in self.passwords[domain]:
                    if password == self.passwords[domain][username]:
                        allowed = True
                    else:
                        reason = b"Invalid password"
                else:
                    reason = b"Invalid username"
            else:
                reason = b"Invalid domain"

            if allowed:
                self.log.debug("ALLOWED (PLAIN) domain=%s username=%s password=%s",
                    domain, username, password,
                )
            else:
                self.log.debug("DENIED %s", reason)

        else:
            reason = b"No passwords defined"
            self.log.debug("DENIED (PLAIN) %s", reason)

        return allowed, reason

    def _authenticate_curve(self, domain, client_key):
        """CURVE ZAP authentication"""
        allowed = False
        reason = b""
        if self.allow_any:
            allowed = True
            reason = b"OK"
            self.log.debug("ALLOWED (CURVE allow any client)")
        else:
            # If no explicit domain is specified then use the default domain
            if not domain:
                domain = '*'

            if domain in self.certs:
                # The certs dict stores keys in z85 format, convert binary key to z85 bytes
                z85_client_key = z85.encode(client_key)
                if z85_client_key in self.certs[domain] or self.certs[domain] == b'OK':
                    allowed = True
                    reason = b"OK"
                else:
                    reason = b"Unknown key"

                status = "ALLOWED" if allowed else "DENIED"
                self.log.debug("%s (CURVE) domain=%s client_key=%s",
                    status, domain, z85_client_key,
                )
            else:
                reason = b"Unknown domain"

        return allowed, reason

    def _send_zap_reply(self, request_id, status_code, status_text, user_id='user'):
        """Send a ZAP reply to finish the authentication."""
        user_id = user_id if status_code == b'200' else b''
        if isinstance(user_id, unicode):
            user_id = user_id.encode(self.encoding, 'replace')
        metadata = b''  # not currently used
        self.log.debug("ZAP reply code=%s text=%s", status_code, status_text)
        reply = [VERSION, request_id, status_code, status_text, user_id, metadata]
        self.zap_socket.send_multipart(reply)

__all__ = ['Authenticator', 'CURVE_ALLOW_ANY']