diff options
author | imarom <imarom@cisco.com> | 2015-08-18 10:26:51 +0300 |
---|---|---|
committer | imarom <imarom@cisco.com> | 2015-08-18 10:26:51 +0300 |
commit | cbc645cba025f2098031350fc1323e6ffff33633 (patch) | |
tree | db867b2eb09d5b38d47e28a9876bc21f51f8cc57 /src | |
parent | 396b54ea57308890c29c1d9746f0cce32d990cf8 (diff) |
new files for Python console
Diffstat (limited to 'src')
125 files changed, 13537 insertions, 0 deletions
diff --git a/src/console/trex_console.py b/src/console/trex_console.py new file mode 100755 index 00000000..404a2ee0 --- /dev/null +++ b/src/console/trex_console.py @@ -0,0 +1,80 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +import cmd +from trex_rpc_client import RpcClient + +class TrexConsole(cmd.Cmd): + """Trex Console""" + + def __init__(self, rpc_client): + cmd.Cmd.__init__(self) + self.prompt = "TRex > " + self.intro = "\n-=TRex Console V1.0=-\n" + self.rpc_client = rpc_client + self.verbose = False + + # a cool hack - i stole this function and added space + def completenames(self, text, *ignored): + dotext = 'do_'+text + return [a[3:]+' ' for a in self.get_names() if a.startswith(dotext)] + + def do_verbose (self, line): + '''shows or set verbose mode\nusage: verbose [on/off]\n''' + if line == "": + print "verbose is " + ("on" if self.verbose else "off") + elif line == "on": + self.verbose = True + self.rpc_client.set_verbose(True) + print "verbose set to on\n" + + elif line == "off": + self.verbose = False + self.rpc_client.set_verbose(False) + print "verbose set to off\n" + else: + print "please specify 'on' or 'off'\n" + + + def do_query_server(self, line): + '''\nquery the RPC server for supported remote commands\n''' + rc, msg = self.rpc_client.query_rpc_server() + if not rc: + print "\n*** Failed to query RPC server: " + str(msg) + + print "RPC server supports the following commands: \n\n" + msg + + def do_ping (self, line): + '''pings the RPC server \n''' + print "Pinging RPC server" + + rc, msg = self.rpc_client.ping_rpc_server() + if rc: + print "[SUCCESS]\n" + else: + print "[FAILED]\n" + + def do_quit(self, line): + '''\nexit the client\n''' + return True + + def default(self, line): + print "'{0}' is an unrecognized command\n".format(line) + + # aliasing + do_EOF = do_q = do_quit + +def main (): + # RPC client + try: + rpc_client = RpcClient(5050) + rpc_client.connect() + except Exception as e: + print "\n*** " + str(e) + "\n" + exit(-1) + + # console + console = TrexConsole(rpc_client) + console.cmdloop() + +if __name__ == '__main__': + main() diff --git a/src/console/trex_rpc_client.py b/src/console/trex_rpc_client.py new file mode 100644 index 00000000..e6ff5fee --- /dev/null +++ b/src/console/trex_rpc_client.py @@ -0,0 +1,101 @@ +
+import zmq
+import json
+from time import sleep
+
+class RpcClient():
+
+ def create_jsonrpc_v2 (self, method_name, params = {}, id = None):
+ msg = {}
+ msg["jsonrpc"] = "2.0"
+ msg["method"] = method_name
+
+ msg["params"] = {}
+ for key, value in params.iteritems():
+ msg["params"][key] = value
+
+ msg["id"] = id
+
+ return json.dumps(msg)
+
+ def invoke_rpc_method (self, method_name, params = {}, block = True):
+ msg = self.create_jsonrpc_v2(method_name, params, id = 1)
+
+ if self.verbose:
+ print "\nSending Request To Server: " + str(msg) + "\n"
+
+ self.socket.send(msg)
+
+ got_response = False
+
+ if block:
+ response = self.socket.recv()
+ got_response = True
+ else:
+ for i in xrange(0 ,5):
+ try:
+ response = self.socket.recv(flags = zmq.NOBLOCK)
+ got_response = True
+ break
+ except zmq.error.Again:
+ sleep(0.1)
+
+ if not got_response:
+ return False, "Failed To Get Server Response"
+
+ if self.verbose:
+ print "Server Response: " + str(response) + "\n"
+
+ # decode
+ response_json = json.loads(response)
+
+ if (response_json.get("jsonrpc") != "2.0"):
+ return False, "Bad Server Response: " + str(response)
+
+ if ("error" in response_json):
+ return False, "Server Has Reported An Error: " + str(response)
+
+ if ("result" not in response_json):
+ return False, "Bad Server Response: " + str(response)
+
+ return True, response_json["result"]
+
+
+ def ping_rpc_server (self):
+
+ return self.invoke_rpc_method("rpc_ping", block = False)
+
+
+ def query_rpc_server (self):
+ return self.invoke_rpc_method("rpc_get_reg_cmds")
+
+ def __init__ (self, port):
+ self.context = zmq.Context()
+
+ # Socket to talk to server
+ self.transport = "tcp://localhost:{0}".format(port)
+
+ self.verbose = False
+
+ def set_verbose (self, mode):
+ self.verbose = mode
+
+ def connect (self):
+
+ print "\nConnecting To RPC Server On {0}".format(self.transport)
+
+ self.socket = self.context.socket(zmq.REQ)
+ self.socket.connect(self.transport)
+
+ # ping the server
+ rc, err = self.ping_rpc_server()
+ if not rc:
+ self.context.destroy(linger = 0)
+ raise Exception(err)
+
+ #print "Connection Established !\n"
+ print "[SUCCESS]\n"
+
+
+
+
diff --git a/src/console/zmq/__init__.py b/src/console/zmq/__init__.py new file mode 100755 index 00000000..3408b3ba --- /dev/null +++ b/src/console/zmq/__init__.py @@ -0,0 +1,64 @@ +"""Python bindings for 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import os +import sys +import glob + +# load bundled libzmq, if there is one: + +here = os.path.dirname(__file__) + +bundled = [] +bundled_sodium = [] +for ext in ('pyd', 'so', 'dll', 'dylib'): + bundled_sodium.extend(glob.glob(os.path.join(here, 'libsodium*.%s*' % ext))) + bundled.extend(glob.glob(os.path.join(here, 'libzmq*.%s*' % ext))) + +if bundled: + import ctypes + if bundled_sodium: + if bundled[0].endswith('.pyd'): + # a Windows Extension + _libsodium = ctypes.cdll.LoadLibrary(bundled_sodium[0]) + else: + _libsodium = ctypes.CDLL(bundled_sodium[0], mode=ctypes.RTLD_GLOBAL) + if bundled[0].endswith('.pyd'): + # a Windows Extension + _libzmq = ctypes.cdll.LoadLibrary(bundled[0]) + else: + _libzmq = ctypes.CDLL(bundled[0], mode=ctypes.RTLD_GLOBAL) + del ctypes +else: + import zipimport + try: + if isinstance(__loader__, zipimport.zipimporter): + # a zipped pyzmq egg + from zmq import libzmq as _libzmq + except (NameError, ImportError): + pass + finally: + del zipimport + +del os, sys, glob, here, bundled, bundled_sodium, ext + +# zmq top-level imports + +from zmq import backend +from zmq.backend import * +from zmq import sugar +from zmq.sugar import * +from zmq import devices + +def get_includes(): + """Return a list of directories to include for linking against pyzmq with cython.""" + from os.path import join, dirname, abspath, pardir + base = dirname(__file__) + parent = abspath(join(base, pardir)) + return [ parent ] + [ join(parent, base, subdir) for subdir in ('utils',) ] + + +__all__ = ['get_includes'] + sugar.__all__ + backend.__all__ + diff --git a/src/console/zmq/auth/__init__.py b/src/console/zmq/auth/__init__.py new file mode 100755 index 00000000..11d3ad6b --- /dev/null +++ b/src/console/zmq/auth/__init__.py @@ -0,0 +1,10 @@ +"""Utilities for ZAP authentication. + +To run authentication in a background thread, see :mod:`zmq.auth.thread`. +For integration with the tornado eventloop, see :mod:`zmq.auth.ioloop`. + +.. versionadded:: 14.1 +""" + +from .base import * +from .certs import * diff --git a/src/console/zmq/auth/base.py b/src/console/zmq/auth/base.py new file mode 100755 index 00000000..9b4aaed7 --- /dev/null +++ b/src/console/zmq/auth/base.py @@ -0,0 +1,272 @@ +"""Base implementation of 0MQ authentication.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import logging + +import zmq +from zmq.utils import z85 +from zmq.utils.strtypes import bytes, unicode, b, u +from zmq.error import _check_version + +from .certs import load_certificates + + +CURVE_ALLOW_ANY = '*' +VERSION = b'1.0' + +class Authenticator(object): + """Implementation of ZAP authentication for zmq connections. + + Note: + - libzmq provides four levels of security: default NULL (which the Authenticator does + not see), and authenticated NULL, PLAIN, and CURVE, which the Authenticator can see. + - until you add policies, all incoming NULL connections are allowed + (classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied. + """ + + def __init__(self, context=None, encoding='utf-8', log=None): + _check_version((4,0), "security") + self.context = context or zmq.Context.instance() + self.encoding = encoding + self.allow_any = False + self.zap_socket = None + self.whitelist = set() + self.blacklist = set() + # passwords is a dict keyed by domain and contains values + # of dicts with username:password pairs. + self.passwords = {} + # certs is dict keyed by domain and contains values + # of dicts keyed by the public keys from the specified location. + self.certs = {} + self.log = log or logging.getLogger('zmq.auth') + + def start(self): + """Create and bind the ZAP socket""" + self.zap_socket = self.context.socket(zmq.REP) + self.zap_socket.linger = 1 + self.zap_socket.bind("inproc://zeromq.zap.01") + + def stop(self): + """Close the ZAP socket""" + if self.zap_socket: + self.zap_socket.close() + self.zap_socket = None + + def allow(self, *addresses): + """Allow (whitelist) IP address(es). + + Connections from addresses not in the whitelist will be rejected. + + - For NULL, all clients from this address will be accepted. + - For PLAIN and CURVE, they will be allowed to continue with authentication. + + whitelist is mutually exclusive with blacklist. + """ + if self.blacklist: + raise ValueError("Only use a whitelist or a blacklist, not both") + self.whitelist.update(addresses) + + def deny(self, *addresses): + """Deny (blacklist) IP address(es). + + Addresses not in the blacklist will be allowed to continue with authentication. + + Blacklist is mutually exclusive with whitelist. + """ + if self.whitelist: + raise ValueError("Only use a whitelist or a blacklist, not both") + self.blacklist.update(addresses) + + def configure_plain(self, domain='*', passwords=None): + """Configure PLAIN authentication for a given domain. + + PLAIN authentication uses a plain-text password file. + To cover all domains, use "*". + You can modify the password file at any time; it is reloaded automatically. + """ + if passwords: + self.passwords[domain] = passwords + + def configure_curve(self, domain='*', location=None): + """Configure CURVE authentication for a given domain. + + CURVE authentication uses a directory that holds all public client certificates, + i.e. their public keys. + + To cover all domains, use "*". + + You can add and remove certificates in that directory at any time. + + To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location. + """ + # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise + # treat location as a directory that holds the certificates. + if location == CURVE_ALLOW_ANY: + self.allow_any = True + else: + self.allow_any = False + try: + self.certs[domain] = load_certificates(location) + except Exception as e: + self.log.error("Failed to load CURVE certs from %s: %s", location, e) + + def handle_zap_message(self, msg): + """Perform ZAP authentication""" + if len(msg) < 6: + self.log.error("Invalid ZAP message, not enough frames: %r", msg) + if len(msg) < 2: + self.log.error("Not enough information to reply") + else: + self._send_zap_reply(msg[1], b"400", b"Not enough frames") + return + + version, request_id, domain, address, identity, mechanism = msg[:6] + credentials = msg[6:] + + domain = u(domain, self.encoding, 'replace') + address = u(address, self.encoding, 'replace') + + if (version != VERSION): + self.log.error("Invalid ZAP version: %r", msg) + self._send_zap_reply(request_id, b"400", b"Invalid version") + return + + self.log.debug("version: %r, request_id: %r, domain: %r," + " address: %r, identity: %r, mechanism: %r", + version, request_id, domain, + address, identity, mechanism, + ) + + + # Is address is explicitly whitelisted or blacklisted? + allowed = False + denied = False + reason = b"NO ACCESS" + + if self.whitelist: + if address in self.whitelist: + allowed = True + self.log.debug("PASSED (whitelist) address=%s", address) + else: + denied = True + reason = b"Address not in whitelist" + self.log.debug("DENIED (not in whitelist) address=%s", address) + + elif self.blacklist: + if address in self.blacklist: + denied = True + reason = b"Address is blacklisted" + self.log.debug("DENIED (blacklist) address=%s", address) + else: + allowed = True + self.log.debug("PASSED (not in blacklist) address=%s", address) + + # Perform authentication mechanism-specific checks if necessary + username = u("user") + if not denied: + + if mechanism == b'NULL' and not allowed: + # For NULL, we allow if the address wasn't blacklisted + self.log.debug("ALLOWED (NULL)") + allowed = True + + elif mechanism == b'PLAIN': + # For PLAIN, even a whitelisted address must authenticate + if len(credentials) != 2: + self.log.error("Invalid PLAIN credentials: %r", credentials) + self._send_zap_reply(request_id, b"400", b"Invalid credentials") + return + username, password = [ u(c, self.encoding, 'replace') for c in credentials ] + allowed, reason = self._authenticate_plain(domain, username, password) + + elif mechanism == b'CURVE': + # For CURVE, even a whitelisted address must authenticate + if len(credentials) != 1: + self.log.error("Invalid CURVE credentials: %r", credentials) + self._send_zap_reply(request_id, b"400", b"Invalid credentials") + return + key = credentials[0] + allowed, reason = self._authenticate_curve(domain, key) + + if allowed: + self._send_zap_reply(request_id, b"200", b"OK", username) + else: + self._send_zap_reply(request_id, b"400", reason) + + def _authenticate_plain(self, domain, username, password): + """PLAIN ZAP authentication""" + allowed = False + reason = b"" + if self.passwords: + # If no domain is not specified then use the default domain + if not domain: + domain = '*' + + if domain in self.passwords: + if username in self.passwords[domain]: + if password == self.passwords[domain][username]: + allowed = True + else: + reason = b"Invalid password" + else: + reason = b"Invalid username" + else: + reason = b"Invalid domain" + + if allowed: + self.log.debug("ALLOWED (PLAIN) domain=%s username=%s password=%s", + domain, username, password, + ) + else: + self.log.debug("DENIED %s", reason) + + else: + reason = b"No passwords defined" + self.log.debug("DENIED (PLAIN) %s", reason) + + return allowed, reason + + def _authenticate_curve(self, domain, client_key): + """CURVE ZAP authentication""" + allowed = False + reason = b"" + if self.allow_any: + allowed = True + reason = b"OK" + self.log.debug("ALLOWED (CURVE allow any client)") + else: + # If no explicit domain is specified then use the default domain + if not domain: + domain = '*' + + if domain in self.certs: + # The certs dict stores keys in z85 format, convert binary key to z85 bytes + z85_client_key = z85.encode(client_key) + if z85_client_key in self.certs[domain] or self.certs[domain] == b'OK': + allowed = True + reason = b"OK" + else: + reason = b"Unknown key" + + status = "ALLOWED" if allowed else "DENIED" + self.log.debug("%s (CURVE) domain=%s client_key=%s", + status, domain, z85_client_key, + ) + else: + reason = b"Unknown domain" + + return allowed, reason + + def _send_zap_reply(self, request_id, status_code, status_text, user_id='user'): + """Send a ZAP reply to finish the authentication.""" + user_id = user_id if status_code == b'200' else b'' + if isinstance(user_id, unicode): + user_id = user_id.encode(self.encoding, 'replace') + metadata = b'' # not currently used + self.log.debug("ZAP reply code=%s text=%s", status_code, status_text) + reply = [VERSION, request_id, status_code, status_text, user_id, metadata] + self.zap_socket.send_multipart(reply) + +__all__ = ['Authenticator', 'CURVE_ALLOW_ANY'] diff --git a/src/console/zmq/auth/certs.py b/src/console/zmq/auth/certs.py new file mode 100755 index 00000000..4d26ad7b --- /dev/null +++ b/src/console/zmq/auth/certs.py @@ -0,0 +1,119 @@ +"""0MQ authentication related functions and classes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import datetime +import glob +import io +import os +import zmq +from zmq.utils.strtypes import bytes, unicode, b, u + + +_cert_secret_banner = u("""# **** Generated on {0} by pyzmq **** +# ZeroMQ CURVE **Secret** Certificate +# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions. + +""") + +_cert_public_banner = u("""# **** Generated on {0} by pyzmq **** +# ZeroMQ CURVE Public Certificate +# Exchange securely, or use a secure mechanism to verify the contents +# of this file after exchange. Store public certificates in your home +# directory, in the .curve subdirectory. + +""") + +def _write_key_file(key_filename, banner, public_key, secret_key=None, metadata=None, encoding='utf-8'): + """Create a certificate file""" + if isinstance(public_key, bytes): + public_key = public_key.decode(encoding) + if isinstance(secret_key, bytes): + secret_key = secret_key.decode(encoding) + with io.open(key_filename, 'w', encoding='utf8') as f: + f.write(banner.format(datetime.datetime.now())) + + f.write(u('metadata\n')) + if metadata: + for k, v in metadata.items(): + if isinstance(v, bytes): + v = v.decode(encoding) + f.write(u(" {0} = {1}\n").format(k, v)) + + f.write(u('curve\n')) + f.write(u(" public-key = \"{0}\"\n").format(public_key)) + + if secret_key: + f.write(u(" secret-key = \"{0}\"\n").format(secret_key)) + + +def create_certificates(key_dir, name, metadata=None): + """Create zmq certificates. + + Returns the file paths to the public and secret certificate files. + """ + public_key, secret_key = zmq.curve_keypair() + base_filename = os.path.join(key_dir, name) + secret_key_file = "{0}.key_secret".format(base_filename) + public_key_file = "{0}.key".format(base_filename) + now = datetime.datetime.now() + + _write_key_file(public_key_file, + _cert_public_banner.format(now), + public_key) + + _write_key_file(secret_key_file, + _cert_secret_banner.format(now), + public_key, + secret_key=secret_key, + metadata=metadata) + + return public_key_file, secret_key_file + + +def load_certificate(filename): + """Load public and secret key from a zmq certificate. + + Returns (public_key, secret_key) + + If the certificate file only contains the public key, + secret_key will be None. + """ + public_key = None + secret_key = None + if not os.path.exists(filename): + raise IOError("Invalid certificate file: {0}".format(filename)) + + with open(filename, 'rb') as f: + for line in f: + line = line.strip() + if line.startswith(b'#'): + continue + if line.startswith(b'public-key'): + public_key = line.split(b"=", 1)[1].strip(b' \t\'"') + if line.startswith(b'secret-key'): + secret_key = line.split(b"=", 1)[1].strip(b' \t\'"') + if public_key and secret_key: + break + + return public_key, secret_key + + +def load_certificates(directory='.'): + """Load public keys from all certificates in a directory""" + certs = {} + if not os.path.isdir(directory): + raise IOError("Invalid certificate directory: {0}".format(directory)) + # Follow czmq pattern of public keys stored in *.key files. + glob_string = os.path.join(directory, "*.key") + + cert_files = glob.glob(glob_string) + for cert_file in cert_files: + public_key, _ = load_certificate(cert_file) + if public_key: + certs[public_key] = 'OK' + return certs + +__all__ = ['create_certificates', 'load_certificate', 'load_certificates'] diff --git a/src/console/zmq/auth/ioloop.py b/src/console/zmq/auth/ioloop.py new file mode 100755 index 00000000..1f448b47 --- /dev/null +++ b/src/console/zmq/auth/ioloop.py @@ -0,0 +1,34 @@ +"""ZAP Authenticator integrated with the tornado IOLoop. + +.. versionadded:: 14.1 +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.eventloop import ioloop, zmqstream +from .base import Authenticator + + +class IOLoopAuthenticator(Authenticator): + """ZAP authentication for use in the tornado IOLoop""" + + def __init__(self, context=None, encoding='utf-8', log=None, io_loop=None): + super(IOLoopAuthenticator, self).__init__(context) + self.zap_stream = None + self.io_loop = io_loop or ioloop.IOLoop.instance() + + def start(self): + """Start ZAP authentication""" + super(IOLoopAuthenticator, self).start() + self.zap_stream = zmqstream.ZMQStream(self.zap_socket, self.io_loop) + self.zap_stream.on_recv(self.handle_zap_message) + + def stop(self): + """Stop ZAP authentication""" + if self.zap_stream: + self.zap_stream.close() + self.zap_stream = None + super(IOLoopAuthenticator, self).stop() + +__all__ = ['IOLoopAuthenticator'] diff --git a/src/console/zmq/auth/thread.py b/src/console/zmq/auth/thread.py new file mode 100755 index 00000000..8c3355a9 --- /dev/null +++ b/src/console/zmq/auth/thread.py @@ -0,0 +1,184 @@ +"""ZAP Authenticator in a Python Thread. + +.. versionadded:: 14.1 +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import logging +from threading import Thread + +import zmq +from zmq.utils import jsonapi +from zmq.utils.strtypes import bytes, unicode, b, u + +from .base import Authenticator + +class AuthenticationThread(Thread): + """A Thread for running a zmq Authenticator + + This is run in the background by ThreadedAuthenticator + """ + + def __init__(self, context, endpoint, encoding='utf-8', log=None): + super(AuthenticationThread, self).__init__() + self.context = context or zmq.Context.instance() + self.encoding = encoding + self.log = log = log or logging.getLogger('zmq.auth') + self.authenticator = Authenticator(context, encoding=encoding, log=log) + + # create a socket to communicate back to main thread. + self.pipe = context.socket(zmq.PAIR) + self.pipe.linger = 1 + self.pipe.connect(endpoint) + + def run(self): + """ Start the Authentication Agent thread task """ + self.authenticator.start() + zap = self.authenticator.zap_socket + poller = zmq.Poller() + poller.register(self.pipe, zmq.POLLIN) + poller.register(zap, zmq.POLLIN) + while True: + try: + socks = dict(poller.poll()) + except zmq.ZMQError: + break # interrupted + + if self.pipe in socks and socks[self.pipe] == zmq.POLLIN: + terminate = self._handle_pipe() + if terminate: + break + + if zap in socks and socks[zap] == zmq.POLLIN: + self._handle_zap() + + self.pipe.close() + self.authenticator.stop() + + def _handle_zap(self): + """ + Handle a message from the ZAP socket. + """ + msg = self.authenticator.zap_socket.recv_multipart() + if not msg: return + self.authenticator.handle_zap_message(msg) + + def _handle_pipe(self): + """ + Handle a message from front-end API. + """ + terminate = False + + # Get the whole message off the pipe in one go + msg = self.pipe.recv_multipart() + + if msg is None: + terminate = True + return terminate + + command = msg[0] + self.log.debug("auth received API command %r", command) + + if command == b'ALLOW': + addresses = [u(m, self.encoding) for m in msg[1:]] + try: + self.authenticator.allow(*addresses) + except Exception as e: + self.log.exception("Failed to allow %s", addresses) + + elif command == b'DENY': + addresses = [u(m, self.encoding) for m in msg[1:]] + try: + self.authenticator.deny(*addresses) + except Exception as e: + self.log.exception("Failed to deny %s", addresses) + + elif command == b'PLAIN': + domain = u(msg[1], self.encoding) + json_passwords = msg[2] + self.authenticator.configure_plain(domain, jsonapi.loads(json_passwords)) + + elif command == b'CURVE': + # For now we don't do anything with domains + domain = u(msg[1], self.encoding) + + # If location is CURVE_ALLOW_ANY, allow all clients. Otherwise + # treat location as a directory that holds the certificates. + location = u(msg[2], self.encoding) + self.authenticator.configure_curve(domain, location) + + elif command == b'TERMINATE': + terminate = True + + else: + self.log.error("Invalid auth command from API: %r", command) + + return terminate + +def _inherit_docstrings(cls): + """inherit docstrings from Authenticator, so we don't duplicate them""" + for name, method in cls.__dict__.items(): + if name.startswith('_'): + continue + upstream_method = getattr(Authenticator, name, None) + if not method.__doc__: + method.__doc__ = upstream_method.__doc__ + return cls + +@_inherit_docstrings +class ThreadAuthenticator(object): + """Run ZAP authentication in a background thread""" + + def __init__(self, context=None, encoding='utf-8', log=None): + self.context = context or zmq.Context.instance() + self.log = log + self.encoding = encoding + self.pipe = None + self.pipe_endpoint = "inproc://{0}.inproc".format(id(self)) + self.thread = None + + def allow(self, *addresses): + self.pipe.send_multipart([b'ALLOW'] + [b(a, self.encoding) for a in addresses]) + + def deny(self, *addresses): + self.pipe.send_multipart([b'DENY'] + [b(a, self.encoding) for a in addresses]) + + def configure_plain(self, domain='*', passwords=None): + self.pipe.send_multipart([b'PLAIN', b(domain, self.encoding), jsonapi.dumps(passwords or {})]) + + def configure_curve(self, domain='*', location=''): + domain = b(domain, self.encoding) + location = b(location, self.encoding) + self.pipe.send_multipart([b'CURVE', domain, location]) + + def start(self): + """Start the authentication thread""" + # create a socket to communicate with auth thread. + self.pipe = self.context.socket(zmq.PAIR) + self.pipe.linger = 1 + self.pipe.bind(self.pipe_endpoint) + self.thread = AuthenticationThread(self.context, self.pipe_endpoint, encoding=self.encoding, log=self.log) + self.thread.start() + + def stop(self): + """Stop the authentication thread""" + if self.pipe: + self.pipe.send(b'TERMINATE') + if self.is_alive(): + self.thread.join() + self.thread = None + self.pipe.close() + self.pipe = None + + def is_alive(self): + """Is the ZAP thread currently running?""" + if self.thread and self.thread.is_alive(): + return True + return False + + def __del__(self): + self.stop() + +__all__ = ['ThreadAuthenticator'] diff --git a/src/console/zmq/backend/__init__.py b/src/console/zmq/backend/__init__.py new file mode 100755 index 00000000..7cac725c --- /dev/null +++ b/src/console/zmq/backend/__init__.py @@ -0,0 +1,45 @@ +"""Import basic exposure of libzmq C API as a backend""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import os +import platform +import sys + +from zmq.utils.sixcerpt import reraise + +from .select import public_api, select_backend + +if 'PYZMQ_BACKEND' in os.environ: + backend = os.environ['PYZMQ_BACKEND'] + if backend in ('cython', 'cffi'): + backend = 'zmq.backend.%s' % backend + _ns = select_backend(backend) +else: + # default to cython, fallback to cffi + # (reverse on PyPy) + if platform.python_implementation() == 'PyPy': + first, second = ('zmq.backend.cffi', 'zmq.backend.cython') + else: + first, second = ('zmq.backend.cython', 'zmq.backend.cffi') + + try: + _ns = select_backend(first) + except Exception: + exc_info = sys.exc_info() + exc = exc_info[1] + try: + _ns = select_backend(second) + except ImportError: + # prevent 'During handling of the above exception...' on py3 + # can't use `raise ... from` on Python 2 + if hasattr(exc, '__cause__'): + exc.__cause__ = None + # raise the *first* error, not the fallback + reraise(*exc_info) + +globals().update(_ns) + +__all__ = public_api diff --git a/src/console/zmq/backend/cffi/__init__.py b/src/console/zmq/backend/cffi/__init__.py new file mode 100755 index 00000000..ca3164d3 --- /dev/null +++ b/src/console/zmq/backend/cffi/__init__.py @@ -0,0 +1,22 @@ +"""CFFI backend (for PyPY)""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.backend.cffi import (constants, error, message, context, socket, + _poll, devices, utils) + +__all__ = [] +for submod in (constants, error, message, context, socket, + _poll, devices, utils): + __all__.extend(submod.__all__) + +from .constants import * +from .error import * +from .message import * +from .context import * +from .socket import * +from .devices import * +from ._poll import * +from ._cffi import zmq_version_info, ffi +from .utils import * diff --git a/src/console/zmq/backend/cffi/_cdefs.h b/src/console/zmq/backend/cffi/_cdefs.h new file mode 100644 index 00000000..d3300575 --- /dev/null +++ b/src/console/zmq/backend/cffi/_cdefs.h @@ -0,0 +1,68 @@ +void zmq_version(int *major, int *minor, int *patch); + +void* zmq_socket(void *context, int type); +int zmq_close(void *socket); + +int zmq_bind(void *socket, const char *endpoint); +int zmq_connect(void *socket, const char *endpoint); + +int zmq_errno(void); +const char * zmq_strerror(int errnum); + +void* zmq_stopwatch_start(void); +unsigned long zmq_stopwatch_stop(void *watch); +void zmq_sleep(int seconds_); +int zmq_device(int device, void *frontend, void *backend); + +int zmq_unbind(void *socket, const char *endpoint); +int zmq_disconnect(void *socket, const char *endpoint); +void* zmq_ctx_new(); +int zmq_ctx_destroy(void *context); +int zmq_ctx_get(void *context, int opt); +int zmq_ctx_set(void *context, int opt, int optval); +int zmq_proxy(void *frontend, void *backend, void *capture); +int zmq_socket_monitor(void *socket, const char *addr, int events); + +int zmq_curve_keypair (char *z85_public_key, char *z85_secret_key); +int zmq_has (const char *capability); + +typedef struct { ...; } zmq_msg_t; +typedef ... zmq_free_fn; + +int zmq_msg_init(zmq_msg_t *msg); +int zmq_msg_init_size(zmq_msg_t *msg, size_t size); +int zmq_msg_init_data(zmq_msg_t *msg, + void *data, + size_t size, + zmq_free_fn *ffn, + void *hint); + +size_t zmq_msg_size(zmq_msg_t *msg); +void *zmq_msg_data(zmq_msg_t *msg); +int zmq_msg_close(zmq_msg_t *msg); + +int zmq_msg_send(zmq_msg_t *msg, void *socket, int flags); +int zmq_msg_recv(zmq_msg_t *msg, void *socket, int flags); + +int zmq_getsockopt(void *socket, + int option_name, + void *option_value, + size_t *option_len); + +int zmq_setsockopt(void *socket, + int option_name, + const void *option_value, + size_t option_len); +typedef struct +{ + void *socket; + int fd; + short events; + short revents; +} zmq_pollitem_t; + +int zmq_poll(zmq_pollitem_t *items, int nitems, long timeout); + +// miscellany +void * memcpy(void *restrict s1, const void *restrict s2, size_t n); +int get_ipc_path_max_len(void); diff --git a/src/console/zmq/backend/cffi/_cffi.py b/src/console/zmq/backend/cffi/_cffi.py new file mode 100755 index 00000000..c73ebf83 --- /dev/null +++ b/src/console/zmq/backend/cffi/_cffi.py @@ -0,0 +1,127 @@ +# coding: utf-8 +"""The main CFFI wrapping of libzmq""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import json +import os +from os.path import dirname, join +from cffi import FFI + +from zmq.utils.constant_names import all_names, no_prefix + + +base_zmq_version = (3,2,2) + +def load_compiler_config(): + """load pyzmq compiler arguments""" + import zmq + zmq_dir = dirname(zmq.__file__) + zmq_parent = dirname(zmq_dir) + + fname = join(zmq_dir, 'utils', 'compiler.json') + if os.path.exists(fname): + with open(fname) as f: + cfg = json.load(f) + else: + cfg = {} + + cfg.setdefault("include_dirs", []) + cfg.setdefault("library_dirs", []) + cfg.setdefault("runtime_library_dirs", []) + cfg.setdefault("libraries", ["zmq"]) + + # cast to str, because cffi can't handle unicode paths (?!) + cfg['libraries'] = [str(lib) for lib in cfg['libraries']] + for key in ("include_dirs", "library_dirs", "runtime_library_dirs"): + # interpret paths relative to parent of zmq (like source tree) + abs_paths = [] + for p in cfg[key]: + if p.startswith('zmq'): + p = join(zmq_parent, p) + abs_paths.append(str(p)) + cfg[key] = abs_paths + return cfg + + +def zmq_version_info(): + """Get libzmq version as tuple of ints""" + major = ffi.new('int*') + minor = ffi.new('int*') + patch = ffi.new('int*') + + C.zmq_version(major, minor, patch) + + return (int(major[0]), int(minor[0]), int(patch[0])) + + +cfg = load_compiler_config() +ffi = FFI() + +def _make_defines(names): + _names = [] + for name in names: + define_line = "#define %s ..." % (name) + _names.append(define_line) + + return "\n".join(_names) + +c_constant_names = [] +for name in all_names: + if no_prefix(name): + c_constant_names.append(name) + else: + c_constant_names.append("ZMQ_" + name) + +# load ffi definitions +here = os.path.dirname(__file__) +with open(os.path.join(here, '_cdefs.h')) as f: + _cdefs = f.read() + +with open(os.path.join(here, '_verify.c')) as f: + _verify = f.read() + +ffi.cdef(_cdefs) +ffi.cdef(_make_defines(c_constant_names)) + +try: + C = ffi.verify(_verify, + modulename='_cffi_ext', + libraries=cfg['libraries'], + include_dirs=cfg['include_dirs'], + library_dirs=cfg['library_dirs'], + runtime_library_dirs=cfg['runtime_library_dirs'], + ) + _version_info = zmq_version_info() +except Exception as e: + raise ImportError("PyZMQ CFFI backend couldn't find zeromq: %s\n" + "Please check that you have zeromq headers and libraries." % e) + +if _version_info < (3,2,2): + raise ImportError("PyZMQ CFFI backend requires zeromq >= 3.2.2," + " but found %i.%i.%i" % _version_info + ) + +nsp = new_sizet_pointer = lambda length: ffi.new('size_t*', length) + +new_uint64_pointer = lambda: (ffi.new('uint64_t*'), + nsp(ffi.sizeof('uint64_t'))) +new_int64_pointer = lambda: (ffi.new('int64_t*'), + nsp(ffi.sizeof('int64_t'))) +new_int_pointer = lambda: (ffi.new('int*'), + nsp(ffi.sizeof('int'))) +new_binary_data = lambda length: (ffi.new('char[%d]' % (length)), + nsp(ffi.sizeof('char') * length)) + +value_uint64_pointer = lambda val : (ffi.new('uint64_t*', val), + ffi.sizeof('uint64_t')) +value_int64_pointer = lambda val: (ffi.new('int64_t*', val), + ffi.sizeof('int64_t')) +value_int_pointer = lambda val: (ffi.new('int*', val), + ffi.sizeof('int')) +value_binary_data = lambda val, length: (ffi.new('char[%d]' % (length + 1), val), + ffi.sizeof('char') * length) + +IPC_PATH_MAX_LEN = C.get_ipc_path_max_len() diff --git a/src/console/zmq/backend/cffi/_poll.py b/src/console/zmq/backend/cffi/_poll.py new file mode 100755 index 00000000..9bca34ca --- /dev/null +++ b/src/console/zmq/backend/cffi/_poll.py @@ -0,0 +1,56 @@ +# coding: utf-8 +"""zmq poll function""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import C, ffi, zmq_version_info + +from .constants import * + +from zmq.error import _check_rc + + +def _make_zmq_pollitem(socket, flags): + zmq_socket = socket._zmq_socket + zmq_pollitem = ffi.new('zmq_pollitem_t*') + zmq_pollitem.socket = zmq_socket + zmq_pollitem.fd = 0 + zmq_pollitem.events = flags + zmq_pollitem.revents = 0 + return zmq_pollitem[0] + +def _make_zmq_pollitem_fromfd(socket_fd, flags): + zmq_pollitem = ffi.new('zmq_pollitem_t*') + zmq_pollitem.socket = ffi.NULL + zmq_pollitem.fd = socket_fd + zmq_pollitem.events = flags + zmq_pollitem.revents = 0 + return zmq_pollitem[0] + +def zmq_poll(sockets, timeout): + cffi_pollitem_list = [] + low_level_to_socket_obj = {} + for item in sockets: + if isinstance(item[0], int): + low_level_to_socket_obj[item[0]] = item + cffi_pollitem_list.append(_make_zmq_pollitem_fromfd(item[0], item[1])) + else: + low_level_to_socket_obj[item[0]._zmq_socket] = item + cffi_pollitem_list.append(_make_zmq_pollitem(item[0], item[1])) + items = ffi.new('zmq_pollitem_t[]', cffi_pollitem_list) + list_length = ffi.cast('int', len(cffi_pollitem_list)) + c_timeout = ffi.cast('long', timeout) + rc = C.zmq_poll(items, list_length, c_timeout) + _check_rc(rc) + result = [] + for index in range(len(items)): + if not items[index].socket == ffi.NULL: + if items[index].revents > 0: + result.append((low_level_to_socket_obj[items[index].socket][0], + items[index].revents)) + else: + result.append((items[index].fd, items[index].revents)) + return result + +__all__ = ['zmq_poll'] diff --git a/src/console/zmq/backend/cffi/_verify.c b/src/console/zmq/backend/cffi/_verify.c new file mode 100644 index 00000000..547840eb --- /dev/null +++ b/src/console/zmq/backend/cffi/_verify.c @@ -0,0 +1,12 @@ +#include <stdio.h> +#include <sys/un.h> +#include <string.h> + +#include <zmq.h> +#include <zmq_utils.h> +#include "zmq_compat.h" + +int get_ipc_path_max_len(void) { + struct sockaddr_un *dummy; + return sizeof(dummy->sun_path) - 1; +} diff --git a/src/console/zmq/backend/cffi/constants.py b/src/console/zmq/backend/cffi/constants.py new file mode 100755 index 00000000..ee293e74 --- /dev/null +++ b/src/console/zmq/backend/cffi/constants.py @@ -0,0 +1,15 @@ +# coding: utf-8 +"""zmq constants""" + +from ._cffi import C, c_constant_names +from zmq.utils.constant_names import all_names + +g = globals() +for cname in c_constant_names: + if cname.startswith("ZMQ_"): + name = cname[4:] + else: + name = cname + g[name] = getattr(C, cname) + +__all__ = all_names diff --git a/src/console/zmq/backend/cffi/context.py b/src/console/zmq/backend/cffi/context.py new file mode 100755 index 00000000..16a7b257 --- /dev/null +++ b/src/console/zmq/backend/cffi/context.py @@ -0,0 +1,100 @@ +# coding: utf-8 +"""zmq Context class""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import weakref + +from ._cffi import C, ffi + +from .socket import * +from .constants import * + +from zmq.error import ZMQError, _check_rc + +class Context(object): + _zmq_ctx = None + _iothreads = None + _closed = None + _sockets = None + _shadow = False + + def __init__(self, io_threads=1, shadow=None): + + if shadow: + self._zmq_ctx = ffi.cast("void *", shadow) + self._shadow = True + else: + self._shadow = False + if not io_threads >= 0: + raise ZMQError(EINVAL) + + self._zmq_ctx = C.zmq_ctx_new() + if self._zmq_ctx == ffi.NULL: + raise ZMQError(C.zmq_errno()) + if not shadow: + C.zmq_ctx_set(self._zmq_ctx, IO_THREADS, io_threads) + self._closed = False + self._sockets = set() + + @property + def underlying(self): + """The address of the underlying libzmq context""" + return int(ffi.cast('size_t', self._zmq_ctx)) + + @property + def closed(self): + return self._closed + + def _add_socket(self, socket): + ref = weakref.ref(socket) + self._sockets.add(ref) + return ref + + def _rm_socket(self, ref): + if ref in self._sockets: + self._sockets.remove(ref) + + def set(self, option, value): + """set a context option + + see zmq_ctx_set + """ + rc = C.zmq_ctx_set(self._zmq_ctx, option, value) + _check_rc(rc) + + def get(self, option): + """get context option + + see zmq_ctx_get + """ + rc = C.zmq_ctx_get(self._zmq_ctx, option) + _check_rc(rc) + return rc + + def term(self): + if self.closed: + return + + C.zmq_ctx_destroy(self._zmq_ctx) + + self._zmq_ctx = None + self._closed = True + + def destroy(self, linger=None): + if self.closed: + return + + sockets = self._sockets + self._sockets = set() + for s in sockets: + s = s() + if s and not s.closed: + if linger: + s.setsockopt(LINGER, linger) + s.close() + + self.term() + +__all__ = ['Context'] diff --git a/src/console/zmq/backend/cffi/devices.py b/src/console/zmq/backend/cffi/devices.py new file mode 100755 index 00000000..c7a514a8 --- /dev/null +++ b/src/console/zmq/backend/cffi/devices.py @@ -0,0 +1,24 @@ +# coding: utf-8 +"""zmq device functions""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import C, ffi, zmq_version_info +from .socket import Socket +from zmq.error import ZMQError, _check_rc + +def device(device_type, frontend, backend): + rc = C.zmq_proxy(frontend._zmq_socket, backend._zmq_socket, ffi.NULL) + _check_rc(rc) + +def proxy(frontend, backend, capture=None): + if isinstance(capture, Socket): + capture = capture._zmq_socket + else: + capture = ffi.NULL + + rc = C.zmq_proxy(frontend._zmq_socket, backend._zmq_socket, capture) + _check_rc(rc) + +__all__ = ['device', 'proxy'] diff --git a/src/console/zmq/backend/cffi/error.py b/src/console/zmq/backend/cffi/error.py new file mode 100755 index 00000000..3bb64de0 --- /dev/null +++ b/src/console/zmq/backend/cffi/error.py @@ -0,0 +1,13 @@ +"""zmq error functions""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import C, ffi + +def strerror(errno): + return ffi.string(C.zmq_strerror(errno)) + +zmq_errno = C.zmq_errno + +__all__ = ['strerror', 'zmq_errno'] diff --git a/src/console/zmq/backend/cffi/message.py b/src/console/zmq/backend/cffi/message.py new file mode 100755 index 00000000..c35decb6 --- /dev/null +++ b/src/console/zmq/backend/cffi/message.py @@ -0,0 +1,69 @@ +"""Dummy Frame object""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import ffi, C + +import zmq +from zmq.utils.strtypes import unicode + +try: + view = memoryview +except NameError: + view = buffer + +_content = lambda x: x.tobytes() if type(x) == memoryview else x + +class Frame(object): + _data = None + tracker = None + closed = False + more = False + buffer = None + + + def __init__(self, data, track=False): + try: + view(data) + except TypeError: + raise + + self._data = data + + if isinstance(data, unicode): + raise TypeError("Unicode objects not allowed. Only: str/bytes, " + + "buffer interfaces.") + + self.more = False + self.tracker = None + self.closed = False + if track: + self.tracker = zmq.MessageTracker() + + self.buffer = view(self.bytes) + + @property + def bytes(self): + data = _content(self._data) + return data + + def __len__(self): + return len(self.bytes) + + def __eq__(self, other): + return self.bytes == _content(other) + + def __str__(self): + if str is unicode: + return self.bytes.decode() + else: + return self.bytes + + @property + def done(self): + return True + +Message = Frame + +__all__ = ['Frame', 'Message'] diff --git a/src/console/zmq/backend/cffi/socket.py b/src/console/zmq/backend/cffi/socket.py new file mode 100755 index 00000000..3c427739 --- /dev/null +++ b/src/console/zmq/backend/cffi/socket.py @@ -0,0 +1,244 @@ +# coding: utf-8 +"""zmq Socket class""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import random +import codecs + +import errno as errno_mod + +from ._cffi import (C, ffi, new_uint64_pointer, new_int64_pointer, + new_int_pointer, new_binary_data, value_uint64_pointer, + value_int64_pointer, value_int_pointer, value_binary_data, + IPC_PATH_MAX_LEN) + +from .message import Frame +from .constants import * + +import zmq +from zmq.error import ZMQError, _check_rc, _check_version +from zmq.utils.strtypes import unicode + + +def new_pointer_from_opt(option, length=0): + from zmq.sugar.constants import ( + int64_sockopts, bytes_sockopts, + ) + if option in int64_sockopts: + return new_int64_pointer() + elif option in bytes_sockopts: + return new_binary_data(length) + else: + # default + return new_int_pointer() + +def value_from_opt_pointer(option, opt_pointer, length=0): + from zmq.sugar.constants import ( + int64_sockopts, bytes_sockopts, + ) + if option in int64_sockopts: + return int(opt_pointer[0]) + elif option in bytes_sockopts: + return ffi.buffer(opt_pointer, length)[:] + else: + return int(opt_pointer[0]) + +def initialize_opt_pointer(option, value, length=0): + from zmq.sugar.constants import ( + int64_sockopts, bytes_sockopts, + ) + if option in int64_sockopts: + return value_int64_pointer(value) + elif option in bytes_sockopts: + return value_binary_data(value, length) + else: + return value_int_pointer(value) + + +class Socket(object): + context = None + socket_type = None + _zmq_socket = None + _closed = None + _ref = None + _shadow = False + + def __init__(self, context=None, socket_type=None, shadow=None): + self.context = context + if shadow is not None: + self._zmq_socket = ffi.cast("void *", shadow) + self._shadow = True + else: + self._shadow = False + self._zmq_socket = C.zmq_socket(context._zmq_ctx, socket_type) + if self._zmq_socket == ffi.NULL: + raise ZMQError() + self._closed = False + if context: + self._ref = context._add_socket(self) + + @property + def underlying(self): + """The address of the underlying libzmq socket""" + return int(ffi.cast('size_t', self._zmq_socket)) + + @property + def closed(self): + return self._closed + + def close(self, linger=None): + rc = 0 + if not self._closed and hasattr(self, '_zmq_socket'): + if self._zmq_socket is not None: + rc = C.zmq_close(self._zmq_socket) + self._closed = True + if self.context: + self.context._rm_socket(self._ref) + return rc + + def bind(self, address): + if isinstance(address, unicode): + address = address.encode('utf8') + rc = C.zmq_bind(self._zmq_socket, address) + if rc < 0: + if IPC_PATH_MAX_LEN and C.zmq_errno() == errno_mod.ENAMETOOLONG: + # py3compat: address is bytes, but msg wants str + if str is unicode: + address = address.decode('utf-8', 'replace') + path = address.split('://', 1)[-1] + msg = ('ipc path "{0}" is longer than {1} ' + 'characters (sizeof(sockaddr_un.sun_path)).' + .format(path, IPC_PATH_MAX_LEN)) + raise ZMQError(C.zmq_errno(), msg=msg) + else: + _check_rc(rc) + + def unbind(self, address): + _check_version((3,2), "unbind") + if isinstance(address, unicode): + address = address.encode('utf8') + rc = C.zmq_unbind(self._zmq_socket, address) + _check_rc(rc) + + def connect(self, address): + if isinstance(address, unicode): + address = address.encode('utf8') + rc = C.zmq_connect(self._zmq_socket, address) + _check_rc(rc) + + def disconnect(self, address): + _check_version((3,2), "disconnect") + if isinstance(address, unicode): + address = address.encode('utf8') + rc = C.zmq_disconnect(self._zmq_socket, address) + _check_rc(rc) + + def set(self, option, value): + length = None + if isinstance(value, unicode): + raise TypeError("unicode not allowed, use bytes") + + if isinstance(value, bytes): + if option not in zmq.constants.bytes_sockopts: + raise TypeError("not a bytes sockopt: %s" % option) + length = len(value) + + c_data = initialize_opt_pointer(option, value, length) + + c_value_pointer = c_data[0] + c_sizet = c_data[1] + + rc = C.zmq_setsockopt(self._zmq_socket, + option, + ffi.cast('void*', c_value_pointer), + c_sizet) + _check_rc(rc) + + def get(self, option): + c_data = new_pointer_from_opt(option, length=255) + + c_value_pointer = c_data[0] + c_sizet_pointer = c_data[1] + + rc = C.zmq_getsockopt(self._zmq_socket, + option, + c_value_pointer, + c_sizet_pointer) + _check_rc(rc) + + sz = c_sizet_pointer[0] + v = value_from_opt_pointer(option, c_value_pointer, sz) + if option != zmq.IDENTITY and option in zmq.constants.bytes_sockopts and v.endswith(b'\0'): + v = v[:-1] + return v + + def send(self, message, flags=0, copy=False, track=False): + if isinstance(message, unicode): + raise TypeError("Message must be in bytes, not an unicode Object") + + if isinstance(message, Frame): + message = message.bytes + + zmq_msg = ffi.new('zmq_msg_t*') + c_message = ffi.new('char[]', message) + rc = C.zmq_msg_init_size(zmq_msg, len(message)) + C.memcpy(C.zmq_msg_data(zmq_msg), c_message, len(message)) + + rc = C.zmq_msg_send(zmq_msg, self._zmq_socket, flags) + C.zmq_msg_close(zmq_msg) + _check_rc(rc) + + if track: + return zmq.MessageTracker() + + def recv(self, flags=0, copy=True, track=False): + zmq_msg = ffi.new('zmq_msg_t*') + C.zmq_msg_init(zmq_msg) + + rc = C.zmq_msg_recv(zmq_msg, self._zmq_socket, flags) + + if rc < 0: + C.zmq_msg_close(zmq_msg) + _check_rc(rc) + + _buffer = ffi.buffer(C.zmq_msg_data(zmq_msg), C.zmq_msg_size(zmq_msg)) + value = _buffer[:] + C.zmq_msg_close(zmq_msg) + + frame = Frame(value, track=track) + frame.more = self.getsockopt(RCVMORE) + + if copy: + return frame.bytes + else: + return frame + + def monitor(self, addr, events=-1): + """s.monitor(addr, flags) + + Start publishing socket events on inproc. + See libzmq docs for zmq_monitor for details. + + Note: requires libzmq >= 3.2 + + Parameters + ---------- + addr : str + The inproc url used for monitoring. Passing None as + the addr will cause an existing socket monitor to be + deregistered. + events : int [default: zmq.EVENT_ALL] + The zmq event bitmask for which events will be sent to the monitor. + """ + + _check_version((3,2), "monitor") + if events < 0: + events = zmq.EVENT_ALL + if addr is None: + addr = ffi.NULL + rc = C.zmq_socket_monitor(self._zmq_socket, addr, events) + + +__all__ = ['Socket', 'IPC_PATH_MAX_LEN'] diff --git a/src/console/zmq/backend/cffi/utils.py b/src/console/zmq/backend/cffi/utils.py new file mode 100755 index 00000000..fde7827b --- /dev/null +++ b/src/console/zmq/backend/cffi/utils.py @@ -0,0 +1,62 @@ +# coding: utf-8 +"""miscellaneous zmq_utils wrapping""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import ffi, C + +from zmq.error import ZMQError, _check_rc, _check_version +from zmq.utils.strtypes import unicode + +def has(capability): + """Check for zmq capability by name (e.g. 'ipc', 'curve') + + .. versionadded:: libzmq-4.1 + .. versionadded:: 14.1 + """ + _check_version((4,1), 'zmq.has') + if isinstance(capability, unicode): + capability = capability.encode('utf8') + return bool(C.zmq_has(capability)) + +def curve_keypair(): + """generate a Z85 keypair for use with zmq.CURVE security + + Requires libzmq (≥ 4.0) to have been linked with libsodium. + + Returns + ------- + (public, secret) : two bytestrings + The public and private keypair as 40 byte z85-encoded bytestrings. + """ + _check_version((3,2), "monitor") + public = ffi.new('char[64]') + private = ffi.new('char[64]') + rc = C.zmq_curve_keypair(public, private) + _check_rc(rc) + return ffi.buffer(public)[:40], ffi.buffer(private)[:40] + + +class Stopwatch(object): + def __init__(self): + self.watch = ffi.NULL + + def start(self): + if self.watch == ffi.NULL: + self.watch = C.zmq_stopwatch_start() + else: + raise ZMQError('Stopwatch is already runing.') + + def stop(self): + if self.watch == ffi.NULL: + raise ZMQError('Must start the Stopwatch before calling stop.') + else: + time = C.zmq_stopwatch_stop(self.watch) + self.watch = ffi.NULL + return time + + def sleep(self, seconds): + C.zmq_sleep(seconds) + +__all__ = ['has', 'curve_keypair', 'Stopwatch'] diff --git a/src/console/zmq/backend/cython/__init__.py b/src/console/zmq/backend/cython/__init__.py new file mode 100755 index 00000000..e5358185 --- /dev/null +++ b/src/console/zmq/backend/cython/__init__.py @@ -0,0 +1,23 @@ +"""Python bindings for core 0MQ objects.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Lesser GNU Public License (LGPL). + +from . import (constants, error, message, context, + socket, utils, _poll, _version, _device ) + +__all__ = [] +for submod in (constants, error, message, context, + socket, utils, _poll, _version, _device): + __all__.extend(submod.__all__) + +from .constants import * +from .error import * +from .message import * +from .context import * +from .socket import * +from ._poll import * +from .utils import * +from ._device import * +from ._version import * + diff --git a/src/console/zmq/backend/cython/_device.py b/src/console/zmq/backend/cython/_device.py new file mode 100755 index 00000000..3368ca2c --- /dev/null +++ b/src/console/zmq/backend/cython/_device.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'_device.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/_poll.py b/src/console/zmq/backend/cython/_poll.py new file mode 100755 index 00000000..cb1d5d77 --- /dev/null +++ b/src/console/zmq/backend/cython/_poll.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'_poll.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/_version.py b/src/console/zmq/backend/cython/_version.py new file mode 100755 index 00000000..08262706 --- /dev/null +++ b/src/console/zmq/backend/cython/_version.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'_version.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/checkrc.pxd b/src/console/zmq/backend/cython/checkrc.pxd new file mode 100644 index 00000000..3bf69fc3 --- /dev/null +++ b/src/console/zmq/backend/cython/checkrc.pxd @@ -0,0 +1,23 @@ +from libc.errno cimport EINTR, EAGAIN +from cpython cimport PyErr_CheckSignals +from libzmq cimport zmq_errno, ZMQ_ETERM + +cdef inline int _check_rc(int rc) except -1: + """internal utility for checking zmq return condition + + and raising the appropriate Exception class + """ + cdef int errno = zmq_errno() + PyErr_CheckSignals() + if rc < 0: + if errno == EAGAIN: + from zmq.error import Again + raise Again(errno) + elif errno == ZMQ_ETERM: + from zmq.error import ContextTerminated + raise ContextTerminated(errno) + else: + from zmq.error import ZMQError + raise ZMQError(errno) + # return -1 + return 0 diff --git a/src/console/zmq/backend/cython/constants.py b/src/console/zmq/backend/cython/constants.py new file mode 100755 index 00000000..ea772ac0 --- /dev/null +++ b/src/console/zmq/backend/cython/constants.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'constants.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/context.pxd b/src/console/zmq/backend/cython/context.pxd new file mode 100644 index 00000000..9c9267a5 --- /dev/null +++ b/src/console/zmq/backend/cython/context.pxd @@ -0,0 +1,41 @@ +"""0MQ Context class declaration.""" + +# +# Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +cdef class Context: + + cdef object __weakref__ # enable weakref + cdef void *handle # The C handle for the underlying zmq object. + cdef bint _shadow # whether the Context is a shadow wrapper of another + cdef void **_sockets # A C-array containg socket handles + cdef size_t _n_sockets # the number of sockets + cdef size_t _max_sockets # the size of the _sockets array + cdef int _pid # the pid of the process which created me (for fork safety) + + cdef public bint closed # bool property for a closed context. + cdef inline int _term(self) + # helpers for events on _sockets in Socket.__cinit__()/close() + cdef inline void _add_socket(self, void* handle) + cdef inline void _remove_socket(self, void* handle) + diff --git a/src/console/zmq/backend/cython/context.py b/src/console/zmq/backend/cython/context.py new file mode 100755 index 00000000..19f8ec7c --- /dev/null +++ b/src/console/zmq/backend/cython/context.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'context.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/error.py b/src/console/zmq/backend/cython/error.py new file mode 100755 index 00000000..d3a4ea0e --- /dev/null +++ b/src/console/zmq/backend/cython/error.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'error.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/libzmq.pxd b/src/console/zmq/backend/cython/libzmq.pxd new file mode 100644 index 00000000..e42f6d6b --- /dev/null +++ b/src/console/zmq/backend/cython/libzmq.pxd @@ -0,0 +1,110 @@ +"""All the C imports for 0MQ""" + +# +# Copyright (c) 2010 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +#----------------------------------------------------------------------------- +# Import the C header files +#----------------------------------------------------------------------------- + +cdef extern from *: + ctypedef void* const_void_ptr "const void *" + ctypedef char* const_char_ptr "const char *" + +cdef extern from "zmq_compat.h": + ctypedef signed long long int64_t "pyzmq_int64_t" + +include "constant_enums.pxi" + +cdef extern from "zmq.h" nogil: + + void _zmq_version "zmq_version"(int *major, int *minor, int *patch) + + ctypedef int fd_t "ZMQ_FD_T" + + enum: errno + char *zmq_strerror (int errnum) + int zmq_errno() + + void *zmq_ctx_new () + int zmq_ctx_destroy (void *context) + int zmq_ctx_set (void *context, int option, int optval) + int zmq_ctx_get (void *context, int option) + void *zmq_init (int io_threads) + int zmq_term (void *context) + + # blackbox def for zmq_msg_t + ctypedef void * zmq_msg_t "zmq_msg_t" + + ctypedef void zmq_free_fn(void *data, void *hint) + + int zmq_msg_init (zmq_msg_t *msg) + int zmq_msg_init_size (zmq_msg_t *msg, size_t size) + int zmq_msg_init_data (zmq_msg_t *msg, void *data, + size_t size, zmq_free_fn *ffn, void *hint) + int zmq_msg_send (zmq_msg_t *msg, void *s, int flags) + int zmq_msg_recv (zmq_msg_t *msg, void *s, int flags) + int zmq_msg_close (zmq_msg_t *msg) + int zmq_msg_move (zmq_msg_t *dest, zmq_msg_t *src) + int zmq_msg_copy (zmq_msg_t *dest, zmq_msg_t *src) + void *zmq_msg_data (zmq_msg_t *msg) + size_t zmq_msg_size (zmq_msg_t *msg) + int zmq_msg_more (zmq_msg_t *msg) + int zmq_msg_get (zmq_msg_t *msg, int option) + int zmq_msg_set (zmq_msg_t *msg, int option, int optval) + const_char_ptr zmq_msg_gets (zmq_msg_t *msg, const_char_ptr property) + int zmq_has (const_char_ptr capability) + + void *zmq_socket (void *context, int type) + int zmq_close (void *s) + int zmq_setsockopt (void *s, int option, void *optval, size_t optvallen) + int zmq_getsockopt (void *s, int option, void *optval, size_t *optvallen) + int zmq_bind (void *s, char *addr) + int zmq_connect (void *s, char *addr) + int zmq_unbind (void *s, char *addr) + int zmq_disconnect (void *s, char *addr) + + int zmq_socket_monitor (void *s, char *addr, int flags) + + # send/recv + int zmq_sendbuf (void *s, const_void_ptr buf, size_t n, int flags) + int zmq_recvbuf (void *s, void *buf, size_t n, int flags) + + ctypedef struct zmq_pollitem_t: + void *socket + int fd + short events + short revents + + int zmq_poll (zmq_pollitem_t *items, int nitems, long timeout) + + int zmq_device (int device_, void *insocket_, void *outsocket_) + int zmq_proxy (void *frontend, void *backend, void *capture) + +cdef extern from "zmq_utils.h" nogil: + + void *zmq_stopwatch_start () + unsigned long zmq_stopwatch_stop (void *watch_) + void zmq_sleep (int seconds_) + int zmq_curve_keypair (char *z85_public_key, char *z85_secret_key) + diff --git a/src/console/zmq/backend/cython/message.pxd b/src/console/zmq/backend/cython/message.pxd new file mode 100644 index 00000000..4781195f --- /dev/null +++ b/src/console/zmq/backend/cython/message.pxd @@ -0,0 +1,63 @@ +"""0MQ Message related class declarations.""" + +# +# Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from cpython cimport PyBytes_FromStringAndSize + +from libzmq cimport zmq_msg_t, zmq_msg_data, zmq_msg_size + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +cdef class MessageTracker(object): + + cdef set events # Message Event objects to track. + cdef set peers # Other Message or MessageTracker objects. + + +cdef class Frame: + + cdef zmq_msg_t zmq_msg + cdef object _data # The actual message data as a Python object. + cdef object _buffer # A Python Buffer/View of the message contents + cdef object _bytes # A bytes/str copy of the message. + cdef bint _failed_init # Flag to handle failed zmq_msg_init + cdef public object tracker_event # Event for use with zmq_free_fn. + cdef public object tracker # MessageTracker object. + cdef public bint more # whether RCVMORE was set + + cdef Frame fast_copy(self) # Create shallow copy of Message object. + cdef object _getbuffer(self) # Construct self._buffer. + + +cdef inline object copy_zmq_msg_bytes(zmq_msg_t *zmq_msg): + """ Copy the data from a zmq_msg_t """ + cdef char *data_c = NULL + cdef Py_ssize_t data_len_c + data_c = <char *>zmq_msg_data(zmq_msg) + data_len_c = zmq_msg_size(zmq_msg) + return PyBytes_FromStringAndSize(data_c, data_len_c) + + diff --git a/src/console/zmq/backend/cython/message.py b/src/console/zmq/backend/cython/message.py new file mode 100755 index 00000000..5e423b62 --- /dev/null +++ b/src/console/zmq/backend/cython/message.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'message.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/socket.pxd b/src/console/zmq/backend/cython/socket.pxd new file mode 100644 index 00000000..b8a331e2 --- /dev/null +++ b/src/console/zmq/backend/cython/socket.pxd @@ -0,0 +1,47 @@ +"""0MQ Socket class declaration.""" + +# +# Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from context cimport Context + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + + +cdef class Socket: + + cdef object __weakref__ # enable weakref + cdef void *handle # The C handle for the underlying zmq object. + cdef bint _shadow # whether the Socket is a shadow wrapper of another + # Hold on to a reference to the context to make sure it is not garbage + # collected until the socket it done with it. + cdef public Context context # The zmq Context object that owns this. + cdef public bint _closed # bool property for a closed socket. + cdef int _pid # the pid of the process which created me (for fork safety) + + # cpdef methods for direct-cython access: + cpdef object send(self, object data, int flags=*, copy=*, track=*) + cpdef object recv(self, int flags=*, copy=*, track=*) + diff --git a/src/console/zmq/backend/cython/socket.py b/src/console/zmq/backend/cython/socket.py new file mode 100755 index 00000000..faef8bee --- /dev/null +++ b/src/console/zmq/backend/cython/socket.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'socket.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/utils.pxd b/src/console/zmq/backend/cython/utils.pxd new file mode 100644 index 00000000..1d7117f1 --- /dev/null +++ b/src/console/zmq/backend/cython/utils.pxd @@ -0,0 +1,29 @@ +"""Wrap zmq_utils.h""" + +# +# Copyright (c) 2010 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + + +cdef class Stopwatch: + cdef void *watch # The C handle for the underlying zmq object + diff --git a/src/console/zmq/backend/cython/utils.py b/src/console/zmq/backend/cython/utils.py new file mode 100755 index 00000000..fe928300 --- /dev/null +++ b/src/console/zmq/backend/cython/utils.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'utils.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/select.py b/src/console/zmq/backend/select.py new file mode 100755 index 00000000..0a2e09a2 --- /dev/null +++ b/src/console/zmq/backend/select.py @@ -0,0 +1,39 @@ +"""Import basic exposure of libzmq C API as a backend""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +public_api = [ + 'Context', + 'Socket', + 'Frame', + 'Message', + 'Stopwatch', + 'device', + 'proxy', + 'zmq_poll', + 'strerror', + 'zmq_errno', + 'has', + 'curve_keypair', + 'constants', + 'zmq_version_info', + 'IPC_PATH_MAX_LEN', +] + +def select_backend(name): + """Select the pyzmq backend""" + try: + mod = __import__(name, fromlist=public_api) + except ImportError: + raise + except Exception as e: + import sys + from zmq.utils.sixcerpt import reraise + exc_info = sys.exc_info() + reraise(ImportError, ImportError("Importing %s failed with %s" % (name, e)), exc_info[2]) + + ns = {} + for key in public_api: + ns[key] = getattr(mod, key) + return ns diff --git a/src/console/zmq/devices/__init__.py b/src/console/zmq/devices/__init__.py new file mode 100755 index 00000000..23715963 --- /dev/null +++ b/src/console/zmq/devices/__init__.py @@ -0,0 +1,16 @@ +"""0MQ Device classes for running in background threads or processes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq import device +from zmq.devices import basedevice, proxydevice, monitoredqueue, monitoredqueuedevice + +from zmq.devices.basedevice import * +from zmq.devices.proxydevice import * +from zmq.devices.monitoredqueue import * +from zmq.devices.monitoredqueuedevice import * + +__all__ = ['device'] +for submod in (basedevice, proxydevice, monitoredqueue, monitoredqueuedevice): + __all__.extend(submod.__all__) diff --git a/src/console/zmq/devices/basedevice.py b/src/console/zmq/devices/basedevice.py new file mode 100755 index 00000000..7ba1b7ac --- /dev/null +++ b/src/console/zmq/devices/basedevice.py @@ -0,0 +1,229 @@ +"""Classes for running 0MQ Devices in the background.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import time +from threading import Thread +from multiprocessing import Process + +from zmq import device, QUEUE, Context, ETERM, ZMQError + + +class Device: + """A 0MQ Device to be run in the background. + + You do not pass Socket instances to this, but rather Socket types:: + + Device(device_type, in_socket_type, out_socket_type) + + For instance:: + + dev = Device(zmq.QUEUE, zmq.DEALER, zmq.ROUTER) + + Similar to zmq.device, but socket types instead of sockets themselves are + passed, and the sockets are created in the work thread, to avoid issues + with thread safety. As a result, additional bind_{in|out} and + connect_{in|out} methods and setsockopt_{in|out} allow users to specify + connections for the sockets. + + Parameters + ---------- + device_type : int + The 0MQ Device type + {in|out}_type : int + zmq socket types, to be passed later to context.socket(). e.g. + zmq.PUB, zmq.SUB, zmq.REQ. If out_type is < 0, then in_socket is used + for both in_socket and out_socket. + + Methods + ------- + bind_{in_out}(iface) + passthrough for ``{in|out}_socket.bind(iface)``, to be called in the thread + connect_{in_out}(iface) + passthrough for ``{in|out}_socket.connect(iface)``, to be called in the + thread + setsockopt_{in_out}(opt,value) + passthrough for ``{in|out}_socket.setsockopt(opt, value)``, to be called in + the thread + + Attributes + ---------- + daemon : int + sets whether the thread should be run as a daemon + Default is true, because if it is false, the thread will not + exit unless it is killed + context_factory : callable (class attribute) + Function for creating the Context. This will be Context.instance + in ThreadDevices, and Context in ProcessDevices. The only reason + it is not instance() in ProcessDevices is that there may be a stale + Context instance already initialized, and the forked environment + should *never* try to use it. + """ + + context_factory = Context.instance + """Callable that returns a context. Typically either Context.instance or Context, + depending on whether the device should share the global instance or not. + """ + + def __init__(self, device_type=QUEUE, in_type=None, out_type=None): + self.device_type = device_type + if in_type is None: + raise TypeError("in_type must be specified") + if out_type is None: + raise TypeError("out_type must be specified") + self.in_type = in_type + self.out_type = out_type + self._in_binds = [] + self._in_connects = [] + self._in_sockopts = [] + self._out_binds = [] + self._out_connects = [] + self._out_sockopts = [] + self.daemon = True + self.done = False + + def bind_in(self, addr): + """Enqueue ZMQ address for binding on in_socket. + + See zmq.Socket.bind for details. + """ + self._in_binds.append(addr) + + def connect_in(self, addr): + """Enqueue ZMQ address for connecting on in_socket. + + See zmq.Socket.connect for details. + """ + self._in_connects.append(addr) + + def setsockopt_in(self, opt, value): + """Enqueue setsockopt(opt, value) for in_socket + + See zmq.Socket.setsockopt for details. + """ + self._in_sockopts.append((opt, value)) + + def bind_out(self, addr): + """Enqueue ZMQ address for binding on out_socket. + + See zmq.Socket.bind for details. + """ + self._out_binds.append(addr) + + def connect_out(self, addr): + """Enqueue ZMQ address for connecting on out_socket. + + See zmq.Socket.connect for details. + """ + self._out_connects.append(addr) + + def setsockopt_out(self, opt, value): + """Enqueue setsockopt(opt, value) for out_socket + + See zmq.Socket.setsockopt for details. + """ + self._out_sockopts.append((opt, value)) + + def _setup_sockets(self): + ctx = self.context_factory() + + self._context = ctx + + # create the sockets + ins = ctx.socket(self.in_type) + if self.out_type < 0: + outs = ins + else: + outs = ctx.socket(self.out_type) + + # set sockopts (must be done first, in case of zmq.IDENTITY) + for opt,value in self._in_sockopts: + ins.setsockopt(opt, value) + for opt,value in self._out_sockopts: + outs.setsockopt(opt, value) + + for iface in self._in_binds: + ins.bind(iface) + for iface in self._out_binds: + outs.bind(iface) + + for iface in self._in_connects: + ins.connect(iface) + for iface in self._out_connects: + outs.connect(iface) + + return ins,outs + + def run_device(self): + """The runner method. + + Do not call me directly, instead call ``self.start()``, just like a Thread. + """ + ins,outs = self._setup_sockets() + device(self.device_type, ins, outs) + + def run(self): + """wrap run_device in try/catch ETERM""" + try: + self.run_device() + except ZMQError as e: + if e.errno == ETERM: + # silence TERM errors, because this should be a clean shutdown + pass + else: + raise + finally: + self.done = True + + def start(self): + """Start the device. Override me in subclass for other launchers.""" + return self.run() + + def join(self,timeout=None): + """wait for me to finish, like Thread.join. + + Reimplemented appropriately by subclasses.""" + tic = time.time() + toc = tic + while not self.done and not (timeout is not None and toc-tic > timeout): + time.sleep(.001) + toc = time.time() + + +class BackgroundDevice(Device): + """Base class for launching Devices in background processes and threads.""" + + launcher=None + _launch_class=None + + def start(self): + self.launcher = self._launch_class(target=self.run) + self.launcher.daemon = self.daemon + return self.launcher.start() + + def join(self, timeout=None): + return self.launcher.join(timeout=timeout) + + +class ThreadDevice(BackgroundDevice): + """A Device that will be run in a background Thread. + + See Device for details. + """ + _launch_class=Thread + +class ProcessDevice(BackgroundDevice): + """A Device that will be run in a background Process. + + See Device for details. + """ + _launch_class=Process + context_factory = Context + """Callable that returns a context. Typically either Context.instance or Context, + depending on whether the device should share the global instance or not. + """ + + +__all__ = ['Device', 'ThreadDevice', 'ProcessDevice'] diff --git a/src/console/zmq/devices/monitoredqueue.pxd b/src/console/zmq/devices/monitoredqueue.pxd new file mode 100644 index 00000000..1e26ed86 --- /dev/null +++ b/src/console/zmq/devices/monitoredqueue.pxd @@ -0,0 +1,177 @@ +"""MonitoredQueue class declarations. + +Authors +------- +* MinRK +* Brian Granger +""" + +# +# Copyright (c) 2010 Min Ragan-Kelley, Brian Granger +# +# This file is part of pyzmq, but is derived and adapted from zmq_queue.cpp +# originally from libzmq-2.1.6, used under LGPLv3 +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from libzmq cimport * + +#----------------------------------------------------------------------------- +# MonitoredQueue C functions +#----------------------------------------------------------------------------- + +cdef inline int _relay(void *insocket_, void *outsocket_, void *sidesocket_, + zmq_msg_t msg, zmq_msg_t side_msg, zmq_msg_t id_msg, + bint swap_ids) nogil: + cdef int rc + cdef int64_t flag_2 + cdef int flag_3 + cdef int flags + cdef bint more + cdef size_t flagsz + cdef void * flag_ptr + + if ZMQ_VERSION_MAJOR < 3: + flagsz = sizeof (int64_t) + flag_ptr = &flag_2 + else: + flagsz = sizeof (int) + flag_ptr = &flag_3 + + if swap_ids:# both router, must send second identity first + # recv two ids into msg, id_msg + rc = zmq_msg_recv(&msg, insocket_, 0) + if rc < 0: return rc + + rc = zmq_msg_recv(&id_msg, insocket_, 0) + if rc < 0: return rc + + # send second id (id_msg) first + #!!!! always send a copy before the original !!!! + rc = zmq_msg_copy(&side_msg, &id_msg) + if rc < 0: return rc + rc = zmq_msg_send(&side_msg, outsocket_, ZMQ_SNDMORE) + if rc < 0: return rc + rc = zmq_msg_send(&id_msg, sidesocket_, ZMQ_SNDMORE) + if rc < 0: return rc + # send first id (msg) second + rc = zmq_msg_copy(&side_msg, &msg) + if rc < 0: return rc + rc = zmq_msg_send(&side_msg, outsocket_, ZMQ_SNDMORE) + if rc < 0: return rc + rc = zmq_msg_send(&msg, sidesocket_, ZMQ_SNDMORE) + if rc < 0: return rc + while (True): + rc = zmq_msg_recv(&msg, insocket_, 0) + if rc < 0: return rc + # assert (rc == 0) + rc = zmq_getsockopt (insocket_, ZMQ_RCVMORE, flag_ptr, &flagsz) + if rc < 0: return rc + flags = 0 + if ZMQ_VERSION_MAJOR < 3: + if flag_2: + flags |= ZMQ_SNDMORE + else: + if flag_3: + flags |= ZMQ_SNDMORE + # LABEL has been removed: + # rc = zmq_getsockopt (insocket_, ZMQ_RCVLABEL, flag_ptr, &flagsz) + # if flag_3: + # flags |= ZMQ_SNDLABEL + # assert (rc == 0) + + rc = zmq_msg_copy(&side_msg, &msg) + if rc < 0: return rc + if flags: + rc = zmq_msg_send(&side_msg, outsocket_, flags) + if rc < 0: return rc + # only SNDMORE for side-socket + rc = zmq_msg_send(&msg, sidesocket_, ZMQ_SNDMORE) + if rc < 0: return rc + else: + rc = zmq_msg_send(&side_msg, outsocket_, 0) + if rc < 0: return rc + rc = zmq_msg_send(&msg, sidesocket_, 0) + if rc < 0: return rc + break + return rc + +# the MonitoredQueue C function, adapted from zmq::queue.cpp : +cdef inline int c_monitored_queue (void *insocket_, void *outsocket_, + void *sidesocket_, zmq_msg_t *in_msg_ptr, + zmq_msg_t *out_msg_ptr, int swap_ids) nogil: + """The actual C function for a monitored queue device. + + See ``monitored_queue()`` for details. + """ + + cdef zmq_msg_t msg + cdef int rc = zmq_msg_init (&msg) + cdef zmq_msg_t id_msg + rc = zmq_msg_init (&id_msg) + if rc < 0: return rc + cdef zmq_msg_t side_msg + rc = zmq_msg_init (&side_msg) + if rc < 0: return rc + + cdef zmq_pollitem_t items [2] + items [0].socket = insocket_ + items [0].fd = 0 + items [0].events = ZMQ_POLLIN + items [0].revents = 0 + items [1].socket = outsocket_ + items [1].fd = 0 + items [1].events = ZMQ_POLLIN + items [1].revents = 0 + # I don't think sidesocket should be polled? + # items [2].socket = sidesocket_ + # items [2].fd = 0 + # items [2].events = ZMQ_POLLIN + # items [2].revents = 0 + + while (True): + + # // Wait while there are either requests or replies to process. + rc = zmq_poll (&items [0], 2, -1) + if rc < 0: return rc + # // The algorithm below asumes ratio of request and replies processed + # // under full load to be 1:1. Although processing requests replies + # // first is tempting it is suspectible to DoS attacks (overloading + # // the system with unsolicited replies). + # + # // Process a request. + if (items [0].revents & ZMQ_POLLIN): + # send in_prefix to side socket + rc = zmq_msg_copy(&side_msg, in_msg_ptr) + if rc < 0: return rc + rc = zmq_msg_send(&side_msg, sidesocket_, ZMQ_SNDMORE) + if rc < 0: return rc + # relay the rest of the message + rc = _relay(insocket_, outsocket_, sidesocket_, msg, side_msg, id_msg, swap_ids) + if rc < 0: return rc + if (items [1].revents & ZMQ_POLLIN): + # send out_prefix to side socket + rc = zmq_msg_copy(&side_msg, out_msg_ptr) + if rc < 0: return rc + rc = zmq_msg_send(&side_msg, sidesocket_, ZMQ_SNDMORE) + if rc < 0: return rc + # relay the rest of the message + rc = _relay(outsocket_, insocket_, sidesocket_, msg, side_msg, id_msg, swap_ids) + if rc < 0: return rc + return rc diff --git a/src/console/zmq/devices/monitoredqueue.py b/src/console/zmq/devices/monitoredqueue.py new file mode 100755 index 00000000..6d714e51 --- /dev/null +++ b/src/console/zmq/devices/monitoredqueue.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'monitoredqueue.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/devices/monitoredqueuedevice.py b/src/console/zmq/devices/monitoredqueuedevice.py new file mode 100755 index 00000000..9723f866 --- /dev/null +++ b/src/console/zmq/devices/monitoredqueuedevice.py @@ -0,0 +1,66 @@ +"""MonitoredQueue classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from zmq import ZMQError, PUB +from zmq.devices.proxydevice import ProxyBase, Proxy, ThreadProxy, ProcessProxy +from zmq.devices.monitoredqueue import monitored_queue + + +class MonitoredQueueBase(ProxyBase): + """Base class for overriding methods.""" + + _in_prefix = b'' + _out_prefix = b'' + + def __init__(self, in_type, out_type, mon_type=PUB, in_prefix=b'in', out_prefix=b'out'): + + ProxyBase.__init__(self, in_type=in_type, out_type=out_type, mon_type=mon_type) + + self._in_prefix = in_prefix + self._out_prefix = out_prefix + + def run_device(self): + ins,outs,mons = self._setup_sockets() + monitored_queue(ins, outs, mons, self._in_prefix, self._out_prefix) + + +class MonitoredQueue(MonitoredQueueBase, Proxy): + """Class for running monitored_queue in the background. + + See zmq.devices.Device for most of the spec. MonitoredQueue differs from Proxy, + only in that it adds a ``prefix`` to messages sent on the monitor socket, + with a different prefix for each direction. + + MQ also supports ROUTER on both sides, which zmq.proxy does not. + + If a message arrives on `in_sock`, it will be prefixed with `in_prefix` on the monitor socket. + If it arrives on out_sock, it will be prefixed with `out_prefix`. + + A PUB socket is the most logical choice for the mon_socket, but it is not required. + """ + pass + + +class ThreadMonitoredQueue(MonitoredQueueBase, ThreadProxy): + """Run zmq.monitored_queue in a background thread. + + See MonitoredQueue and Proxy for details. + """ + pass + + +class ProcessMonitoredQueue(MonitoredQueueBase, ProcessProxy): + """Run zmq.monitored_queue in a background thread. + + See MonitoredQueue and Proxy for details. + """ + + +__all__ = [ + 'MonitoredQueue', + 'ThreadMonitoredQueue', + 'ProcessMonitoredQueue' +] diff --git a/src/console/zmq/devices/proxydevice.py b/src/console/zmq/devices/proxydevice.py new file mode 100755 index 00000000..68be3f15 --- /dev/null +++ b/src/console/zmq/devices/proxydevice.py @@ -0,0 +1,90 @@ +"""Proxy classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import zmq +from zmq.devices.basedevice import Device, ThreadDevice, ProcessDevice + + +class ProxyBase(object): + """Base class for overriding methods.""" + + def __init__(self, in_type, out_type, mon_type=zmq.PUB): + + Device.__init__(self, in_type=in_type, out_type=out_type) + self.mon_type = mon_type + self._mon_binds = [] + self._mon_connects = [] + self._mon_sockopts = [] + + def bind_mon(self, addr): + """Enqueue ZMQ address for binding on mon_socket. + + See zmq.Socket.bind for details. + """ + self._mon_binds.append(addr) + + def connect_mon(self, addr): + """Enqueue ZMQ address for connecting on mon_socket. + + See zmq.Socket.bind for details. + """ + self._mon_connects.append(addr) + + def setsockopt_mon(self, opt, value): + """Enqueue setsockopt(opt, value) for mon_socket + + See zmq.Socket.setsockopt for details. + """ + self._mon_sockopts.append((opt, value)) + + def _setup_sockets(self): + ins,outs = Device._setup_sockets(self) + ctx = self._context + mons = ctx.socket(self.mon_type) + + # set sockopts (must be done first, in case of zmq.IDENTITY) + for opt,value in self._mon_sockopts: + mons.setsockopt(opt, value) + + for iface in self._mon_binds: + mons.bind(iface) + + for iface in self._mon_connects: + mons.connect(iface) + + return ins,outs,mons + + def run_device(self): + ins,outs,mons = self._setup_sockets() + zmq.proxy(ins, outs, mons) + +class Proxy(ProxyBase, Device): + """Threadsafe Proxy object. + + See zmq.devices.Device for most of the spec. This subclass adds a + <method>_mon version of each <method>_{in|out} method, for configuring the + monitor socket. + + A Proxy is a 3-socket ZMQ Device that functions just like a + QUEUE, except each message is also sent out on the monitor socket. + + A PUB socket is the most logical choice for the mon_socket, but it is not required. + """ + pass + +class ThreadProxy(ProxyBase, ThreadDevice): + """Proxy in a Thread. See Proxy for more.""" + pass + +class ProcessProxy(ProxyBase, ProcessDevice): + """Proxy in a Process. See Proxy for more.""" + pass + + +__all__ = [ + 'Proxy', + 'ThreadProxy', + 'ProcessProxy', +] diff --git a/src/console/zmq/error.py b/src/console/zmq/error.py new file mode 100755 index 00000000..48cdaafa --- /dev/null +++ b/src/console/zmq/error.py @@ -0,0 +1,164 @@ +"""0MQ Error classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +class ZMQBaseError(Exception): + """Base exception class for 0MQ errors in Python.""" + pass + +class ZMQError(ZMQBaseError): + """Wrap an errno style error. + + Parameters + ---------- + errno : int + The ZMQ errno or None. If None, then ``zmq_errno()`` is called and + used. + msg : string + Description of the error or None. + """ + errno = None + + def __init__(self, errno=None, msg=None): + """Wrap an errno style error. + + Parameters + ---------- + errno : int + The ZMQ errno or None. If None, then ``zmq_errno()`` is called and + used. + msg : string + Description of the error or None. + """ + from zmq.backend import strerror, zmq_errno + if errno is None: + errno = zmq_errno() + if isinstance(errno, int): + self.errno = errno + if msg is None: + self.strerror = strerror(errno) + else: + self.strerror = msg + else: + if msg is None: + self.strerror = str(errno) + else: + self.strerror = msg + # flush signals, because there could be a SIGINT + # waiting to pounce, resulting in uncaught exceptions. + # Doing this here means getting SIGINT during a blocking + # libzmq call will raise a *catchable* KeyboardInterrupt + # PyErr_CheckSignals() + + def __str__(self): + return self.strerror + + def __repr__(self): + return "ZMQError('%s')"%self.strerror + + +class ZMQBindError(ZMQBaseError): + """An error for ``Socket.bind_to_random_port()``. + + See Also + -------- + .Socket.bind_to_random_port + """ + pass + + +class NotDone(ZMQBaseError): + """Raised when timeout is reached while waiting for 0MQ to finish with a Message + + See Also + -------- + .MessageTracker.wait : object for tracking when ZeroMQ is done + """ + pass + + +class ContextTerminated(ZMQError): + """Wrapper for zmq.ETERM + + .. versionadded:: 13.0 + """ + pass + + +class Again(ZMQError): + """Wrapper for zmq.EAGAIN + + .. versionadded:: 13.0 + """ + pass + + +def _check_rc(rc, errno=None): + """internal utility for checking zmq return condition + + and raising the appropriate Exception class + """ + if rc < 0: + from zmq.backend import zmq_errno + if errno is None: + errno = zmq_errno() + from zmq import EAGAIN, ETERM + if errno == EAGAIN: + raise Again(errno) + elif errno == ETERM: + raise ContextTerminated(errno) + else: + raise ZMQError(errno) + +_zmq_version_info = None +_zmq_version = None + +class ZMQVersionError(NotImplementedError): + """Raised when a feature is not provided by the linked version of libzmq. + + .. versionadded:: 14.2 + """ + min_version = None + def __init__(self, min_version, msg='Feature'): + global _zmq_version + if _zmq_version is None: + from zmq import zmq_version + _zmq_version = zmq_version() + self.msg = msg + self.min_version = min_version + self.version = _zmq_version + + def __repr__(self): + return "ZMQVersionError('%s')" % str(self) + + def __str__(self): + return "%s requires libzmq >= %s, have %s" % (self.msg, self.min_version, self.version) + + +def _check_version(min_version_info, msg='Feature'): + """Check for libzmq + + raises ZMQVersionError if current zmq version is not at least min_version + + min_version_info is a tuple of integers, and will be compared against zmq.zmq_version_info(). + """ + global _zmq_version_info + if _zmq_version_info is None: + from zmq import zmq_version_info + _zmq_version_info = zmq_version_info() + if _zmq_version_info < min_version_info: + min_version = '.'.join(str(v) for v in min_version_info) + raise ZMQVersionError(min_version, msg) + + +__all__ = [ + 'ZMQBaseError', + 'ZMQBindError', + 'ZMQError', + 'NotDone', + 'ContextTerminated', + 'Again', + 'ZMQVersionError', +] diff --git a/src/console/zmq/eventloop/__init__.py b/src/console/zmq/eventloop/__init__.py new file mode 100755 index 00000000..568e8e8d --- /dev/null +++ b/src/console/zmq/eventloop/__init__.py @@ -0,0 +1,5 @@ +"""A Tornado based event loop for PyZMQ.""" + +from zmq.eventloop.ioloop import IOLoop + +__all__ = ['IOLoop']
\ No newline at end of file diff --git a/src/console/zmq/eventloop/ioloop.py b/src/console/zmq/eventloop/ioloop.py new file mode 100755 index 00000000..35f4c418 --- /dev/null +++ b/src/console/zmq/eventloop/ioloop.py @@ -0,0 +1,193 @@ +# coding: utf-8 +"""tornado IOLoop API with zmq compatibility + +If you have tornado ≥ 3.0, this is a subclass of tornado's IOLoop, +otherwise we ship a minimal subset of tornado in zmq.eventloop.minitornado. + +The minimal shipped version of tornado's IOLoop does not include +support for concurrent futures - this will only be available if you +have tornado ≥ 3.0. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import absolute_import, division, with_statement + +import os +import time +import warnings + +from zmq import ( + Poller, + POLLIN, POLLOUT, POLLERR, + ZMQError, ETERM, +) + +try: + import tornado + tornado_version = tornado.version_info +except (ImportError, AttributeError): + tornado_version = () + +try: + # tornado ≥ 3 + from tornado.ioloop import PollIOLoop, PeriodicCallback + from tornado.log import gen_log +except ImportError: + from .minitornado.ioloop import PollIOLoop, PeriodicCallback + from .minitornado.log import gen_log + + +class DelayedCallback(PeriodicCallback): + """Schedules the given callback to be called once. + + The callback is called once, after callback_time milliseconds. + + `start` must be called after the DelayedCallback is created. + + The timeout is calculated from when `start` is called. + """ + def __init__(self, callback, callback_time, io_loop=None): + # PeriodicCallback require callback_time to be positive + warnings.warn("""DelayedCallback is deprecated. + Use loop.add_timeout instead.""", DeprecationWarning) + callback_time = max(callback_time, 1e-3) + super(DelayedCallback, self).__init__(callback, callback_time, io_loop) + + def start(self): + """Starts the timer.""" + self._running = True + self._firstrun = True + self._next_timeout = time.time() + self.callback_time / 1000.0 + self.io_loop.add_timeout(self._next_timeout, self._run) + + def _run(self): + if not self._running: return + self._running = False + try: + self.callback() + except Exception: + gen_log.error("Error in delayed callback", exc_info=True) + + +class ZMQPoller(object): + """A poller that can be used in the tornado IOLoop. + + This simply wraps a regular zmq.Poller, scaling the timeout + by 1000, so that it is in seconds rather than milliseconds. + """ + + def __init__(self): + self._poller = Poller() + + @staticmethod + def _map_events(events): + """translate IOLoop.READ/WRITE/ERROR event masks into zmq.POLLIN/OUT/ERR""" + z_events = 0 + if events & IOLoop.READ: + z_events |= POLLIN + if events & IOLoop.WRITE: + z_events |= POLLOUT + if events & IOLoop.ERROR: + z_events |= POLLERR + return z_events + + @staticmethod + def _remap_events(z_events): + """translate zmq.POLLIN/OUT/ERR event masks into IOLoop.READ/WRITE/ERROR""" + events = 0 + if z_events & POLLIN: + events |= IOLoop.READ + if z_events & POLLOUT: + events |= IOLoop.WRITE + if z_events & POLLERR: + events |= IOLoop.ERROR + return events + + def register(self, fd, events): + return self._poller.register(fd, self._map_events(events)) + + def modify(self, fd, events): + return self._poller.modify(fd, self._map_events(events)) + + def unregister(self, fd): + return self._poller.unregister(fd) + + def poll(self, timeout): + """poll in seconds rather than milliseconds. + + Event masks will be IOLoop.READ/WRITE/ERROR + """ + z_events = self._poller.poll(1000*timeout) + return [ (fd,self._remap_events(evt)) for (fd,evt) in z_events ] + + def close(self): + pass + + +class ZMQIOLoop(PollIOLoop): + """ZMQ subclass of tornado's IOLoop""" + def initialize(self, impl=None, **kwargs): + impl = ZMQPoller() if impl is None else impl + super(ZMQIOLoop, self).initialize(impl=impl, **kwargs) + + @staticmethod + def instance(): + """Returns a global `IOLoop` instance. + + Most applications have a single, global `IOLoop` running on the + main thread. Use this method to get this instance from + another thread. To get the current thread's `IOLoop`, use `current()`. + """ + # install ZMQIOLoop as the active IOLoop implementation + # when using tornado 3 + if tornado_version >= (3,): + PollIOLoop.configure(ZMQIOLoop) + return PollIOLoop.instance() + + def start(self): + try: + super(ZMQIOLoop, self).start() + except ZMQError as e: + if e.errno == ETERM: + # quietly return on ETERM + pass + else: + raise e + + +if tornado_version >= (3,0) and tornado_version < (3,1): + def backport_close(self, all_fds=False): + """backport IOLoop.close to 3.0 from 3.1 (supports fd.close() method)""" + from zmq.eventloop.minitornado.ioloop import PollIOLoop as mini_loop + return mini_loop.close.__get__(self)(all_fds) + ZMQIOLoop.close = backport_close + + +# public API name +IOLoop = ZMQIOLoop + + +def install(): + """set the tornado IOLoop instance with the pyzmq IOLoop. + + After calling this function, tornado's IOLoop.instance() and pyzmq's + IOLoop.instance() will return the same object. + + An assertion error will be raised if tornado's IOLoop has been initialized + prior to calling this function. + """ + from tornado import ioloop + # check if tornado's IOLoop is already initialized to something other + # than the pyzmq IOLoop instance: + assert (not ioloop.IOLoop.initialized()) or \ + ioloop.IOLoop.instance() is IOLoop.instance(), "tornado IOLoop already initialized" + + if tornado_version >= (3,): + # tornado 3 has an official API for registering new defaults, yay! + ioloop.IOLoop.configure(ZMQIOLoop) + else: + # we have to set the global instance explicitly + ioloop.IOLoop._instance = IOLoop.instance() + diff --git a/src/console/zmq/eventloop/minitornado/__init__.py b/src/console/zmq/eventloop/minitornado/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/__init__.py diff --git a/src/console/zmq/eventloop/minitornado/concurrent.py b/src/console/zmq/eventloop/minitornado/concurrent.py new file mode 100755 index 00000000..519b23d5 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/concurrent.py @@ -0,0 +1,11 @@ +"""pyzmq does not ship tornado's futures, +this just raises informative NotImplementedErrors to avoid having to change too much code. +""" + +class NotImplementedFuture(object): + def __init__(self, *args, **kwargs): + raise NotImplementedError("pyzmq does not ship tornado's Futures, " + "install tornado >= 3.0 for future support." + ) + +Future = TracebackFuture = NotImplementedFuture diff --git a/src/console/zmq/eventloop/minitornado/ioloop.py b/src/console/zmq/eventloop/minitornado/ioloop.py new file mode 100755 index 00000000..710a3ecb --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/ioloop.py @@ -0,0 +1,829 @@ +#!/usr/bin/env python +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""An I/O event loop for non-blocking sockets. + +Typical applications will use a single `IOLoop` object, in the +`IOLoop.instance` singleton. The `IOLoop.start` method should usually +be called at the end of the ``main()`` function. Atypical applications may +use more than one `IOLoop`, such as one `IOLoop` per thread, or per `unittest` +case. + +In addition to I/O events, the `IOLoop` can also schedule time-based events. +`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import datetime +import errno +import functools +import heapq +import logging +import numbers +import os +import select +import sys +import threading +import time +import traceback + +from .concurrent import Future, TracebackFuture +from .log import app_log, gen_log +from . import stack_context +from .util import Configurable + +try: + import signal +except ImportError: + signal = None + +try: + import thread # py2 +except ImportError: + import _thread as thread # py3 + +from .platform.auto import set_close_exec, Waker + + +class TimeoutError(Exception): + pass + + +class IOLoop(Configurable): + """A level-triggered I/O loop. + + We use ``epoll`` (Linux) or ``kqueue`` (BSD and Mac OS X) if they + are available, or else we fall back on select(). If you are + implementing a system that needs to handle thousands of + simultaneous connections, you should use a system that supports + either ``epoll`` or ``kqueue``. + + Example usage for a simple TCP server:: + + import errno + import functools + import ioloop + import socket + + def connection_ready(sock, fd, events): + while True: + try: + connection, address = sock.accept() + except socket.error, e: + if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN): + raise + return + connection.setblocking(0) + handle_connection(connection, address) + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(0) + sock.bind(("", port)) + sock.listen(128) + + io_loop = ioloop.IOLoop.instance() + callback = functools.partial(connection_ready, sock) + io_loop.add_handler(sock.fileno(), callback, io_loop.READ) + io_loop.start() + + """ + # Constants from the epoll module + _EPOLLIN = 0x001 + _EPOLLPRI = 0x002 + _EPOLLOUT = 0x004 + _EPOLLERR = 0x008 + _EPOLLHUP = 0x010 + _EPOLLRDHUP = 0x2000 + _EPOLLONESHOT = (1 << 30) + _EPOLLET = (1 << 31) + + # Our events map exactly to the epoll events + NONE = 0 + READ = _EPOLLIN + WRITE = _EPOLLOUT + ERROR = _EPOLLERR | _EPOLLHUP + + # Global lock for creating global IOLoop instance + _instance_lock = threading.Lock() + + _current = threading.local() + + @staticmethod + def instance(): + """Returns a global `IOLoop` instance. + + Most applications have a single, global `IOLoop` running on the + main thread. Use this method to get this instance from + another thread. To get the current thread's `IOLoop`, use `current()`. + """ + if not hasattr(IOLoop, "_instance"): + with IOLoop._instance_lock: + if not hasattr(IOLoop, "_instance"): + # New instance after double check + IOLoop._instance = IOLoop() + return IOLoop._instance + + @staticmethod + def initialized(): + """Returns true if the singleton instance has been created.""" + return hasattr(IOLoop, "_instance") + + def install(self): + """Installs this `IOLoop` object as the singleton instance. + + This is normally not necessary as `instance()` will create + an `IOLoop` on demand, but you may want to call `install` to use + a custom subclass of `IOLoop`. + """ + assert not IOLoop.initialized() + IOLoop._instance = self + + @staticmethod + def current(): + """Returns the current thread's `IOLoop`. + + If an `IOLoop` is currently running or has been marked as current + by `make_current`, returns that instance. Otherwise returns + `IOLoop.instance()`, i.e. the main thread's `IOLoop`. + + A common pattern for classes that depend on ``IOLoops`` is to use + a default argument to enable programs with multiple ``IOLoops`` + but not require the argument for simpler applications:: + + class MyClass(object): + def __init__(self, io_loop=None): + self.io_loop = io_loop or IOLoop.current() + + In general you should use `IOLoop.current` as the default when + constructing an asynchronous object, and use `IOLoop.instance` + when you mean to communicate to the main thread from a different + one. + """ + current = getattr(IOLoop._current, "instance", None) + if current is None: + return IOLoop.instance() + return current + + def make_current(self): + """Makes this the `IOLoop` for the current thread. + + An `IOLoop` automatically becomes current for its thread + when it is started, but it is sometimes useful to call + `make_current` explictly before starting the `IOLoop`, + so that code run at startup time can find the right + instance. + """ + IOLoop._current.instance = self + + @staticmethod + def clear_current(): + IOLoop._current.instance = None + + @classmethod + def configurable_base(cls): + return IOLoop + + @classmethod + def configurable_default(cls): + # this is the only patch to IOLoop: + from zmq.eventloop.ioloop import ZMQIOLoop + return ZMQIOLoop + # the remainder of this method is unused, + # but left for preservation reasons + if hasattr(select, "epoll"): + from tornado.platform.epoll import EPollIOLoop + return EPollIOLoop + if hasattr(select, "kqueue"): + # Python 2.6+ on BSD or Mac + from tornado.platform.kqueue import KQueueIOLoop + return KQueueIOLoop + from tornado.platform.select import SelectIOLoop + return SelectIOLoop + + def initialize(self): + pass + + def close(self, all_fds=False): + """Closes the `IOLoop`, freeing any resources used. + + If ``all_fds`` is true, all file descriptors registered on the + IOLoop will be closed (not just the ones created by the + `IOLoop` itself). + + Many applications will only use a single `IOLoop` that runs for the + entire lifetime of the process. In that case closing the `IOLoop` + is not necessary since everything will be cleaned up when the + process exits. `IOLoop.close` is provided mainly for scenarios + such as unit tests, which create and destroy a large number of + ``IOLoops``. + + An `IOLoop` must be completely stopped before it can be closed. This + means that `IOLoop.stop()` must be called *and* `IOLoop.start()` must + be allowed to return before attempting to call `IOLoop.close()`. + Therefore the call to `close` will usually appear just after + the call to `start` rather than near the call to `stop`. + + .. versionchanged:: 3.1 + If the `IOLoop` implementation supports non-integer objects + for "file descriptors", those objects will have their + ``close`` method when ``all_fds`` is true. + """ + raise NotImplementedError() + + def add_handler(self, fd, handler, events): + """Registers the given handler to receive the given events for fd. + + The ``events`` argument is a bitwise or of the constants + ``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``. + + When an event occurs, ``handler(fd, events)`` will be run. + """ + raise NotImplementedError() + + def update_handler(self, fd, events): + """Changes the events we listen for fd.""" + raise NotImplementedError() + + def remove_handler(self, fd): + """Stop listening for events on fd.""" + raise NotImplementedError() + + def set_blocking_signal_threshold(self, seconds, action): + """Sends a signal if the `IOLoop` is blocked for more than + ``s`` seconds. + + Pass ``seconds=None`` to disable. Requires Python 2.6 on a unixy + platform. + + The action parameter is a Python signal handler. Read the + documentation for the `signal` module for more information. + If ``action`` is None, the process will be killed if it is + blocked for too long. + """ + raise NotImplementedError() + + def set_blocking_log_threshold(self, seconds): + """Logs a stack trace if the `IOLoop` is blocked for more than + ``s`` seconds. + + Equivalent to ``set_blocking_signal_threshold(seconds, + self.log_stack)`` + """ + self.set_blocking_signal_threshold(seconds, self.log_stack) + + def log_stack(self, signal, frame): + """Signal handler to log the stack trace of the current thread. + + For use with `set_blocking_signal_threshold`. + """ + gen_log.warning('IOLoop blocked for %f seconds in\n%s', + self._blocking_signal_threshold, + ''.join(traceback.format_stack(frame))) + + def start(self): + """Starts the I/O loop. + + The loop will run until one of the callbacks calls `stop()`, which + will make the loop stop after the current event iteration completes. + """ + raise NotImplementedError() + + def stop(self): + """Stop the I/O loop. + + If the event loop is not currently running, the next call to `start()` + will return immediately. + + To use asynchronous methods from otherwise-synchronous code (such as + unit tests), you can start and stop the event loop like this:: + + ioloop = IOLoop() + async_method(ioloop=ioloop, callback=ioloop.stop) + ioloop.start() + + ``ioloop.start()`` will return after ``async_method`` has run + its callback, whether that callback was invoked before or + after ``ioloop.start``. + + Note that even after `stop` has been called, the `IOLoop` is not + completely stopped until `IOLoop.start` has also returned. + Some work that was scheduled before the call to `stop` may still + be run before the `IOLoop` shuts down. + """ + raise NotImplementedError() + + def run_sync(self, func, timeout=None): + """Starts the `IOLoop`, runs the given function, and stops the loop. + + If the function returns a `.Future`, the `IOLoop` will run + until the future is resolved. If it raises an exception, the + `IOLoop` will stop and the exception will be re-raised to the + caller. + + The keyword-only argument ``timeout`` may be used to set + a maximum duration for the function. If the timeout expires, + a `TimeoutError` is raised. + + This method is useful in conjunction with `tornado.gen.coroutine` + to allow asynchronous calls in a ``main()`` function:: + + @gen.coroutine + def main(): + # do stuff... + + if __name__ == '__main__': + IOLoop.instance().run_sync(main) + """ + future_cell = [None] + + def run(): + try: + result = func() + except Exception: + future_cell[0] = TracebackFuture() + future_cell[0].set_exc_info(sys.exc_info()) + else: + if isinstance(result, Future): + future_cell[0] = result + else: + future_cell[0] = Future() + future_cell[0].set_result(result) + self.add_future(future_cell[0], lambda future: self.stop()) + self.add_callback(run) + if timeout is not None: + timeout_handle = self.add_timeout(self.time() + timeout, self.stop) + self.start() + if timeout is not None: + self.remove_timeout(timeout_handle) + if not future_cell[0].done(): + raise TimeoutError('Operation timed out after %s seconds' % timeout) + return future_cell[0].result() + + def time(self): + """Returns the current time according to the `IOLoop`'s clock. + + The return value is a floating-point number relative to an + unspecified time in the past. + + By default, the `IOLoop`'s time function is `time.time`. However, + it may be configured to use e.g. `time.monotonic` instead. + Calls to `add_timeout` that pass a number instead of a + `datetime.timedelta` should use this function to compute the + appropriate time, so they can work no matter what time function + is chosen. + """ + return time.time() + + def add_timeout(self, deadline, callback): + """Runs the ``callback`` at the time ``deadline`` from the I/O loop. + + Returns an opaque handle that may be passed to + `remove_timeout` to cancel. + + ``deadline`` may be a number denoting a time (on the same + scale as `IOLoop.time`, normally `time.time`), or a + `datetime.timedelta` object for a deadline relative to the + current time. + + Note that it is not safe to call `add_timeout` from other threads. + Instead, you must use `add_callback` to transfer control to the + `IOLoop`'s thread, and then call `add_timeout` from there. + """ + raise NotImplementedError() + + def remove_timeout(self, timeout): + """Cancels a pending timeout. + + The argument is a handle as returned by `add_timeout`. It is + safe to call `remove_timeout` even if the callback has already + been run. + """ + raise NotImplementedError() + + def add_callback(self, callback, *args, **kwargs): + """Calls the given callback on the next I/O loop iteration. + + It is safe to call this method from any thread at any time, + except from a signal handler. Note that this is the **only** + method in `IOLoop` that makes this thread-safety guarantee; all + other interaction with the `IOLoop` must be done from that + `IOLoop`'s thread. `add_callback()` may be used to transfer + control from other threads to the `IOLoop`'s thread. + + To add a callback from a signal handler, see + `add_callback_from_signal`. + """ + raise NotImplementedError() + + def add_callback_from_signal(self, callback, *args, **kwargs): + """Calls the given callback on the next I/O loop iteration. + + Safe for use from a Python signal handler; should not be used + otherwise. + + Callbacks added with this method will be run without any + `.stack_context`, to avoid picking up the context of the function + that was interrupted by the signal. + """ + raise NotImplementedError() + + def add_future(self, future, callback): + """Schedules a callback on the ``IOLoop`` when the given + `.Future` is finished. + + The callback is invoked with one argument, the + `.Future`. + """ + assert isinstance(future, Future) + callback = stack_context.wrap(callback) + future.add_done_callback( + lambda future: self.add_callback(callback, future)) + + def _run_callback(self, callback): + """Runs a callback with error handling. + + For use in subclasses. + """ + try: + callback() + except Exception: + self.handle_callback_exception(callback) + + def handle_callback_exception(self, callback): + """This method is called whenever a callback run by the `IOLoop` + throws an exception. + + By default simply logs the exception as an error. Subclasses + may override this method to customize reporting of exceptions. + + The exception itself is not passed explicitly, but is available + in `sys.exc_info`. + """ + app_log.error("Exception in callback %r", callback, exc_info=True) + + +class PollIOLoop(IOLoop): + """Base class for IOLoops built around a select-like function. + + For concrete implementations, see `tornado.platform.epoll.EPollIOLoop` + (Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or + `tornado.platform.select.SelectIOLoop` (all platforms). + """ + def initialize(self, impl, time_func=None): + super(PollIOLoop, self).initialize() + self._impl = impl + if hasattr(self._impl, 'fileno'): + set_close_exec(self._impl.fileno()) + self.time_func = time_func or time.time + self._handlers = {} + self._events = {} + self._callbacks = [] + self._callback_lock = threading.Lock() + self._timeouts = [] + self._cancellations = 0 + self._running = False + self._stopped = False + self._closing = False + self._thread_ident = None + self._blocking_signal_threshold = None + + # Create a pipe that we send bogus data to when we want to wake + # the I/O loop when it is idle + self._waker = Waker() + self.add_handler(self._waker.fileno(), + lambda fd, events: self._waker.consume(), + self.READ) + + def close(self, all_fds=False): + with self._callback_lock: + self._closing = True + self.remove_handler(self._waker.fileno()) + if all_fds: + for fd in self._handlers.keys(): + try: + close_method = getattr(fd, 'close', None) + if close_method is not None: + close_method() + else: + os.close(fd) + except Exception: + gen_log.debug("error closing fd %s", fd, exc_info=True) + self._waker.close() + self._impl.close() + + def add_handler(self, fd, handler, events): + self._handlers[fd] = stack_context.wrap(handler) + self._impl.register(fd, events | self.ERROR) + + def update_handler(self, fd, events): + self._impl.modify(fd, events | self.ERROR) + + def remove_handler(self, fd): + self._handlers.pop(fd, None) + self._events.pop(fd, None) + try: + self._impl.unregister(fd) + except Exception: + gen_log.debug("Error deleting fd from IOLoop", exc_info=True) + + def set_blocking_signal_threshold(self, seconds, action): + if not hasattr(signal, "setitimer"): + gen_log.error("set_blocking_signal_threshold requires a signal module " + "with the setitimer method") + return + self._blocking_signal_threshold = seconds + if seconds is not None: + signal.signal(signal.SIGALRM, + action if action is not None else signal.SIG_DFL) + + def start(self): + if not logging.getLogger().handlers: + # The IOLoop catches and logs exceptions, so it's + # important that log output be visible. However, python's + # default behavior for non-root loggers (prior to python + # 3.2) is to print an unhelpful "no handlers could be + # found" message rather than the actual log entry, so we + # must explicitly configure logging if we've made it this + # far without anything. + logging.basicConfig() + if self._stopped: + self._stopped = False + return + old_current = getattr(IOLoop._current, "instance", None) + IOLoop._current.instance = self + self._thread_ident = thread.get_ident() + self._running = True + + # signal.set_wakeup_fd closes a race condition in event loops: + # a signal may arrive at the beginning of select/poll/etc + # before it goes into its interruptible sleep, so the signal + # will be consumed without waking the select. The solution is + # for the (C, synchronous) signal handler to write to a pipe, + # which will then be seen by select. + # + # In python's signal handling semantics, this only matters on the + # main thread (fortunately, set_wakeup_fd only works on the main + # thread and will raise a ValueError otherwise). + # + # If someone has already set a wakeup fd, we don't want to + # disturb it. This is an issue for twisted, which does its + # SIGCHILD processing in response to its own wakeup fd being + # written to. As long as the wakeup fd is registered on the IOLoop, + # the loop will still wake up and everything should work. + old_wakeup_fd = None + if hasattr(signal, 'set_wakeup_fd') and os.name == 'posix': + # requires python 2.6+, unix. set_wakeup_fd exists but crashes + # the python process on windows. + try: + old_wakeup_fd = signal.set_wakeup_fd(self._waker.write_fileno()) + if old_wakeup_fd != -1: + # Already set, restore previous value. This is a little racy, + # but there's no clean get_wakeup_fd and in real use the + # IOLoop is just started once at the beginning. + signal.set_wakeup_fd(old_wakeup_fd) + old_wakeup_fd = None + except ValueError: # non-main thread + pass + + while True: + poll_timeout = 3600.0 + + # Prevent IO event starvation by delaying new callbacks + # to the next iteration of the event loop. + with self._callback_lock: + callbacks = self._callbacks + self._callbacks = [] + for callback in callbacks: + self._run_callback(callback) + + if self._timeouts: + now = self.time() + while self._timeouts: + if self._timeouts[0].callback is None: + # the timeout was cancelled + heapq.heappop(self._timeouts) + self._cancellations -= 1 + elif self._timeouts[0].deadline <= now: + timeout = heapq.heappop(self._timeouts) + self._run_callback(timeout.callback) + else: + seconds = self._timeouts[0].deadline - now + poll_timeout = min(seconds, poll_timeout) + break + if (self._cancellations > 512 + and self._cancellations > (len(self._timeouts) >> 1)): + # Clean up the timeout queue when it gets large and it's + # more than half cancellations. + self._cancellations = 0 + self._timeouts = [x for x in self._timeouts + if x.callback is not None] + heapq.heapify(self._timeouts) + + if self._callbacks: + # If any callbacks or timeouts called add_callback, + # we don't want to wait in poll() before we run them. + poll_timeout = 0.0 + + if not self._running: + break + + if self._blocking_signal_threshold is not None: + # clear alarm so it doesn't fire while poll is waiting for + # events. + signal.setitimer(signal.ITIMER_REAL, 0, 0) + + try: + event_pairs = self._impl.poll(poll_timeout) + except Exception as e: + # Depending on python version and IOLoop implementation, + # different exception types may be thrown and there are + # two ways EINTR might be signaled: + # * e.errno == errno.EINTR + # * e.args is like (errno.EINTR, 'Interrupted system call') + if (getattr(e, 'errno', None) == errno.EINTR or + (isinstance(getattr(e, 'args', None), tuple) and + len(e.args) == 2 and e.args[0] == errno.EINTR)): + continue + else: + raise + + if self._blocking_signal_threshold is not None: + signal.setitimer(signal.ITIMER_REAL, + self._blocking_signal_threshold, 0) + + # Pop one fd at a time from the set of pending fds and run + # its handler. Since that handler may perform actions on + # other file descriptors, there may be reentrant calls to + # this IOLoop that update self._events + self._events.update(event_pairs) + while self._events: + fd, events = self._events.popitem() + try: + self._handlers[fd](fd, events) + except (OSError, IOError) as e: + if e.args[0] == errno.EPIPE: + # Happens when the client closes the connection + pass + else: + app_log.error("Exception in I/O handler for fd %s", + fd, exc_info=True) + except Exception: + app_log.error("Exception in I/O handler for fd %s", + fd, exc_info=True) + # reset the stopped flag so another start/stop pair can be issued + self._stopped = False + if self._blocking_signal_threshold is not None: + signal.setitimer(signal.ITIMER_REAL, 0, 0) + IOLoop._current.instance = old_current + if old_wakeup_fd is not None: + signal.set_wakeup_fd(old_wakeup_fd) + + def stop(self): + self._running = False + self._stopped = True + self._waker.wake() + + def time(self): + return self.time_func() + + def add_timeout(self, deadline, callback): + timeout = _Timeout(deadline, stack_context.wrap(callback), self) + heapq.heappush(self._timeouts, timeout) + return timeout + + def remove_timeout(self, timeout): + # Removing from a heap is complicated, so just leave the defunct + # timeout object in the queue (see discussion in + # http://docs.python.org/library/heapq.html). + # If this turns out to be a problem, we could add a garbage + # collection pass whenever there are too many dead timeouts. + timeout.callback = None + self._cancellations += 1 + + def add_callback(self, callback, *args, **kwargs): + with self._callback_lock: + if self._closing: + raise RuntimeError("IOLoop is closing") + list_empty = not self._callbacks + self._callbacks.append(functools.partial( + stack_context.wrap(callback), *args, **kwargs)) + if list_empty and thread.get_ident() != self._thread_ident: + # If we're in the IOLoop's thread, we know it's not currently + # polling. If we're not, and we added the first callback to an + # empty list, we may need to wake it up (it may wake up on its + # own, but an occasional extra wake is harmless). Waking + # up a polling IOLoop is relatively expensive, so we try to + # avoid it when we can. + self._waker.wake() + + def add_callback_from_signal(self, callback, *args, **kwargs): + with stack_context.NullContext(): + if thread.get_ident() != self._thread_ident: + # if the signal is handled on another thread, we can add + # it normally (modulo the NullContext) + self.add_callback(callback, *args, **kwargs) + else: + # If we're on the IOLoop's thread, we cannot use + # the regular add_callback because it may deadlock on + # _callback_lock. Blindly insert into self._callbacks. + # This is safe because the GIL makes list.append atomic. + # One subtlety is that if the signal interrupted the + # _callback_lock block in IOLoop.start, we may modify + # either the old or new version of self._callbacks, + # but either way will work. + self._callbacks.append(functools.partial( + stack_context.wrap(callback), *args, **kwargs)) + + +class _Timeout(object): + """An IOLoop timeout, a UNIX timestamp and a callback""" + + # Reduce memory overhead when there are lots of pending callbacks + __slots__ = ['deadline', 'callback'] + + def __init__(self, deadline, callback, io_loop): + if isinstance(deadline, numbers.Real): + self.deadline = deadline + elif isinstance(deadline, datetime.timedelta): + self.deadline = io_loop.time() + _Timeout.timedelta_to_seconds(deadline) + else: + raise TypeError("Unsupported deadline %r" % deadline) + self.callback = callback + + @staticmethod + def timedelta_to_seconds(td): + """Equivalent to td.total_seconds() (introduced in python 2.7).""" + return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6) + + # Comparison methods to sort by deadline, with object id as a tiebreaker + # to guarantee a consistent ordering. The heapq module uses __le__ + # in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons + # use __lt__). + def __lt__(self, other): + return ((self.deadline, id(self)) < + (other.deadline, id(other))) + + def __le__(self, other): + return ((self.deadline, id(self)) <= + (other.deadline, id(other))) + + +class PeriodicCallback(object): + """Schedules the given callback to be called periodically. + + The callback is called every ``callback_time`` milliseconds. + + `start` must be called after the `PeriodicCallback` is created. + """ + def __init__(self, callback, callback_time, io_loop=None): + self.callback = callback + if callback_time <= 0: + raise ValueError("Periodic callback must have a positive callback_time") + self.callback_time = callback_time + self.io_loop = io_loop or IOLoop.current() + self._running = False + self._timeout = None + + def start(self): + """Starts the timer.""" + self._running = True + self._next_timeout = self.io_loop.time() + self._schedule_next() + + def stop(self): + """Stops the timer.""" + self._running = False + if self._timeout is not None: + self.io_loop.remove_timeout(self._timeout) + self._timeout = None + + def _run(self): + if not self._running: + return + try: + self.callback() + except Exception: + app_log.error("Error in periodic callback", exc_info=True) + self._schedule_next() + + def _schedule_next(self): + if self._running: + current_time = self.io_loop.time() + while self._next_timeout <= current_time: + self._next_timeout += self.callback_time / 1000.0 + self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run) diff --git a/src/console/zmq/eventloop/minitornado/log.py b/src/console/zmq/eventloop/minitornado/log.py new file mode 100755 index 00000000..49051e89 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/log.py @@ -0,0 +1,6 @@ +"""minimal subset of tornado.log for zmq.eventloop.minitornado""" + +import logging + +app_log = logging.getLogger("tornado.application") +gen_log = logging.getLogger("tornado.general") diff --git a/src/console/zmq/eventloop/minitornado/platform/__init__.py b/src/console/zmq/eventloop/minitornado/platform/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/__init__.py diff --git a/src/console/zmq/eventloop/minitornado/platform/auto.py b/src/console/zmq/eventloop/minitornado/platform/auto.py new file mode 100755 index 00000000..b40ccd94 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/auto.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Implementation of platform-specific functionality. + +For each function or class described in `tornado.platform.interface`, +the appropriate platform-specific implementation exists in this module. +Most code that needs access to this functionality should do e.g.:: + + from tornado.platform.auto import set_close_exec +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import os + +if os.name == 'nt': + from .common import Waker + from .windows import set_close_exec +else: + from .posix import set_close_exec, Waker + +try: + # monotime monkey-patches the time module to have a monotonic function + # in versions of python before 3.3. + import monotime +except ImportError: + pass +try: + from time import monotonic as monotonic_time +except ImportError: + monotonic_time = None diff --git a/src/console/zmq/eventloop/minitornado/platform/common.py b/src/console/zmq/eventloop/minitornado/platform/common.py new file mode 100755 index 00000000..2d75dc1e --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/common.py @@ -0,0 +1,91 @@ +"""Lowest-common-denominator implementations of platform functionality.""" +from __future__ import absolute_import, division, print_function, with_statement + +import errno +import socket + +from . import interface + + +class Waker(interface.Waker): + """Create an OS independent asynchronous pipe. + + For use on platforms that don't have os.pipe() (or where pipes cannot + be passed to select()), but do have sockets. This includes Windows + and Jython. + """ + def __init__(self): + # Based on Zope async.py: http://svn.zope.org/zc.ngi/trunk/src/zc/ngi/async.py + + self.writer = socket.socket() + # Disable buffering -- pulling the trigger sends 1 byte, + # and we want that sent immediately, to wake up ASAP. + self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + count = 0 + while 1: + count += 1 + # Bind to a local port; for efficiency, let the OS pick + # a free port for us. + # Unfortunately, stress tests showed that we may not + # be able to connect to that port ("Address already in + # use") despite that the OS picked it. This appears + # to be a race bug in the Windows socket implementation. + # So we loop until a connect() succeeds (almost always + # on the first try). See the long thread at + # http://mail.zope.org/pipermail/zope/2005-July/160433.html + # for hideous details. + a = socket.socket() + a.bind(("127.0.0.1", 0)) + a.listen(1) + connect_address = a.getsockname() # assigned (host, port) pair + try: + self.writer.connect(connect_address) + break # success + except socket.error as detail: + if (not hasattr(errno, 'WSAEADDRINUSE') or + detail[0] != errno.WSAEADDRINUSE): + # "Address already in use" is the only error + # I've seen on two WinXP Pro SP2 boxes, under + # Pythons 2.3.5 and 2.4.1. + raise + # (10048, 'Address already in use') + # assert count <= 2 # never triggered in Tim's tests + if count >= 10: # I've never seen it go above 2 + a.close() + self.writer.close() + raise socket.error("Cannot bind trigger!") + # Close `a` and try again. Note: I originally put a short + # sleep() here, but it didn't appear to help or hurt. + a.close() + + self.reader, addr = a.accept() + self.reader.setblocking(0) + self.writer.setblocking(0) + a.close() + self.reader_fd = self.reader.fileno() + + def fileno(self): + return self.reader.fileno() + + def write_fileno(self): + return self.writer.fileno() + + def wake(self): + try: + self.writer.send(b"x") + except (IOError, socket.error): + pass + + def consume(self): + try: + while True: + result = self.reader.recv(1024) + if not result: + break + except (IOError, socket.error): + pass + + def close(self): + self.reader.close() + self.writer.close() diff --git a/src/console/zmq/eventloop/minitornado/platform/interface.py b/src/console/zmq/eventloop/minitornado/platform/interface.py new file mode 100755 index 00000000..07da6bab --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/interface.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Interfaces for platform-specific functionality. + +This module exists primarily for documentation purposes and as base classes +for other tornado.platform modules. Most code should import the appropriate +implementation from `tornado.platform.auto`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + + +def set_close_exec(fd): + """Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor.""" + raise NotImplementedError() + + +class Waker(object): + """A socket-like object that can wake another thread from ``select()``. + + The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to + its ``select`` (or ``epoll`` or ``kqueue``) calls. When another + thread wants to wake up the loop, it calls `wake`. Once it has woken + up, it will call `consume` to do any necessary per-wake cleanup. When + the ``IOLoop`` is closed, it closes its waker too. + """ + def fileno(self): + """Returns the read file descriptor for this waker. + + Must be suitable for use with ``select()`` or equivalent on the + local platform. + """ + raise NotImplementedError() + + def write_fileno(self): + """Returns the write file descriptor for this waker.""" + raise NotImplementedError() + + def wake(self): + """Triggers activity on the waker's file descriptor.""" + raise NotImplementedError() + + def consume(self): + """Called after the listen has woken up to do any necessary cleanup.""" + raise NotImplementedError() + + def close(self): + """Closes the waker's file descriptor(s).""" + raise NotImplementedError() diff --git a/src/console/zmq/eventloop/minitornado/platform/posix.py b/src/console/zmq/eventloop/minitornado/platform/posix.py new file mode 100755 index 00000000..ccffbb66 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/posix.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Posix implementations of platform-specific functionality.""" + +from __future__ import absolute_import, division, print_function, with_statement + +import fcntl +import os + +from . import interface + + +def set_close_exec(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + +class Waker(interface.Waker): + def __init__(self): + r, w = os.pipe() + _set_nonblocking(r) + _set_nonblocking(w) + set_close_exec(r) + set_close_exec(w) + self.reader = os.fdopen(r, "rb", 0) + self.writer = os.fdopen(w, "wb", 0) + + def fileno(self): + return self.reader.fileno() + + def write_fileno(self): + return self.writer.fileno() + + def wake(self): + try: + self.writer.write(b"x") + except IOError: + pass + + def consume(self): + try: + while True: + result = self.reader.read() + if not result: + break + except IOError: + pass + + def close(self): + self.reader.close() + self.writer.close() diff --git a/src/console/zmq/eventloop/minitornado/platform/windows.py b/src/console/zmq/eventloop/minitornado/platform/windows.py new file mode 100755 index 00000000..817bdca1 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/windows.py @@ -0,0 +1,20 @@ +# NOTE: win32 support is currently experimental, and not recommended +# for production use. + + +from __future__ import absolute_import, division, print_function, with_statement +import ctypes +import ctypes.wintypes + +# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx +SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation +SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD) +SetHandleInformation.restype = ctypes.wintypes.BOOL + +HANDLE_FLAG_INHERIT = 0x00000001 + + +def set_close_exec(fd): + success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0) + if not success: + raise ctypes.GetLastError() diff --git a/src/console/zmq/eventloop/minitornado/stack_context.py b/src/console/zmq/eventloop/minitornado/stack_context.py new file mode 100755 index 00000000..226d8042 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/stack_context.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python +# +# Copyright 2010 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""`StackContext` allows applications to maintain threadlocal-like state +that follows execution as it moves to other execution contexts. + +The motivating examples are to eliminate the need for explicit +``async_callback`` wrappers (as in `tornado.web.RequestHandler`), and to +allow some additional context to be kept for logging. + +This is slightly magic, but it's an extension of the idea that an +exception handler is a kind of stack-local state and when that stack +is suspended and resumed in a new context that state needs to be +preserved. `StackContext` shifts the burden of restoring that state +from each call site (e.g. wrapping each `.AsyncHTTPClient` callback +in ``async_callback``) to the mechanisms that transfer control from +one context to another (e.g. `.AsyncHTTPClient` itself, `.IOLoop`, +thread pools, etc). + +Example usage:: + + @contextlib.contextmanager + def die_on_error(): + try: + yield + except Exception: + logging.error("exception in asynchronous operation",exc_info=True) + sys.exit(1) + + with StackContext(die_on_error): + # Any exception thrown here *or in callback and its desendents* + # will cause the process to exit instead of spinning endlessly + # in the ioloop. + http_client.fetch(url, callback) + ioloop.start() + +Most applications shouln't have to work with `StackContext` directly. +Here are a few rules of thumb for when it's necessary: + +* If you're writing an asynchronous library that doesn't rely on a + stack_context-aware library like `tornado.ioloop` or `tornado.iostream` + (for example, if you're writing a thread pool), use + `.stack_context.wrap()` before any asynchronous operations to capture the + stack context from where the operation was started. + +* If you're writing an asynchronous library that has some shared + resources (such as a connection pool), create those shared resources + within a ``with stack_context.NullContext():`` block. This will prevent + ``StackContexts`` from leaking from one request to another. + +* If you want to write something like an exception handler that will + persist across asynchronous calls, create a new `StackContext` (or + `ExceptionStackContext`), and make your asynchronous calls in a ``with`` + block that references your `StackContext`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import sys +import threading + +from .util import raise_exc_info + + +class StackContextInconsistentError(Exception): + pass + + +class _State(threading.local): + def __init__(self): + self.contexts = (tuple(), None) +_state = _State() + + +class StackContext(object): + """Establishes the given context as a StackContext that will be transferred. + + Note that the parameter is a callable that returns a context + manager, not the context itself. That is, where for a + non-transferable context manager you would say:: + + with my_context(): + + StackContext takes the function itself rather than its result:: + + with StackContext(my_context): + + The result of ``with StackContext() as cb:`` is a deactivation + callback. Run this callback when the StackContext is no longer + needed to ensure that it is not propagated any further (note that + deactivating a context does not affect any instances of that + context that are currently pending). This is an advanced feature + and not necessary in most applications. + """ + def __init__(self, context_factory): + self.context_factory = context_factory + self.contexts = [] + self.active = True + + def _deactivate(self): + self.active = False + + # StackContext protocol + def enter(self): + context = self.context_factory() + self.contexts.append(context) + context.__enter__() + + def exit(self, type, value, traceback): + context = self.contexts.pop() + context.__exit__(type, value, traceback) + + # Note that some of this code is duplicated in ExceptionStackContext + # below. ExceptionStackContext is more common and doesn't need + # the full generality of this class. + def __enter__(self): + self.old_contexts = _state.contexts + self.new_contexts = (self.old_contexts[0] + (self,), self) + _state.contexts = self.new_contexts + + try: + self.enter() + except: + _state.contexts = self.old_contexts + raise + + return self._deactivate + + def __exit__(self, type, value, traceback): + try: + self.exit(type, value, traceback) + finally: + final_contexts = _state.contexts + _state.contexts = self.old_contexts + + # Generator coroutines and with-statements with non-local + # effects interact badly. Check here for signs of + # the stack getting out of sync. + # Note that this check comes after restoring _state.context + # so that if it fails things are left in a (relatively) + # consistent state. + if final_contexts is not self.new_contexts: + raise StackContextInconsistentError( + 'stack_context inconsistency (may be caused by yield ' + 'within a "with StackContext" block)') + + # Break up a reference to itself to allow for faster GC on CPython. + self.new_contexts = None + + +class ExceptionStackContext(object): + """Specialization of StackContext for exception handling. + + The supplied ``exception_handler`` function will be called in the + event of an uncaught exception in this context. The semantics are + similar to a try/finally clause, and intended use cases are to log + an error, close a socket, or similar cleanup actions. The + ``exc_info`` triple ``(type, value, traceback)`` will be passed to the + exception_handler function. + + If the exception handler returns true, the exception will be + consumed and will not be propagated to other exception handlers. + """ + def __init__(self, exception_handler): + self.exception_handler = exception_handler + self.active = True + + def _deactivate(self): + self.active = False + + def exit(self, type, value, traceback): + if type is not None: + return self.exception_handler(type, value, traceback) + + def __enter__(self): + self.old_contexts = _state.contexts + self.new_contexts = (self.old_contexts[0], self) + _state.contexts = self.new_contexts + + return self._deactivate + + def __exit__(self, type, value, traceback): + try: + if type is not None: + return self.exception_handler(type, value, traceback) + finally: + final_contexts = _state.contexts + _state.contexts = self.old_contexts + + if final_contexts is not self.new_contexts: + raise StackContextInconsistentError( + 'stack_context inconsistency (may be caused by yield ' + 'within a "with StackContext" block)') + + # Break up a reference to itself to allow for faster GC on CPython. + self.new_contexts = None + + +class NullContext(object): + """Resets the `StackContext`. + + Useful when creating a shared resource on demand (e.g. an + `.AsyncHTTPClient`) where the stack that caused the creating is + not relevant to future operations. + """ + def __enter__(self): + self.old_contexts = _state.contexts + _state.contexts = (tuple(), None) + + def __exit__(self, type, value, traceback): + _state.contexts = self.old_contexts + + +def _remove_deactivated(contexts): + """Remove deactivated handlers from the chain""" + # Clean ctx handlers + stack_contexts = tuple([h for h in contexts[0] if h.active]) + + # Find new head + head = contexts[1] + while head is not None and not head.active: + head = head.old_contexts[1] + + # Process chain + ctx = head + while ctx is not None: + parent = ctx.old_contexts[1] + + while parent is not None: + if parent.active: + break + ctx.old_contexts = parent.old_contexts + parent = parent.old_contexts[1] + + ctx = parent + + return (stack_contexts, head) + + +def wrap(fn): + """Returns a callable object that will restore the current `StackContext` + when executed. + + Use this whenever saving a callback to be executed later in a + different execution context (either in a different thread or + asynchronously in the same thread). + """ + # Check if function is already wrapped + if fn is None or hasattr(fn, '_wrapped'): + return fn + + # Capture current stack head + # TODO: Any other better way to store contexts and update them in wrapped function? + cap_contexts = [_state.contexts] + + def wrapped(*args, **kwargs): + ret = None + try: + # Capture old state + current_state = _state.contexts + + # Remove deactivated items + cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0]) + + # Force new state + _state.contexts = contexts + + # Current exception + exc = (None, None, None) + top = None + + # Apply stack contexts + last_ctx = 0 + stack = contexts[0] + + # Apply state + for n in stack: + try: + n.enter() + last_ctx += 1 + except: + # Exception happened. Record exception info and store top-most handler + exc = sys.exc_info() + top = n.old_contexts[1] + + # Execute callback if no exception happened while restoring state + if top is None: + try: + ret = fn(*args, **kwargs) + except: + exc = sys.exc_info() + top = contexts[1] + + # If there was exception, try to handle it by going through the exception chain + if top is not None: + exc = _handle_exception(top, exc) + else: + # Otherwise take shorter path and run stack contexts in reverse order + while last_ctx > 0: + last_ctx -= 1 + c = stack[last_ctx] + + try: + c.exit(*exc) + except: + exc = sys.exc_info() + top = c.old_contexts[1] + break + else: + top = None + + # If if exception happened while unrolling, take longer exception handler path + if top is not None: + exc = _handle_exception(top, exc) + + # If exception was not handled, raise it + if exc != (None, None, None): + raise_exc_info(exc) + finally: + _state.contexts = current_state + return ret + + wrapped._wrapped = True + return wrapped + + +def _handle_exception(tail, exc): + while tail is not None: + try: + if tail.exit(*exc): + exc = (None, None, None) + except: + exc = sys.exc_info() + + tail = tail.old_contexts[1] + + return exc + + +def run_with_stack_context(context, func): + """Run a coroutine ``func`` in the given `StackContext`. + + It is not safe to have a ``yield`` statement within a ``with StackContext`` + block, so it is difficult to use stack context with `.gen.coroutine`. + This helper function runs the function in the correct context while + keeping the ``yield`` and ``with`` statements syntactically separate. + + Example:: + + @gen.coroutine + def incorrect(): + with StackContext(ctx): + # ERROR: this will raise StackContextInconsistentError + yield other_coroutine() + + @gen.coroutine + def correct(): + yield run_with_stack_context(StackContext(ctx), other_coroutine) + + .. versionadded:: 3.1 + """ + with context: + return func() diff --git a/src/console/zmq/eventloop/minitornado/util.py b/src/console/zmq/eventloop/minitornado/util.py new file mode 100755 index 00000000..c1e2eb95 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/util.py @@ -0,0 +1,184 @@ +"""Miscellaneous utility functions and classes. + +This module is used internally by Tornado. It is not necessarily expected +that the functions and classes defined here will be useful to other +applications, but they are documented here in case they are. + +The one public-facing part of this module is the `Configurable` class +and its `~Configurable.configure` method, which becomes a part of the +interface of its subclasses, including `.AsyncHTTPClient`, `.IOLoop`, +and `.Resolver`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import sys + + +def import_object(name): + """Imports an object by name. + + import_object('x') is equivalent to 'import x'. + import_object('x.y.z') is equivalent to 'from x.y import z'. + + >>> import tornado.escape + >>> import_object('tornado.escape') is tornado.escape + True + >>> import_object('tornado.escape.utf8') is tornado.escape.utf8 + True + >>> import_object('tornado') is tornado + True + >>> import_object('tornado.missing_module') + Traceback (most recent call last): + ... + ImportError: No module named missing_module + """ + if name.count('.') == 0: + return __import__(name, None, None) + + parts = name.split('.') + obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0) + try: + return getattr(obj, parts[-1]) + except AttributeError: + raise ImportError("No module named %s" % parts[-1]) + + +# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for +# literal strings, and alternative solutions like "from __future__ import +# unicode_literals" have other problems (see PEP 414). u() can be applied +# to ascii strings that include \u escapes (but they must not contain +# literal non-ascii characters). +if type('') is not type(b''): + def u(s): + return s + bytes_type = bytes + unicode_type = str + basestring_type = str +else: + def u(s): + return s.decode('unicode_escape') + bytes_type = str + unicode_type = unicode + basestring_type = basestring + + +if sys.version_info > (3,): + exec(""" +def raise_exc_info(exc_info): + raise exc_info[1].with_traceback(exc_info[2]) + +def exec_in(code, glob, loc=None): + if isinstance(code, str): + code = compile(code, '<string>', 'exec', dont_inherit=True) + exec(code, glob, loc) +""") +else: + exec(""" +def raise_exc_info(exc_info): + raise exc_info[0], exc_info[1], exc_info[2] + +def exec_in(code, glob, loc=None): + if isinstance(code, basestring): + # exec(string) inherits the caller's future imports; compile + # the string first to prevent that. + code = compile(code, '<string>', 'exec', dont_inherit=True) + exec code in glob, loc +""") + + +class Configurable(object): + """Base class for configurable interfaces. + + A configurable interface is an (abstract) class whose constructor + acts as a factory function for one of its implementation subclasses. + The implementation subclass as well as optional keyword arguments to + its initializer can be set globally at runtime with `configure`. + + By using the constructor as the factory method, the interface + looks like a normal class, `isinstance` works as usual, etc. This + pattern is most useful when the choice of implementation is likely + to be a global decision (e.g. when `~select.epoll` is available, + always use it instead of `~select.select`), or when a + previously-monolithic class has been split into specialized + subclasses. + + Configurable subclasses must define the class methods + `configurable_base` and `configurable_default`, and use the instance + method `initialize` instead of ``__init__``. + """ + __impl_class = None + __impl_kwargs = None + + def __new__(cls, **kwargs): + base = cls.configurable_base() + args = {} + if cls is base: + impl = cls.configured_class() + if base.__impl_kwargs: + args.update(base.__impl_kwargs) + else: + impl = cls + args.update(kwargs) + instance = super(Configurable, cls).__new__(impl) + # initialize vs __init__ chosen for compatiblity with AsyncHTTPClient + # singleton magic. If we get rid of that we can switch to __init__ + # here too. + instance.initialize(**args) + return instance + + @classmethod + def configurable_base(cls): + """Returns the base class of a configurable hierarchy. + + This will normally return the class in which it is defined. + (which is *not* necessarily the same as the cls classmethod parameter). + """ + raise NotImplementedError() + + @classmethod + def configurable_default(cls): + """Returns the implementation class to be used if none is configured.""" + raise NotImplementedError() + + def initialize(self): + """Initialize a `Configurable` subclass instance. + + Configurable classes should use `initialize` instead of ``__init__``. + """ + + @classmethod + def configure(cls, impl, **kwargs): + """Sets the class to use when the base class is instantiated. + + Keyword arguments will be saved and added to the arguments passed + to the constructor. This can be used to set global defaults for + some parameters. + """ + base = cls.configurable_base() + if isinstance(impl, (unicode_type, bytes_type)): + impl = import_object(impl) + if impl is not None and not issubclass(impl, cls): + raise ValueError("Invalid subclass of %s" % cls) + base.__impl_class = impl + base.__impl_kwargs = kwargs + + @classmethod + def configured_class(cls): + """Returns the currently configured class.""" + base = cls.configurable_base() + if cls.__impl_class is None: + base.__impl_class = cls.configurable_default() + return base.__impl_class + + @classmethod + def _save_configuration(cls): + base = cls.configurable_base() + return (base.__impl_class, base.__impl_kwargs) + + @classmethod + def _restore_configuration(cls, saved): + base = cls.configurable_base() + base.__impl_class = saved[0] + base.__impl_kwargs = saved[1] + diff --git a/src/console/zmq/eventloop/zmqstream.py b/src/console/zmq/eventloop/zmqstream.py new file mode 100755 index 00000000..86a97e44 --- /dev/null +++ b/src/console/zmq/eventloop/zmqstream.py @@ -0,0 +1,529 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""A utility class to send to and recv from a non-blocking socket.""" + +from __future__ import with_statement + +import sys + +import zmq +from zmq.utils import jsonapi + +try: + import cPickle as pickle +except ImportError: + import pickle + +from .ioloop import IOLoop + +try: + # gen_log will only import from >= 3.0 + from tornado.log import gen_log + from tornado import stack_context +except ImportError: + from .minitornado.log import gen_log + from .minitornado import stack_context + +try: + from queue import Queue +except ImportError: + from Queue import Queue + +from zmq.utils.strtypes import bytes, unicode, basestring + +try: + callable +except NameError: + callable = lambda obj: hasattr(obj, '__call__') + + +class ZMQStream(object): + """A utility class to register callbacks when a zmq socket sends and receives + + For use with zmq.eventloop.ioloop + + There are three main methods + + Methods: + + * **on_recv(callback, copy=True):** + register a callback to be run every time the socket has something to receive + * **on_send(callback):** + register a callback to be run every time you call send + * **send(self, msg, flags=0, copy=False, callback=None):** + perform a send that will trigger the callback + if callback is passed, on_send is also called. + + There are also send_multipart(), send_json(), send_pyobj() + + Three other methods for deactivating the callbacks: + + * **stop_on_recv():** + turn off the recv callback + * **stop_on_send():** + turn off the send callback + + which simply call ``on_<evt>(None)``. + + The entire socket interface, excluding direct recv methods, is also + provided, primarily through direct-linking the methods. + e.g. + + >>> stream.bind is stream.socket.bind + True + + """ + + socket = None + io_loop = None + poller = None + + def __init__(self, socket, io_loop=None): + self.socket = socket + self.io_loop = io_loop or IOLoop.instance() + self.poller = zmq.Poller() + + self._send_queue = Queue() + self._recv_callback = None + self._send_callback = None + self._close_callback = None + self._recv_copy = False + self._flushed = False + + self._state = self.io_loop.ERROR + self._init_io_state() + + # shortcircuit some socket methods + self.bind = self.socket.bind + self.bind_to_random_port = self.socket.bind_to_random_port + self.connect = self.socket.connect + self.setsockopt = self.socket.setsockopt + self.getsockopt = self.socket.getsockopt + self.setsockopt_string = self.socket.setsockopt_string + self.getsockopt_string = self.socket.getsockopt_string + self.setsockopt_unicode = self.socket.setsockopt_unicode + self.getsockopt_unicode = self.socket.getsockopt_unicode + + + def stop_on_recv(self): + """Disable callback and automatic receiving.""" + return self.on_recv(None) + + def stop_on_send(self): + """Disable callback on sending.""" + return self.on_send(None) + + def stop_on_err(self): + """DEPRECATED, does nothing""" + gen_log.warn("on_err does nothing, and will be removed") + + def on_err(self, callback): + """DEPRECATED, does nothing""" + gen_log.warn("on_err does nothing, and will be removed") + + def on_recv(self, callback, copy=True): + """Register a callback for when a message is ready to recv. + + There can be only one callback registered at a time, so each + call to `on_recv` replaces previously registered callbacks. + + on_recv(None) disables recv event polling. + + Use on_recv_stream(callback) instead, to register a callback that will receive + both this ZMQStream and the message, instead of just the message. + + Parameters + ---------- + + callback : callable + callback must take exactly one argument, which will be a + list, as returned by socket.recv_multipart() + if callback is None, recv callbacks are disabled. + copy : bool + copy is passed directly to recv, so if copy is False, + callback will receive Message objects. If copy is True, + then callback will receive bytes/str objects. + + Returns : None + """ + + self._check_closed() + assert callback is None or callable(callback) + self._recv_callback = stack_context.wrap(callback) + self._recv_copy = copy + if callback is None: + self._drop_io_state(self.io_loop.READ) + else: + self._add_io_state(self.io_loop.READ) + + def on_recv_stream(self, callback, copy=True): + """Same as on_recv, but callback will get this stream as first argument + + callback must take exactly two arguments, as it will be called as:: + + callback(stream, msg) + + Useful when a single callback should be used with multiple streams. + """ + if callback is None: + self.stop_on_recv() + else: + self.on_recv(lambda msg: callback(self, msg), copy=copy) + + def on_send(self, callback): + """Register a callback to be called on each send + + There will be two arguments:: + + callback(msg, status) + + * `msg` will be the list of sendable objects that was just sent + * `status` will be the return result of socket.send_multipart(msg) - + MessageTracker or None. + + Non-copying sends return a MessageTracker object whose + `done` attribute will be True when the send is complete. + This allows users to track when an object is safe to write to + again. + + The second argument will always be None if copy=True + on the send. + + Use on_send_stream(callback) to register a callback that will be passed + this ZMQStream as the first argument, in addition to the other two. + + on_send(None) disables recv event polling. + + Parameters + ---------- + + callback : callable + callback must take exactly two arguments, which will be + the message being sent (always a list), + and the return result of socket.send_multipart(msg) - + MessageTracker or None. + + if callback is None, send callbacks are disabled. + """ + + self._check_closed() + assert callback is None or callable(callback) + self._send_callback = stack_context.wrap(callback) + + + def on_send_stream(self, callback): + """Same as on_send, but callback will get this stream as first argument + + Callback will be passed three arguments:: + + callback(stream, msg, status) + + Useful when a single callback should be used with multiple streams. + """ + if callback is None: + self.stop_on_send() + else: + self.on_send(lambda msg, status: callback(self, msg, status)) + + + def send(self, msg, flags=0, copy=True, track=False, callback=None): + """Send a message, optionally also register a new callback for sends. + See zmq.socket.send for details. + """ + return self.send_multipart([msg], flags=flags, copy=copy, track=track, callback=callback) + + def send_multipart(self, msg, flags=0, copy=True, track=False, callback=None): + """Send a multipart message, optionally also register a new callback for sends. + See zmq.socket.send_multipart for details. + """ + kwargs = dict(flags=flags, copy=copy, track=track) + self._send_queue.put((msg, kwargs)) + callback = callback or self._send_callback + if callback is not None: + self.on_send(callback) + else: + # noop callback + self.on_send(lambda *args: None) + self._add_io_state(self.io_loop.WRITE) + + def send_string(self, u, flags=0, encoding='utf-8', callback=None): + """Send a unicode message with an encoding. + See zmq.socket.send_unicode for details. + """ + if not isinstance(u, basestring): + raise TypeError("unicode/str objects only") + return self.send(u.encode(encoding), flags=flags, callback=callback) + + send_unicode = send_string + + def send_json(self, obj, flags=0, callback=None): + """Send json-serialized version of an object. + See zmq.socket.send_json for details. + """ + if jsonapi is None: + raise ImportError('jsonlib{1,2}, json or simplejson library is required.') + else: + msg = jsonapi.dumps(obj) + return self.send(msg, flags=flags, callback=callback) + + def send_pyobj(self, obj, flags=0, protocol=-1, callback=None): + """Send a Python object as a message using pickle to serialize. + + See zmq.socket.send_json for details. + """ + msg = pickle.dumps(obj, protocol) + return self.send(msg, flags, callback=callback) + + def _finish_flush(self): + """callback for unsetting _flushed flag.""" + self._flushed = False + + def flush(self, flag=zmq.POLLIN|zmq.POLLOUT, limit=None): + """Flush pending messages. + + This method safely handles all pending incoming and/or outgoing messages, + bypassing the inner loop, passing them to the registered callbacks. + + A limit can be specified, to prevent blocking under high load. + + flush will return the first time ANY of these conditions are met: + * No more events matching the flag are pending. + * the total number of events handled reaches the limit. + + Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback + is registered, unlike normal IOLoop operation. This allows flush to be + used to remove *and ignore* incoming messages. + + Parameters + ---------- + flag : int, default=POLLIN|POLLOUT + 0MQ poll flags. + If flag|POLLIN, recv events will be flushed. + If flag|POLLOUT, send events will be flushed. + Both flags can be set at once, which is the default. + limit : None or int, optional + The maximum number of messages to send or receive. + Both send and recv count against this limit. + + Returns + ------- + int : count of events handled (both send and recv) + """ + self._check_closed() + # unset self._flushed, so callbacks will execute, in case flush has + # already been called this iteration + already_flushed = self._flushed + self._flushed = False + # initialize counters + count = 0 + def update_flag(): + """Update the poll flag, to prevent registering POLLOUT events + if we don't have pending sends.""" + return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT) + flag = update_flag() + if not flag: + # nothing to do + return 0 + self.poller.register(self.socket, flag) + events = self.poller.poll(0) + while events and (not limit or count < limit): + s,event = events[0] + if event & zmq.POLLIN: # receiving + self._handle_recv() + count += 1 + if self.socket is None: + # break if socket was closed during callback + break + if event & zmq.POLLOUT and self.sending(): + self._handle_send() + count += 1 + if self.socket is None: + # break if socket was closed during callback + break + + flag = update_flag() + if flag: + self.poller.register(self.socket, flag) + events = self.poller.poll(0) + else: + events = [] + if count: # only bypass loop if we actually flushed something + # skip send/recv callbacks this iteration + self._flushed = True + # reregister them at the end of the loop + if not already_flushed: # don't need to do it again + self.io_loop.add_callback(self._finish_flush) + elif already_flushed: + self._flushed = True + + # update ioloop poll state, which may have changed + self._rebuild_io_state() + return count + + def set_close_callback(self, callback): + """Call the given callback when the stream is closed.""" + self._close_callback = stack_context.wrap(callback) + + def close(self, linger=None): + """Close this stream.""" + if self.socket is not None: + self.io_loop.remove_handler(self.socket) + self.socket.close(linger) + self.socket = None + if self._close_callback: + self._run_callback(self._close_callback) + + def receiving(self): + """Returns True if we are currently receiving from the stream.""" + return self._recv_callback is not None + + def sending(self): + """Returns True if we are currently sending to the stream.""" + return not self._send_queue.empty() + + def closed(self): + return self.socket is None + + def _run_callback(self, callback, *args, **kwargs): + """Wrap running callbacks in try/except to allow us to + close our socket.""" + try: + # Use a NullContext to ensure that all StackContexts are run + # inside our blanket exception handler rather than outside. + with stack_context.NullContext(): + callback(*args, **kwargs) + except: + gen_log.error("Uncaught exception, closing connection.", + exc_info=True) + # Close the socket on an uncaught exception from a user callback + # (It would eventually get closed when the socket object is + # gc'd, but we don't want to rely on gc happening before we + # run out of file descriptors) + self.close() + # Re-raise the exception so that IOLoop.handle_callback_exception + # can see it and log the error + raise + + def _handle_events(self, fd, events): + """This method is the actual handler for IOLoop, that gets called whenever + an event on my socket is posted. It dispatches to _handle_recv, etc.""" + # print "handling events" + if not self.socket: + gen_log.warning("Got events for closed stream %s", fd) + return + try: + # dispatch events: + if events & IOLoop.ERROR: + gen_log.error("got POLLERR event on ZMQStream, which doesn't make sense") + return + if events & IOLoop.READ: + self._handle_recv() + if not self.socket: + return + if events & IOLoop.WRITE: + self._handle_send() + if not self.socket: + return + + # rebuild the poll state + self._rebuild_io_state() + except: + gen_log.error("Uncaught exception, closing connection.", + exc_info=True) + self.close() + raise + + def _handle_recv(self): + """Handle a recv event.""" + if self._flushed: + return + try: + msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy) + except zmq.ZMQError as e: + if e.errno == zmq.EAGAIN: + # state changed since poll event + pass + else: + gen_log.error("RECV Error: %s"%zmq.strerror(e.errno)) + else: + if self._recv_callback: + callback = self._recv_callback + # self._recv_callback = None + self._run_callback(callback, msg) + + # self.update_state() + + + def _handle_send(self): + """Handle a send event.""" + if self._flushed: + return + if not self.sending(): + gen_log.error("Shouldn't have handled a send event") + return + + msg, kwargs = self._send_queue.get() + try: + status = self.socket.send_multipart(msg, **kwargs) + except zmq.ZMQError as e: + gen_log.error("SEND Error: %s", e) + status = e + if self._send_callback: + callback = self._send_callback + self._run_callback(callback, msg, status) + + # self.update_state() + + def _check_closed(self): + if not self.socket: + raise IOError("Stream is closed") + + def _rebuild_io_state(self): + """rebuild io state based on self.sending() and receiving()""" + if self.socket is None: + return + state = self.io_loop.ERROR + if self.receiving(): + state |= self.io_loop.READ + if self.sending(): + state |= self.io_loop.WRITE + if state != self._state: + self._state = state + self._update_handler(state) + + def _add_io_state(self, state): + """Add io_state to poller.""" + if not self._state & state: + self._state = self._state | state + self._update_handler(self._state) + + def _drop_io_state(self, state): + """Stop poller from watching an io_state.""" + if self._state & state: + self._state = self._state & (~state) + self._update_handler(self._state) + + def _update_handler(self, state): + """Update IOLoop handler with state.""" + if self.socket is None: + return + self.io_loop.update_handler(self.socket, state) + + def _init_io_state(self): + """initialize the ioloop event handler""" + with stack_context.NullContext(): + self.io_loop.add_handler(self.socket, self._handle_events, self._state) + diff --git a/src/console/zmq/green/__init__.py b/src/console/zmq/green/__init__.py new file mode 100755 index 00000000..ff7e5965 --- /dev/null +++ b/src/console/zmq/green/__init__.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +#----------------------------------------------------------------------------- +# Copyright (C) 2011-2012 Travis Cline +# +# This file is part of pyzmq +# It is adapted from upstream project zeromq_gevent under the New BSD License +# +# Distributed under the terms of the New BSD License. The full license is in +# the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + +"""zmq.green - gevent compatibility with zeromq. + +Usage +----- + +Instead of importing zmq directly, do so in the following manner: + +.. + + import zmq.green as zmq + + +Any calls that would have blocked the current thread will now only block the +current green thread. + +This compatibility is accomplished by ensuring the nonblocking flag is set +before any blocking operation and the ØMQ file descriptor is polled internally +to trigger needed events. +""" + +from zmq import * +from zmq.green.core import _Context, _Socket +from zmq.green.poll import _Poller +Context = _Context +Socket = _Socket +Poller = _Poller + +from zmq.green.device import device + diff --git a/src/console/zmq/green/core.py b/src/console/zmq/green/core.py new file mode 100755 index 00000000..9fc73e32 --- /dev/null +++ b/src/console/zmq/green/core.py @@ -0,0 +1,287 @@ +#----------------------------------------------------------------------------- +# Copyright (C) 2011-2012 Travis Cline +# +# This file is part of pyzmq +# It is adapted from upstream project zeromq_gevent under the New BSD License +# +# Distributed under the terms of the New BSD License. The full license is in +# the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + +"""This module wraps the :class:`Socket` and :class:`Context` found in :mod:`pyzmq <zmq>` to be non blocking +""" + +from __future__ import print_function + +import sys +import time +import warnings + +import zmq + +from zmq import Context as _original_Context +from zmq import Socket as _original_Socket +from .poll import _Poller + +import gevent +from gevent.event import AsyncResult +from gevent.hub import get_hub + +if hasattr(zmq, 'RCVTIMEO'): + TIMEOS = (zmq.RCVTIMEO, zmq.SNDTIMEO) +else: + TIMEOS = () + +def _stop(evt): + """simple wrapper for stopping an Event, allowing for method rename in gevent 1.0""" + try: + evt.stop() + except AttributeError as e: + # gevent<1.0 compat + evt.cancel() + +class _Socket(_original_Socket): + """Green version of :class:`zmq.Socket` + + The following methods are overridden: + + * send + * recv + + To ensure that the ``zmq.NOBLOCK`` flag is set and that sending or receiving + is deferred to the hub if a ``zmq.EAGAIN`` (retry) error is raised. + + The `__state_changed` method is triggered when the zmq.FD for the socket is + marked as readable and triggers the necessary read and write events (which + are waited for in the recv and send methods). + + Some double underscore prefixes are used to minimize pollution of + :class:`zmq.Socket`'s namespace. + """ + __in_send_multipart = False + __in_recv_multipart = False + __writable = None + __readable = None + _state_event = None + _gevent_bug_timeout = 11.6 # timeout for not trusting gevent + _debug_gevent = False # turn on if you think gevent is missing events + _poller_class = _Poller + + def __init__(self, context, socket_type): + _original_Socket.__init__(self, context, socket_type) + self.__in_send_multipart = False + self.__in_recv_multipart = False + self.__setup_events() + + + def __del__(self): + self.close() + + def close(self, linger=None): + super(_Socket, self).close(linger) + self.__cleanup_events() + + def __cleanup_events(self): + # close the _state_event event, keeps the number of active file descriptors down + if getattr(self, '_state_event', None): + _stop(self._state_event) + self._state_event = None + # if the socket has entered a close state resume any waiting greenlets + self.__writable.set() + self.__readable.set() + + def __setup_events(self): + self.__readable = AsyncResult() + self.__writable = AsyncResult() + self.__readable.set() + self.__writable.set() + + try: + self._state_event = get_hub().loop.io(self.getsockopt(zmq.FD), 1) # read state watcher + self._state_event.start(self.__state_changed) + except AttributeError: + # for gevent<1.0 compatibility + from gevent.core import read_event + self._state_event = read_event(self.getsockopt(zmq.FD), self.__state_changed, persist=True) + + def __state_changed(self, event=None, _evtype=None): + if self.closed: + self.__cleanup_events() + return + try: + # avoid triggering __state_changed from inside __state_changed + events = super(_Socket, self).getsockopt(zmq.EVENTS) + except zmq.ZMQError as exc: + self.__writable.set_exception(exc) + self.__readable.set_exception(exc) + else: + if events & zmq.POLLOUT: + self.__writable.set() + if events & zmq.POLLIN: + self.__readable.set() + + def _wait_write(self): + assert self.__writable.ready(), "Only one greenlet can be waiting on this event" + self.__writable = AsyncResult() + # timeout is because libzmq cannot be trusted to properly signal a new send event: + # this is effectively a maximum poll interval of 1s + tic = time.time() + dt = self._gevent_bug_timeout + if dt: + timeout = gevent.Timeout(seconds=dt) + else: + timeout = None + try: + if timeout: + timeout.start() + self.__writable.get(block=True) + except gevent.Timeout as t: + if t is not timeout: + raise + toc = time.time() + # gevent bug: get can raise timeout even on clean return + # don't display zmq bug warning for gevent bug (this is getting ridiculous) + if self._debug_gevent and timeout and toc-tic > dt and \ + self.getsockopt(zmq.EVENTS) & zmq.POLLOUT: + print("BUG: gevent may have missed a libzmq send event on %i!" % self.FD, file=sys.stderr) + finally: + if timeout: + timeout.cancel() + self.__writable.set() + + def _wait_read(self): + assert self.__readable.ready(), "Only one greenlet can be waiting on this event" + self.__readable = AsyncResult() + # timeout is because libzmq cannot always be trusted to play nice with libevent. + # I can only confirm that this actually happens for send, but lets be symmetrical + # with our dirty hacks. + # this is effectively a maximum poll interval of 1s + tic = time.time() + dt = self._gevent_bug_timeout + if dt: + timeout = gevent.Timeout(seconds=dt) + else: + timeout = None + try: + if timeout: + timeout.start() + self.__readable.get(block=True) + except gevent.Timeout as t: + if t is not timeout: + raise + toc = time.time() + # gevent bug: get can raise timeout even on clean return + # don't display zmq bug warning for gevent bug (this is getting ridiculous) + if self._debug_gevent and timeout and toc-tic > dt and \ + self.getsockopt(zmq.EVENTS) & zmq.POLLIN: + print("BUG: gevent may have missed a libzmq recv event on %i!" % self.FD, file=sys.stderr) + finally: + if timeout: + timeout.cancel() + self.__readable.set() + + def send(self, data, flags=0, copy=True, track=False): + """send, which will only block current greenlet + + state_changed always fires exactly once (success or fail) at the + end of this method. + """ + + # if we're given the NOBLOCK flag act as normal and let the EAGAIN get raised + if flags & zmq.NOBLOCK: + try: + msg = super(_Socket, self).send(data, flags, copy, track) + finally: + if not self.__in_send_multipart: + self.__state_changed() + return msg + # ensure the zmq.NOBLOCK flag is part of flags + flags |= zmq.NOBLOCK + while True: # Attempt to complete this operation indefinitely, blocking the current greenlet + try: + # attempt the actual call + msg = super(_Socket, self).send(data, flags, copy, track) + except zmq.ZMQError as e: + # if the raised ZMQError is not EAGAIN, reraise + if e.errno != zmq.EAGAIN: + if not self.__in_send_multipart: + self.__state_changed() + raise + else: + if not self.__in_send_multipart: + self.__state_changed() + return msg + # defer to the event loop until we're notified the socket is writable + self._wait_write() + + def recv(self, flags=0, copy=True, track=False): + """recv, which will only block current greenlet + + state_changed always fires exactly once (success or fail) at the + end of this method. + """ + if flags & zmq.NOBLOCK: + try: + msg = super(_Socket, self).recv(flags, copy, track) + finally: + if not self.__in_recv_multipart: + self.__state_changed() + return msg + + flags |= zmq.NOBLOCK + while True: + try: + msg = super(_Socket, self).recv(flags, copy, track) + except zmq.ZMQError as e: + if e.errno != zmq.EAGAIN: + if not self.__in_recv_multipart: + self.__state_changed() + raise + else: + if not self.__in_recv_multipart: + self.__state_changed() + return msg + self._wait_read() + + def send_multipart(self, *args, **kwargs): + """wrap send_multipart to prevent state_changed on each partial send""" + self.__in_send_multipart = True + try: + msg = super(_Socket, self).send_multipart(*args, **kwargs) + finally: + self.__in_send_multipart = False + self.__state_changed() + return msg + + def recv_multipart(self, *args, **kwargs): + """wrap recv_multipart to prevent state_changed on each partial recv""" + self.__in_recv_multipart = True + try: + msg = super(_Socket, self).recv_multipart(*args, **kwargs) + finally: + self.__in_recv_multipart = False + self.__state_changed() + return msg + + def get(self, opt): + """trigger state_changed on getsockopt(EVENTS)""" + if opt in TIMEOS: + warnings.warn("TIMEO socket options have no effect in zmq.green", UserWarning) + optval = super(_Socket, self).get(opt) + if opt == zmq.EVENTS: + self.__state_changed() + return optval + + def set(self, opt, val): + """set socket option""" + if opt in TIMEOS: + warnings.warn("TIMEO socket options have no effect in zmq.green", UserWarning) + return super(_Socket, self).set(opt, val) + + +class _Context(_original_Context): + """Replacement for :class:`zmq.Context` + + Ensures that the greened Socket above is used in calls to `socket`. + """ + _socket_class = _Socket diff --git a/src/console/zmq/green/device.py b/src/console/zmq/green/device.py new file mode 100755 index 00000000..4b070237 --- /dev/null +++ b/src/console/zmq/green/device.py @@ -0,0 +1,32 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import zmq +from zmq.green import Poller + +def device(device_type, isocket, osocket): + """Start a zeromq device (gevent-compatible). + + Unlike the true zmq.device, this does not release the GIL. + + Parameters + ---------- + device_type : (QUEUE, FORWARDER, STREAMER) + The type of device to start (ignored). + isocket : Socket + The Socket instance for the incoming traffic. + osocket : Socket + The Socket instance for the outbound traffic. + """ + p = Poller() + if osocket == -1: + osocket = isocket + p.register(isocket, zmq.POLLIN) + p.register(osocket, zmq.POLLIN) + + while True: + events = dict(p.poll()) + if isocket in events: + osocket.send_multipart(isocket.recv_multipart()) + if osocket in events: + isocket.send_multipart(osocket.recv_multipart()) diff --git a/src/console/zmq/green/eventloop/__init__.py b/src/console/zmq/green/eventloop/__init__.py new file mode 100755 index 00000000..c5150efe --- /dev/null +++ b/src/console/zmq/green/eventloop/__init__.py @@ -0,0 +1,3 @@ +from zmq.green.eventloop.ioloop import IOLoop + +__all__ = ['IOLoop']
\ No newline at end of file diff --git a/src/console/zmq/green/eventloop/ioloop.py b/src/console/zmq/green/eventloop/ioloop.py new file mode 100755 index 00000000..e12fd5e9 --- /dev/null +++ b/src/console/zmq/green/eventloop/ioloop.py @@ -0,0 +1,33 @@ +from zmq.eventloop.ioloop import * +from zmq.green import Poller + +RealIOLoop = IOLoop +RealZMQPoller = ZMQPoller + +class IOLoop(RealIOLoop): + + def initialize(self, impl=None): + impl = _poll() if impl is None else impl + super(IOLoop, self).initialize(impl) + + @staticmethod + def instance(): + """Returns a global `IOLoop` instance. + + Most applications have a single, global `IOLoop` running on the + main thread. Use this method to get this instance from + another thread. To get the current thread's `IOLoop`, use `current()`. + """ + # install this class as the active IOLoop implementation + # when using tornado 3 + if tornado_version >= (3,): + PollIOLoop.configure(IOLoop) + return PollIOLoop.instance() + + +class ZMQPoller(RealZMQPoller): + """gevent-compatible version of ioloop.ZMQPoller""" + def __init__(self): + self._poller = Poller() + +_poll = ZMQPoller diff --git a/src/console/zmq/green/eventloop/zmqstream.py b/src/console/zmq/green/eventloop/zmqstream.py new file mode 100755 index 00000000..90fbd1f5 --- /dev/null +++ b/src/console/zmq/green/eventloop/zmqstream.py @@ -0,0 +1,11 @@ +from zmq.eventloop.zmqstream import * + +from zmq.green.eventloop.ioloop import IOLoop + +RealZMQStream = ZMQStream + +class ZMQStream(RealZMQStream): + + def __init__(self, socket, io_loop=None): + io_loop = io_loop or IOLoop.instance() + super(ZMQStream, self).__init__(socket, io_loop=io_loop) diff --git a/src/console/zmq/green/poll.py b/src/console/zmq/green/poll.py new file mode 100755 index 00000000..8f016129 --- /dev/null +++ b/src/console/zmq/green/poll.py @@ -0,0 +1,95 @@ +import zmq +import gevent +from gevent import select + +from zmq import Poller as _original_Poller + + +class _Poller(_original_Poller): + """Replacement for :class:`zmq.Poller` + + Ensures that the greened Poller below is used in calls to + :meth:`zmq.Poller.poll`. + """ + _gevent_bug_timeout = 1.33 # minimum poll interval, for working around gevent bug + + def _get_descriptors(self): + """Returns three elements tuple with socket descriptors ready + for gevent.select.select + """ + rlist = [] + wlist = [] + xlist = [] + + for socket, flags in self.sockets: + if isinstance(socket, zmq.Socket): + rlist.append(socket.getsockopt(zmq.FD)) + continue + elif isinstance(socket, int): + fd = socket + elif hasattr(socket, 'fileno'): + try: + fd = int(socket.fileno()) + except: + raise ValueError('fileno() must return an valid integer fd') + else: + raise TypeError('Socket must be a 0MQ socket, an integer fd ' + 'or have a fileno() method: %r' % socket) + + if flags & zmq.POLLIN: + rlist.append(fd) + if flags & zmq.POLLOUT: + wlist.append(fd) + if flags & zmq.POLLERR: + xlist.append(fd) + + return (rlist, wlist, xlist) + + def poll(self, timeout=-1): + """Overridden method to ensure that the green version of + Poller is used. + + Behaves the same as :meth:`zmq.core.Poller.poll` + """ + + if timeout is None: + timeout = -1 + + if timeout < 0: + timeout = -1 + + rlist = None + wlist = None + xlist = None + + if timeout > 0: + tout = gevent.Timeout.start_new(timeout/1000.0) + + try: + # Loop until timeout or events available + rlist, wlist, xlist = self._get_descriptors() + while True: + events = super(_Poller, self).poll(0) + if events or timeout == 0: + return events + + # wait for activity on sockets in a green way + # set a minimum poll frequency, + # because gevent < 1.0 cannot be trusted to catch edge-triggered FD events + _bug_timeout = gevent.Timeout.start_new(self._gevent_bug_timeout) + try: + select.select(rlist, wlist, xlist) + except gevent.Timeout as t: + if t is not _bug_timeout: + raise + finally: + _bug_timeout.cancel() + + except gevent.Timeout as t: + if t is not tout: + raise + return [] + finally: + if timeout > 0: + tout.cancel() + diff --git a/src/console/zmq/log/__init__.py b/src/console/zmq/log/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/console/zmq/log/__init__.py diff --git a/src/console/zmq/log/handlers.py b/src/console/zmq/log/handlers.py new file mode 100755 index 00000000..5ff21bf3 --- /dev/null +++ b/src/console/zmq/log/handlers.py @@ -0,0 +1,146 @@ +"""pyzmq logging handlers. + +This mainly defines the PUBHandler object for publishing logging messages over +a zmq.PUB socket. + +The PUBHandler can be used with the regular logging module, as in:: + + >>> import logging + >>> handler = PUBHandler('tcp://127.0.0.1:12345') + >>> handler.root_topic = 'foo' + >>> logger = logging.getLogger('foobar') + >>> logger.setLevel(logging.DEBUG) + >>> logger.addHandler(handler) + +After this point, all messages logged by ``logger`` will be published on the +PUB socket. + +Code adapted from StarCluster: + + http://github.com/jtriley/StarCluster/blob/master/starcluster/logger.py +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import logging +from logging import INFO, DEBUG, WARN, ERROR, FATAL + +import zmq +from zmq.utils.strtypes import bytes, unicode, cast_bytes + + +TOPIC_DELIM="::" # delimiter for splitting topics on the receiving end. + + +class PUBHandler(logging.Handler): + """A basic logging handler that emits log messages through a PUB socket. + + Takes a PUB socket already bound to interfaces or an interface to bind to. + + Example:: + + sock = context.socket(zmq.PUB) + sock.bind('inproc://log') + handler = PUBHandler(sock) + + Or:: + + handler = PUBHandler('inproc://loc') + + These are equivalent. + + Log messages handled by this handler are broadcast with ZMQ topics + ``this.root_topic`` comes first, followed by the log level + (DEBUG,INFO,etc.), followed by any additional subtopics specified in the + message by: log.debug("subtopic.subsub::the real message") + """ + root_topic="" + socket = None + + formatters = { + logging.DEBUG: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n"), + logging.INFO: logging.Formatter("%(message)s\n"), + logging.WARN: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n"), + logging.ERROR: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s - %(exc_info)s\n"), + logging.CRITICAL: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n")} + + def __init__(self, interface_or_socket, context=None): + logging.Handler.__init__(self) + if isinstance(interface_or_socket, zmq.Socket): + self.socket = interface_or_socket + self.ctx = self.socket.context + else: + self.ctx = context or zmq.Context() + self.socket = self.ctx.socket(zmq.PUB) + self.socket.bind(interface_or_socket) + + def format(self,record): + """Format a record.""" + return self.formatters[record.levelno].format(record) + + def emit(self, record): + """Emit a log message on my socket.""" + try: + topic, record.msg = record.msg.split(TOPIC_DELIM,1) + except Exception: + topic = "" + try: + bmsg = cast_bytes(self.format(record)) + except Exception: + self.handleError(record) + return + + topic_list = [] + + if self.root_topic: + topic_list.append(self.root_topic) + + topic_list.append(record.levelname) + + if topic: + topic_list.append(topic) + + btopic = b'.'.join(cast_bytes(t) for t in topic_list) + + self.socket.send_multipart([btopic, bmsg]) + + +class TopicLogger(logging.Logger): + """A simple wrapper that takes an additional argument to log methods. + + All the regular methods exist, but instead of one msg argument, two + arguments: topic, msg are passed. + + That is:: + + logger.debug('msg') + + Would become:: + + logger.debug('topic.sub', 'msg') + """ + def log(self, level, topic, msg, *args, **kwargs): + """Log 'msg % args' with level and topic. + + To pass exception information, use the keyword argument exc_info + with a True value:: + + logger.log(level, "zmq.fun", "We have a %s", + "mysterious problem", exc_info=1) + """ + logging.Logger.log(self, level, '%s::%s'%(topic,msg), *args, **kwargs) + +# Generate the methods of TopicLogger, since they are just adding a +# topic prefix to a message. +for name in "debug warn warning error critical fatal".split(): + meth = getattr(logging.Logger,name) + setattr(TopicLogger, name, + lambda self, level, topic, msg, *args, **kwargs: + meth(self, level, topic+TOPIC_DELIM+msg,*args, **kwargs)) + diff --git a/src/console/zmq/ssh/__init__.py b/src/console/zmq/ssh/__init__.py new file mode 100755 index 00000000..57f09568 --- /dev/null +++ b/src/console/zmq/ssh/__init__.py @@ -0,0 +1 @@ +from zmq.ssh.tunnel import * diff --git a/src/console/zmq/ssh/forward.py b/src/console/zmq/ssh/forward.py new file mode 100755 index 00000000..2d619462 --- /dev/null +++ b/src/console/zmq/ssh/forward.py @@ -0,0 +1,91 @@ +# +# This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1. +# Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com> +# Edits Copyright (C) 2010 The IPython Team +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA. + +""" +Sample script showing how to do local port forwarding over paramiko. + +This script connects to the requested SSH server and sets up local port +forwarding (the openssh -L option) from a local port through a tunneled +connection to a destination reachable from the SSH server machine. +""" + +from __future__ import print_function + +import logging +import select +try: # Python 3 + import socketserver +except ImportError: # Python 2 + import SocketServer as socketserver + +logger = logging.getLogger('ssh') + +class ForwardServer (socketserver.ThreadingTCPServer): + daemon_threads = True + allow_reuse_address = True + + +class Handler (socketserver.BaseRequestHandler): + + def handle(self): + try: + chan = self.ssh_transport.open_channel('direct-tcpip', + (self.chain_host, self.chain_port), + self.request.getpeername()) + except Exception as e: + logger.debug('Incoming request to %s:%d failed: %s' % (self.chain_host, + self.chain_port, + repr(e))) + return + if chan is None: + logger.debug('Incoming request to %s:%d was rejected by the SSH server.' % + (self.chain_host, self.chain_port)) + return + + logger.debug('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(), + chan.getpeername(), (self.chain_host, self.chain_port))) + while True: + r, w, x = select.select([self.request, chan], [], []) + if self.request in r: + data = self.request.recv(1024) + if len(data) == 0: + break + chan.send(data) + if chan in r: + data = chan.recv(1024) + if len(data) == 0: + break + self.request.send(data) + chan.close() + self.request.close() + logger.debug('Tunnel closed ') + + +def forward_tunnel(local_port, remote_host, remote_port, transport): + # this is a little convoluted, but lets me configure things for the Handler + # object. (SocketServer doesn't give Handlers any way to access the outer + # server normally.) + class SubHander (Handler): + chain_host = remote_host + chain_port = remote_port + ssh_transport = transport + ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever() + + +__all__ = ['forward_tunnel'] diff --git a/src/console/zmq/ssh/tunnel.py b/src/console/zmq/ssh/tunnel.py new file mode 100755 index 00000000..5a0c5433 --- /dev/null +++ b/src/console/zmq/ssh/tunnel.py @@ -0,0 +1,376 @@ +"""Basic ssh tunnel utilities, and convenience functions for tunneling +zeromq connections. +""" + +# Copyright (C) 2010-2011 IPython Development Team +# Copyright (C) 2011- PyZMQ Developers +# +# Redistributed from IPython under the terms of the BSD License. + + +from __future__ import print_function + +import atexit +import os +import signal +import socket +import sys +import warnings +from getpass import getpass, getuser +from multiprocessing import Process + +try: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + import paramiko + SSHException = paramiko.ssh_exception.SSHException +except ImportError: + paramiko = None + class SSHException(Exception): + pass +else: + from .forward import forward_tunnel + +try: + import pexpect +except ImportError: + pexpect = None + + +_random_ports = set() + +def select_random_ports(n): + """Selects and return n random ports that are available.""" + ports = [] + for i in range(n): + sock = socket.socket() + sock.bind(('', 0)) + while sock.getsockname()[1] in _random_ports: + sock.close() + sock = socket.socket() + sock.bind(('', 0)) + ports.append(sock) + for i, sock in enumerate(ports): + port = sock.getsockname()[1] + sock.close() + ports[i] = port + _random_ports.add(port) + return ports + + +#----------------------------------------------------------------------------- +# Check for passwordless login +#----------------------------------------------------------------------------- + +def try_passwordless_ssh(server, keyfile, paramiko=None): + """Attempt to make an ssh connection without a password. + This is mainly used for requiring password input only once + when many tunnels may be connected to the same server. + + If paramiko is None, the default for the platform is chosen. + """ + if paramiko is None: + paramiko = sys.platform == 'win32' + if not paramiko: + f = _try_passwordless_openssh + else: + f = _try_passwordless_paramiko + return f(server, keyfile) + +def _try_passwordless_openssh(server, keyfile): + """Try passwordless login with shell ssh command.""" + if pexpect is None: + raise ImportError("pexpect unavailable, use paramiko") + cmd = 'ssh -f '+ server + if keyfile: + cmd += ' -i ' + keyfile + cmd += ' exit' + + # pop SSH_ASKPASS from env + env = os.environ.copy() + env.pop('SSH_ASKPASS', None) + + ssh_newkey = 'Are you sure you want to continue connecting' + p = pexpect.spawn(cmd, env=env) + while True: + try: + i = p.expect([ssh_newkey, '[Pp]assword:'], timeout=.1) + if i==0: + raise SSHException('The authenticity of the host can\'t be established.') + except pexpect.TIMEOUT: + continue + except pexpect.EOF: + return True + else: + return False + +def _try_passwordless_paramiko(server, keyfile): + """Try passwordless login with paramiko.""" + if paramiko is None: + msg = "Paramiko unavaliable, " + if sys.platform == 'win32': + msg += "Paramiko is required for ssh tunneled connections on Windows." + else: + msg += "use OpenSSH." + raise ImportError(msg) + username, server, port = _split_server(server) + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + try: + client.connect(server, port, username=username, key_filename=keyfile, + look_for_keys=True) + except paramiko.AuthenticationException: + return False + else: + client.close() + return True + + +def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramiko=None, timeout=60): + """Connect a socket to an address via an ssh tunnel. + + This is a wrapper for socket.connect(addr), when addr is not accessible + from the local machine. It simply creates an ssh tunnel using the remaining args, + and calls socket.connect('tcp://localhost:lport') where lport is the randomly + selected local port of the tunnel. + + """ + new_url, tunnel = open_tunnel(addr, server, keyfile=keyfile, password=password, paramiko=paramiko, timeout=timeout) + socket.connect(new_url) + return tunnel + + +def open_tunnel(addr, server, keyfile=None, password=None, paramiko=None, timeout=60): + """Open a tunneled connection from a 0MQ url. + + For use inside tunnel_connection. + + Returns + ------- + + (url, tunnel) : (str, object) + The 0MQ url that has been forwarded, and the tunnel object + """ + + lport = select_random_ports(1)[0] + transport, addr = addr.split('://') + ip,rport = addr.split(':') + rport = int(rport) + if paramiko is None: + paramiko = sys.platform == 'win32' + if paramiko: + tunnelf = paramiko_tunnel + else: + tunnelf = openssh_tunnel + + tunnel = tunnelf(lport, rport, server, remoteip=ip, keyfile=keyfile, password=password, timeout=timeout) + return 'tcp://127.0.0.1:%i'%lport, tunnel + +def openssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60): + """Create an ssh tunnel using command-line ssh that connects port lport + on this machine to localhost:rport on server. The tunnel + will automatically close when not in use, remaining open + for a minimum of timeout seconds for an initial connection. + + This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`, + as seen from `server`. + + keyfile and password may be specified, but ssh config is checked for defaults. + + Parameters + ---------- + + lport : int + local port for connecting to the tunnel from this machine. + rport : int + port on the remote machine to connect to. + server : str + The ssh server to connect to. The full ssh server string will be parsed. + user@server:port + remoteip : str [Default: 127.0.0.1] + The remote ip, specifying the destination of the tunnel. + Default is localhost, which means that the tunnel would redirect + localhost:lport on this machine to localhost:rport on the *server*. + + keyfile : str; path to public key file + This specifies a key to be used in ssh login, default None. + Regular default ssh keys will be used without specifying this argument. + password : str; + Your ssh password to the ssh server. Note that if this is left None, + you will be prompted for it if passwordless key based login is unavailable. + timeout : int [default: 60] + The time (in seconds) after which no activity will result in the tunnel + closing. This prevents orphaned tunnels from running forever. + """ + if pexpect is None: + raise ImportError("pexpect unavailable, use paramiko_tunnel") + ssh="ssh " + if keyfile: + ssh += "-i " + keyfile + + if ':' in server: + server, port = server.split(':') + ssh += " -p %s" % port + + cmd = "%s -O check %s" % (ssh, server) + (output, exitstatus) = pexpect.run(cmd, withexitstatus=True) + if not exitstatus: + pid = int(output[output.find("(pid=")+5:output.find(")")]) + cmd = "%s -O forward -L 127.0.0.1:%i:%s:%i %s" % ( + ssh, lport, remoteip, rport, server) + (output, exitstatus) = pexpect.run(cmd, withexitstatus=True) + if not exitstatus: + atexit.register(_stop_tunnel, cmd.replace("-O forward", "-O cancel", 1)) + return pid + cmd = "%s -f -S none -L 127.0.0.1:%i:%s:%i %s sleep %i" % ( + ssh, lport, remoteip, rport, server, timeout) + + # pop SSH_ASKPASS from env + env = os.environ.copy() + env.pop('SSH_ASKPASS', None) + + ssh_newkey = 'Are you sure you want to continue connecting' + tunnel = pexpect.spawn(cmd, env=env) + failed = False + while True: + try: + i = tunnel.expect([ssh_newkey, '[Pp]assword:'], timeout=.1) + if i==0: + raise SSHException('The authenticity of the host can\'t be established.') + except pexpect.TIMEOUT: + continue + except pexpect.EOF: + if tunnel.exitstatus: + print(tunnel.exitstatus) + print(tunnel.before) + print(tunnel.after) + raise RuntimeError("tunnel '%s' failed to start"%(cmd)) + else: + return tunnel.pid + else: + if failed: + print("Password rejected, try again") + password=None + if password is None: + password = getpass("%s's password: "%(server)) + tunnel.sendline(password) + failed = True + +def _stop_tunnel(cmd): + pexpect.run(cmd) + +def _split_server(server): + if '@' in server: + username,server = server.split('@', 1) + else: + username = getuser() + if ':' in server: + server, port = server.split(':') + port = int(port) + else: + port = 22 + return username, server, port + +def paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60): + """launch a tunner with paramiko in a subprocess. This should only be used + when shell ssh is unavailable (e.g. Windows). + + This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`, + as seen from `server`. + + If you are familiar with ssh tunnels, this creates the tunnel: + + ssh server -L localhost:lport:remoteip:rport + + keyfile and password may be specified, but ssh config is checked for defaults. + + + Parameters + ---------- + + lport : int + local port for connecting to the tunnel from this machine. + rport : int + port on the remote machine to connect to. + server : str + The ssh server to connect to. The full ssh server string will be parsed. + user@server:port + remoteip : str [Default: 127.0.0.1] + The remote ip, specifying the destination of the tunnel. + Default is localhost, which means that the tunnel would redirect + localhost:lport on this machine to localhost:rport on the *server*. + + keyfile : str; path to public key file + This specifies a key to be used in ssh login, default None. + Regular default ssh keys will be used without specifying this argument. + password : str; + Your ssh password to the ssh server. Note that if this is left None, + you will be prompted for it if passwordless key based login is unavailable. + timeout : int [default: 60] + The time (in seconds) after which no activity will result in the tunnel + closing. This prevents orphaned tunnels from running forever. + + """ + if paramiko is None: + raise ImportError("Paramiko not available") + + if password is None: + if not _try_passwordless_paramiko(server, keyfile): + password = getpass("%s's password: "%(server)) + + p = Process(target=_paramiko_tunnel, + args=(lport, rport, server, remoteip), + kwargs=dict(keyfile=keyfile, password=password)) + p.daemon=False + p.start() + atexit.register(_shutdown_process, p) + return p + +def _shutdown_process(p): + if p.is_alive(): + p.terminate() + +def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None): + """Function for actually starting a paramiko tunnel, to be passed + to multiprocessing.Process(target=this), and not called directly. + """ + username, server, port = _split_server(server) + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + + try: + client.connect(server, port, username=username, key_filename=keyfile, + look_for_keys=True, password=password) +# except paramiko.AuthenticationException: +# if password is None: +# password = getpass("%s@%s's password: "%(username, server)) +# client.connect(server, port, username=username, password=password) +# else: +# raise + except Exception as e: + print('*** Failed to connect to %s:%d: %r' % (server, port, e)) + sys.exit(1) + + # Don't let SIGINT kill the tunnel subprocess + signal.signal(signal.SIGINT, signal.SIG_IGN) + + try: + forward_tunnel(lport, remoteip, rport, client.get_transport()) + except KeyboardInterrupt: + print('SIGINT: Port forwarding stopped cleanly') + sys.exit(0) + except Exception as e: + print("Port forwarding stopped uncleanly: %s"%e) + sys.exit(255) + +if sys.platform == 'win32': + ssh_tunnel = paramiko_tunnel +else: + ssh_tunnel = openssh_tunnel + + +__all__ = ['tunnel_connection', 'ssh_tunnel', 'openssh_tunnel', 'paramiko_tunnel', 'try_passwordless_ssh'] + + diff --git a/src/console/zmq/sugar/__init__.py b/src/console/zmq/sugar/__init__.py new file mode 100755 index 00000000..d0510a44 --- /dev/null +++ b/src/console/zmq/sugar/__init__.py @@ -0,0 +1,27 @@ +"""pure-Python sugar wrappers for core 0MQ objects.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from zmq.sugar import ( + constants, context, frame, poll, socket, tracker, version +) +from zmq import error + +__all__ = ['constants'] +for submod in ( + constants, context, error, frame, poll, socket, tracker, version +): + __all__.extend(submod.__all__) + +from zmq.error import * +from zmq.sugar.context import * +from zmq.sugar.tracker import * +from zmq.sugar.socket import * +from zmq.sugar.constants import * +from zmq.sugar.frame import * +from zmq.sugar.poll import * +# from zmq.sugar.stopwatch import * +# from zmq.sugar._device import * +from zmq.sugar.version import * diff --git a/src/console/zmq/sugar/attrsettr.py b/src/console/zmq/sugar/attrsettr.py new file mode 100755 index 00000000..4bbd36d6 --- /dev/null +++ b/src/console/zmq/sugar/attrsettr.py @@ -0,0 +1,52 @@ +# coding: utf-8 +"""Mixin for mapping set/getattr to self.set/get""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from . import constants + +class AttributeSetter(object): + + def __setattr__(self, key, value): + """set zmq options by attribute""" + + # regular setattr only allowed for class-defined attributes + for obj in [self] + self.__class__.mro(): + if key in obj.__dict__: + object.__setattr__(self, key, value) + return + + upper_key = key.upper() + try: + opt = getattr(constants, upper_key) + except AttributeError: + raise AttributeError("%s has no such option: %s" % ( + self.__class__.__name__, upper_key) + ) + else: + self._set_attr_opt(upper_key, opt, value) + + def _set_attr_opt(self, name, opt, value): + """override if setattr should do something other than call self.set""" + self.set(opt, value) + + def __getattr__(self, key): + """get zmq options by attribute""" + upper_key = key.upper() + try: + opt = getattr(constants, upper_key) + except AttributeError: + raise AttributeError("%s has no such option: %s" % ( + self.__class__.__name__, upper_key) + ) + else: + return self._get_attr_opt(upper_key, opt) + + def _get_attr_opt(self, name, opt): + """override if getattr should do something other than call self.get""" + return self.get(opt) + + +__all__ = ['AttributeSetter'] diff --git a/src/console/zmq/sugar/constants.py b/src/console/zmq/sugar/constants.py new file mode 100755 index 00000000..88281176 --- /dev/null +++ b/src/console/zmq/sugar/constants.py @@ -0,0 +1,98 @@ +"""0MQ Constants.""" + +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +from zmq.backend import constants +from zmq.utils.constant_names import ( + base_names, + switched_sockopt_names, + int_sockopt_names, + int64_sockopt_names, + bytes_sockopt_names, + fd_sockopt_names, + ctx_opt_names, + msg_opt_names, +) + +#----------------------------------------------------------------------------- +# Python module level constants +#----------------------------------------------------------------------------- + +__all__ = [ + 'int_sockopts', + 'int64_sockopts', + 'bytes_sockopts', + 'ctx_opts', + 'ctx_opt_names', + ] + +int_sockopts = set() +int64_sockopts = set() +bytes_sockopts = set() +fd_sockopts = set() +ctx_opts = set() +msg_opts = set() + + +if constants.VERSION < 30000: + int64_sockopt_names.extend(switched_sockopt_names) +else: + int_sockopt_names.extend(switched_sockopt_names) + +_UNDEFINED = -9999 + +def _add_constant(name, container=None): + """add a constant to be defined + + optionally add it to one of the sets for use in get/setopt checkers + """ + c = getattr(constants, name, _UNDEFINED) + if c == _UNDEFINED: + return + globals()[name] = c + __all__.append(name) + if container is not None: + container.add(c) + return c + +for name in base_names: + _add_constant(name) + +for name in int_sockopt_names: + _add_constant(name, int_sockopts) + +for name in int64_sockopt_names: + _add_constant(name, int64_sockopts) + +for name in bytes_sockopt_names: + _add_constant(name, bytes_sockopts) + +for name in fd_sockopt_names: + _add_constant(name, fd_sockopts) + +for name in ctx_opt_names: + _add_constant(name, ctx_opts) + +for name in msg_opt_names: + _add_constant(name, msg_opts) + +# ensure some aliases are always defined +aliases = [ + ('DONTWAIT', 'NOBLOCK'), + ('XREQ', 'DEALER'), + ('XREP', 'ROUTER'), +] +for group in aliases: + undefined = set() + found = None + for name in group: + value = getattr(constants, name, -1) + if value != -1: + found = value + else: + undefined.add(name) + if found is not None: + for name in undefined: + globals()[name] = found + __all__.append(name) diff --git a/src/console/zmq/sugar/context.py b/src/console/zmq/sugar/context.py new file mode 100755 index 00000000..86a9c5dc --- /dev/null +++ b/src/console/zmq/sugar/context.py @@ -0,0 +1,192 @@ +# coding: utf-8 +"""Python bindings for 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import atexit +import weakref + +from zmq.backend import Context as ContextBase +from . import constants +from .attrsettr import AttributeSetter +from .constants import ENOTSUP, ctx_opt_names +from .socket import Socket +from zmq.error import ZMQError + +from zmq.utils.interop import cast_int_addr + + +class Context(ContextBase, AttributeSetter): + """Create a zmq Context + + A zmq Context creates sockets via its ``ctx.socket`` method. + """ + sockopts = None + _instance = None + _shadow = False + _exiting = False + + def __init__(self, io_threads=1, **kwargs): + super(Context, self).__init__(io_threads=io_threads, **kwargs) + if kwargs.get('shadow', False): + self._shadow = True + else: + self._shadow = False + self.sockopts = {} + + self._exiting = False + if not self._shadow: + ctx_ref = weakref.ref(self) + def _notify_atexit(): + ctx = ctx_ref() + if ctx is not None: + ctx._exiting = True + atexit.register(_notify_atexit) + + def __del__(self): + """deleting a Context should terminate it, without trying non-threadsafe destroy""" + if not self._shadow and not self._exiting: + self.term() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.term() + + @classmethod + def shadow(cls, address): + """Shadow an existing libzmq context + + address is the integer address of the libzmq context + or an FFI pointer to it. + + .. versionadded:: 14.1 + """ + address = cast_int_addr(address) + return cls(shadow=address) + + @classmethod + def shadow_pyczmq(cls, ctx): + """Shadow an existing pyczmq context + + ctx is the FFI `zctx_t *` pointer + + .. versionadded:: 14.1 + """ + from pyczmq import zctx + + underlying = zctx.underlying(ctx) + address = cast_int_addr(underlying) + return cls(shadow=address) + + # static method copied from tornado IOLoop.instance + @classmethod + def instance(cls, io_threads=1): + """Returns a global Context instance. + + Most single-threaded applications have a single, global Context. + Use this method instead of passing around Context instances + throughout your code. + + A common pattern for classes that depend on Contexts is to use + a default argument to enable programs with multiple Contexts + but not require the argument for simpler applications: + + class MyClass(object): + def __init__(self, context=None): + self.context = context or Context.instance() + """ + if cls._instance is None or cls._instance.closed: + cls._instance = cls(io_threads=io_threads) + return cls._instance + + #------------------------------------------------------------------------- + # Hooks for ctxopt completion + #------------------------------------------------------------------------- + + def __dir__(self): + keys = dir(self.__class__) + + for collection in ( + ctx_opt_names, + ): + keys.extend(collection) + return keys + + #------------------------------------------------------------------------- + # Creating Sockets + #------------------------------------------------------------------------- + + @property + def _socket_class(self): + return Socket + + def socket(self, socket_type): + """Create a Socket associated with this Context. + + Parameters + ---------- + socket_type : int + The socket type, which can be any of the 0MQ socket types: + REQ, REP, PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, etc. + """ + if self.closed: + raise ZMQError(ENOTSUP) + s = self._socket_class(self, socket_type) + for opt, value in self.sockopts.items(): + try: + s.setsockopt(opt, value) + except ZMQError: + # ignore ZMQErrors, which are likely for socket options + # that do not apply to a particular socket type, e.g. + # SUBSCRIBE for non-SUB sockets. + pass + return s + + def setsockopt(self, opt, value): + """set default socket options for new sockets created by this Context + + .. versionadded:: 13.0 + """ + self.sockopts[opt] = value + + def getsockopt(self, opt): + """get default socket options for new sockets created by this Context + + .. versionadded:: 13.0 + """ + return self.sockopts[opt] + + def _set_attr_opt(self, name, opt, value): + """set default sockopts as attributes""" + if name in constants.ctx_opt_names: + return self.set(opt, value) + else: + self.sockopts[opt] = value + + def _get_attr_opt(self, name, opt): + """get default sockopts as attributes""" + if name in constants.ctx_opt_names: + return self.get(opt) + else: + if opt not in self.sockopts: + raise AttributeError(name) + else: + return self.sockopts[opt] + + def __delattr__(self, key): + """delete default sockopts as attributes""" + key = key.upper() + try: + opt = getattr(constants, key) + except AttributeError: + raise AttributeError("no such socket option: %s" % key) + else: + if opt not in self.sockopts: + raise AttributeError(key) + else: + del self.sockopts[opt] + +__all__ = ['Context'] diff --git a/src/console/zmq/sugar/frame.py b/src/console/zmq/sugar/frame.py new file mode 100755 index 00000000..9f556c86 --- /dev/null +++ b/src/console/zmq/sugar/frame.py @@ -0,0 +1,19 @@ +# coding: utf-8 +"""0MQ Frame pure Python methods.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from .attrsettr import AttributeSetter +from zmq.backend import Frame as FrameBase + + +class Frame(FrameBase, AttributeSetter): + def __getitem__(self, key): + # map Frame['User-Id'] to Frame.get('User-Id') + return self.get(key) + +# keep deprecated alias +Message = Frame +__all__ = ['Frame', 'Message']
\ No newline at end of file diff --git a/src/console/zmq/sugar/poll.py b/src/console/zmq/sugar/poll.py new file mode 100755 index 00000000..c7b1d1bb --- /dev/null +++ b/src/console/zmq/sugar/poll.py @@ -0,0 +1,161 @@ +"""0MQ polling related functions and classes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import zmq +from zmq.backend import zmq_poll +from .constants import POLLIN, POLLOUT, POLLERR + +#----------------------------------------------------------------------------- +# Polling related methods +#----------------------------------------------------------------------------- + + +class Poller(object): + """A stateful poll interface that mirrors Python's built-in poll.""" + sockets = None + _map = {} + + def __init__(self): + self.sockets = [] + self._map = {} + + def __contains__(self, socket): + return socket in self._map + + def register(self, socket, flags=POLLIN|POLLOUT): + """p.register(socket, flags=POLLIN|POLLOUT) + + Register a 0MQ socket or native fd for I/O monitoring. + + register(s,0) is equivalent to unregister(s). + + Parameters + ---------- + socket : zmq.Socket or native socket + A zmq.Socket or any Python object having a ``fileno()`` + method that returns a valid file descriptor. + flags : int + The events to watch for. Can be POLLIN, POLLOUT or POLLIN|POLLOUT. + If `flags=0`, socket will be unregistered. + """ + if flags: + if socket in self._map: + idx = self._map[socket] + self.sockets[idx] = (socket, flags) + else: + idx = len(self.sockets) + self.sockets.append((socket, flags)) + self._map[socket] = idx + elif socket in self._map: + # uregister sockets registered with no events + self.unregister(socket) + else: + # ignore new sockets with no events + pass + + def modify(self, socket, flags=POLLIN|POLLOUT): + """Modify the flags for an already registered 0MQ socket or native fd.""" + self.register(socket, flags) + + def unregister(self, socket): + """Remove a 0MQ socket or native fd for I/O monitoring. + + Parameters + ---------- + socket : Socket + The socket instance to stop polling. + """ + idx = self._map.pop(socket) + self.sockets.pop(idx) + # shift indices after deletion + for socket, flags in self.sockets[idx:]: + self._map[socket] -= 1 + + def poll(self, timeout=None): + """Poll the registered 0MQ or native fds for I/O. + + Parameters + ---------- + timeout : float, int + The timeout in milliseconds. If None, no `timeout` (infinite). This + is in milliseconds to be compatible with ``select.poll()``. The + underlying zmq_poll uses microseconds and we convert to that in + this function. + + Returns + ------- + events : list of tuples + The list of events that are ready to be processed. + This is a list of tuples of the form ``(socket, event)``, where the 0MQ Socket + or integer fd is the first element, and the poll event mask (POLLIN, POLLOUT) is the second. + It is common to call ``events = dict(poller.poll())``, + which turns the list of tuples into a mapping of ``socket : event``. + """ + if timeout is None or timeout < 0: + timeout = -1 + elif isinstance(timeout, float): + timeout = int(timeout) + return zmq_poll(self.sockets, timeout=timeout) + + +def select(rlist, wlist, xlist, timeout=None): + """select(rlist, wlist, xlist, timeout=None) -> (rlist, wlist, xlist) + + Return the result of poll as a lists of sockets ready for r/w/exception. + + This has the same interface as Python's built-in ``select.select()`` function. + + Parameters + ---------- + timeout : float, int, optional + The timeout in seconds. If None, no timeout (infinite). This is in seconds to be + compatible with ``select.select()``. The underlying zmq_poll uses microseconds + and we convert to that in this function. + rlist : list of sockets/FDs + sockets/FDs to be polled for read events + wlist : list of sockets/FDs + sockets/FDs to be polled for write events + xlist : list of sockets/FDs + sockets/FDs to be polled for error events + + Returns + ------- + (rlist, wlist, xlist) : tuple of lists of sockets (length 3) + Lists correspond to sockets available for read/write/error events respectively. + """ + if timeout is None: + timeout = -1 + # Convert from sec -> us for zmq_poll. + # zmq_poll accepts 3.x style timeout in ms + timeout = int(timeout*1000.0) + if timeout < 0: + timeout = -1 + sockets = [] + for s in set(rlist + wlist + xlist): + flags = 0 + if s in rlist: + flags |= POLLIN + if s in wlist: + flags |= POLLOUT + if s in xlist: + flags |= POLLERR + sockets.append((s, flags)) + return_sockets = zmq_poll(sockets, timeout) + rlist, wlist, xlist = [], [], [] + for s, flags in return_sockets: + if flags & POLLIN: + rlist.append(s) + if flags & POLLOUT: + wlist.append(s) + if flags & POLLERR: + xlist.append(s) + return rlist, wlist, xlist + +#----------------------------------------------------------------------------- +# Symbols to export +#----------------------------------------------------------------------------- + +__all__ = [ 'Poller', 'select' ] diff --git a/src/console/zmq/sugar/socket.py b/src/console/zmq/sugar/socket.py new file mode 100755 index 00000000..c91589d7 --- /dev/null +++ b/src/console/zmq/sugar/socket.py @@ -0,0 +1,495 @@ +# coding: utf-8 +"""0MQ Socket pure Python methods.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import codecs +import random +import warnings + +import zmq +from zmq.backend import Socket as SocketBase +from .poll import Poller +from . import constants +from .attrsettr import AttributeSetter +from zmq.error import ZMQError, ZMQBindError +from zmq.utils import jsonapi +from zmq.utils.strtypes import bytes,unicode,basestring +from zmq.utils.interop import cast_int_addr + +from .constants import ( + SNDMORE, ENOTSUP, POLLIN, + int64_sockopt_names, + int_sockopt_names, + bytes_sockopt_names, + fd_sockopt_names, +) +try: + import cPickle + pickle = cPickle +except: + cPickle = None + import pickle + +try: + DEFAULT_PROTOCOL = pickle.DEFAULT_PROTOCOL +except AttributeError: + DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL + + +class Socket(SocketBase, AttributeSetter): + """The ZMQ socket object + + To create a Socket, first create a Context:: + + ctx = zmq.Context.instance() + + then call ``ctx.socket(socket_type)``:: + + s = ctx.socket(zmq.ROUTER) + + """ + _shadow = False + + def __del__(self): + if not self._shadow: + self.close() + + # socket as context manager: + def __enter__(self): + """Sockets are context managers + + .. versionadded:: 14.4 + """ + return self + + def __exit__(self, *args, **kwargs): + self.close() + + #------------------------------------------------------------------------- + # Socket creation + #------------------------------------------------------------------------- + + @classmethod + def shadow(cls, address): + """Shadow an existing libzmq socket + + address is the integer address of the libzmq socket + or an FFI pointer to it. + + .. versionadded:: 14.1 + """ + address = cast_int_addr(address) + return cls(shadow=address) + + #------------------------------------------------------------------------- + # Deprecated aliases + #------------------------------------------------------------------------- + + @property + def socket_type(self): + warnings.warn("Socket.socket_type is deprecated, use Socket.type", + DeprecationWarning + ) + return self.type + + #------------------------------------------------------------------------- + # Hooks for sockopt completion + #------------------------------------------------------------------------- + + def __dir__(self): + keys = dir(self.__class__) + for collection in ( + bytes_sockopt_names, + int_sockopt_names, + int64_sockopt_names, + fd_sockopt_names, + ): + keys.extend(collection) + return keys + + #------------------------------------------------------------------------- + # Getting/Setting options + #------------------------------------------------------------------------- + setsockopt = SocketBase.set + getsockopt = SocketBase.get + + def set_string(self, option, optval, encoding='utf-8'): + """set socket options with a unicode object + + This is simply a wrapper for setsockopt to protect from encoding ambiguity. + + See the 0MQ documentation for details on specific options. + + Parameters + ---------- + option : int + The name of the option to set. Can be any of: SUBSCRIBE, + UNSUBSCRIBE, IDENTITY + optval : unicode string (unicode on py2, str on py3) + The value of the option to set. + encoding : str + The encoding to be used, default is utf8 + """ + if not isinstance(optval, unicode): + raise TypeError("unicode strings only") + return self.set(option, optval.encode(encoding)) + + setsockopt_unicode = setsockopt_string = set_string + + def get_string(self, option, encoding='utf-8'): + """get the value of a socket option + + See the 0MQ documentation for details on specific options. + + Parameters + ---------- + option : int + The option to retrieve. + + Returns + ------- + optval : unicode string (unicode on py2, str on py3) + The value of the option as a unicode string. + """ + + if option not in constants.bytes_sockopts: + raise TypeError("option %i will not return a string to be decoded"%option) + return self.getsockopt(option).decode(encoding) + + getsockopt_unicode = getsockopt_string = get_string + + def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=100): + """bind this socket to a random port in a range + + Parameters + ---------- + addr : str + The address string without the port to pass to ``Socket.bind()``. + min_port : int, optional + The minimum port in the range of ports to try (inclusive). + max_port : int, optional + The maximum port in the range of ports to try (exclusive). + max_tries : int, optional + The maximum number of bind attempts to make. + + Returns + ------- + port : int + The port the socket was bound to. + + Raises + ------ + ZMQBindError + if `max_tries` reached before successful bind + """ + for i in range(max_tries): + try: + port = random.randrange(min_port, max_port) + self.bind('%s:%s' % (addr, port)) + except ZMQError as exception: + if not exception.errno == zmq.EADDRINUSE: + raise + else: + return port + raise ZMQBindError("Could not bind socket to random port.") + + def get_hwm(self): + """get the High Water Mark + + On libzmq ≥ 3, this gets SNDHWM if available, otherwise RCVHWM + """ + major = zmq.zmq_version_info()[0] + if major >= 3: + # return sndhwm, fallback on rcvhwm + try: + return self.getsockopt(zmq.SNDHWM) + except zmq.ZMQError as e: + pass + + return self.getsockopt(zmq.RCVHWM) + else: + return self.getsockopt(zmq.HWM) + + def set_hwm(self, value): + """set the High Water Mark + + On libzmq ≥ 3, this sets both SNDHWM and RCVHWM + """ + major = zmq.zmq_version_info()[0] + if major >= 3: + raised = None + try: + self.sndhwm = value + except Exception as e: + raised = e + try: + self.rcvhwm = value + except Exception: + raised = e + + if raised: + raise raised + else: + return self.setsockopt(zmq.HWM, value) + + hwm = property(get_hwm, set_hwm, + """property for High Water Mark + + Setting hwm sets both SNDHWM and RCVHWM as appropriate. + It gets SNDHWM if available, otherwise RCVHWM. + """ + ) + + #------------------------------------------------------------------------- + # Sending and receiving messages + #------------------------------------------------------------------------- + + def send_multipart(self, msg_parts, flags=0, copy=True, track=False): + """send a sequence of buffers as a multipart message + + The zmq.SNDMORE flag is added to all msg parts before the last. + + Parameters + ---------- + msg_parts : iterable + A sequence of objects to send as a multipart message. Each element + can be any sendable object (Frame, bytes, buffer-providers) + flags : int, optional + SNDMORE is handled automatically for frames before the last. + copy : bool, optional + Should the frame(s) be sent in a copying or non-copying manner. + track : bool, optional + Should the frame(s) be tracked for notification that ZMQ has + finished with it (ignored if copy=True). + + Returns + ------- + None : if copy or not track + MessageTracker : if track and not copy + a MessageTracker object, whose `pending` property will + be True until the last send is completed. + """ + for msg in msg_parts[:-1]: + self.send(msg, SNDMORE|flags, copy=copy, track=track) + # Send the last part without the extra SNDMORE flag. + return self.send(msg_parts[-1], flags, copy=copy, track=track) + + def recv_multipart(self, flags=0, copy=True, track=False): + """receive a multipart message as a list of bytes or Frame objects + + Parameters + ---------- + flags : int, optional + Any supported flag: NOBLOCK. If NOBLOCK is set, this method + will raise a ZMQError with EAGAIN if a message is not ready. + If NOBLOCK is not set, then this method will block until a + message arrives. + copy : bool, optional + Should the message frame(s) be received in a copying or non-copying manner? + If False a Frame object is returned for each part, if True a copy of + the bytes is made for each frame. + track : bool, optional + Should the message frame(s) be tracked for notification that ZMQ has + finished with it? (ignored if copy=True) + + Returns + ------- + msg_parts : list + A list of frames in the multipart message; either Frames or bytes, + depending on `copy`. + + """ + parts = [self.recv(flags, copy=copy, track=track)] + # have first part already, only loop while more to receive + while self.getsockopt(zmq.RCVMORE): + part = self.recv(flags, copy=copy, track=track) + parts.append(part) + + return parts + + def send_string(self, u, flags=0, copy=True, encoding='utf-8'): + """send a Python unicode string as a message with an encoding + + 0MQ communicates with raw bytes, so you must encode/decode + text (unicode on py2, str on py3) around 0MQ. + + Parameters + ---------- + u : Python unicode string (unicode on py2, str on py3) + The unicode string to send. + flags : int, optional + Any valid send flag. + encoding : str [default: 'utf-8'] + The encoding to be used + """ + if not isinstance(u, basestring): + raise TypeError("unicode/str objects only") + return self.send(u.encode(encoding), flags=flags, copy=copy) + + send_unicode = send_string + + def recv_string(self, flags=0, encoding='utf-8'): + """receive a unicode string, as sent by send_string + + Parameters + ---------- + flags : int + Any valid recv flag. + encoding : str [default: 'utf-8'] + The encoding to be used + + Returns + ------- + s : unicode string (unicode on py2, str on py3) + The Python unicode string that arrives as encoded bytes. + """ + b = self.recv(flags=flags) + return b.decode(encoding) + + recv_unicode = recv_string + + def send_pyobj(self, obj, flags=0, protocol=DEFAULT_PROTOCOL): + """send a Python object as a message using pickle to serialize + + Parameters + ---------- + obj : Python object + The Python object to send. + flags : int + Any valid send flag. + protocol : int + The pickle protocol number to use. The default is pickle.DEFAULT_PROTOCOl + where defined, and pickle.HIGHEST_PROTOCOL elsewhere. + """ + msg = pickle.dumps(obj, protocol) + return self.send(msg, flags) + + def recv_pyobj(self, flags=0): + """receive a Python object as a message using pickle to serialize + + Parameters + ---------- + flags : int + Any valid recv flag. + + Returns + ------- + obj : Python object + The Python object that arrives as a message. + """ + s = self.recv(flags) + return pickle.loads(s) + + def send_json(self, obj, flags=0, **kwargs): + """send a Python object as a message using json to serialize + + Keyword arguments are passed on to json.dumps + + Parameters + ---------- + obj : Python object + The Python object to send + flags : int + Any valid send flag + """ + msg = jsonapi.dumps(obj, **kwargs) + return self.send(msg, flags) + + def recv_json(self, flags=0, **kwargs): + """receive a Python object as a message using json to serialize + + Keyword arguments are passed on to json.loads + + Parameters + ---------- + flags : int + Any valid recv flag. + + Returns + ------- + obj : Python object + The Python object that arrives as a message. + """ + msg = self.recv(flags) + return jsonapi.loads(msg, **kwargs) + + _poller_class = Poller + + def poll(self, timeout=None, flags=POLLIN): + """poll the socket for events + + The default is to poll forever for incoming + events. Timeout is in milliseconds, if specified. + + Parameters + ---------- + timeout : int [default: None] + The timeout (in milliseconds) to wait for an event. If unspecified + (or specified None), will wait forever for an event. + flags : bitfield (int) [default: POLLIN] + The event flags to poll for (any combination of POLLIN|POLLOUT). + The default is to check for incoming events (POLLIN). + + Returns + ------- + events : bitfield (int) + The events that are ready and waiting. Will be 0 if no events were ready + by the time timeout was reached. + """ + + if self.closed: + raise ZMQError(ENOTSUP) + + p = self._poller_class() + p.register(self, flags) + evts = dict(p.poll(timeout)) + # return 0 if no events, otherwise return event bitfield + return evts.get(self, 0) + + def get_monitor_socket(self, events=None, addr=None): + """Return a connected PAIR socket ready to receive the event notifications. + + .. versionadded:: libzmq-4.0 + .. versionadded:: 14.0 + + Parameters + ---------- + events : bitfield (int) [default: ZMQ_EVENTS_ALL] + The bitmask defining which events are wanted. + addr : string [default: None] + The optional endpoint for the monitoring sockets. + + Returns + ------- + socket : (PAIR) + The socket is already connected and ready to receive messages. + """ + # safe-guard, method only available on libzmq >= 4 + if zmq.zmq_version_info() < (4,): + raise NotImplementedError("get_monitor_socket requires libzmq >= 4, have %s" % zmq.zmq_version()) + if addr is None: + # create endpoint name from internal fd + addr = "inproc://monitor.s-%d" % self.FD + if events is None: + # use all events + events = zmq.EVENT_ALL + # attach monitoring socket + self.monitor(addr, events) + # create new PAIR socket and connect it + ret = self.context.socket(zmq.PAIR) + ret.connect(addr) + return ret + + def disable_monitor(self): + """Shutdown the PAIR socket (created using get_monitor_socket) + that is serving socket events. + + .. versionadded:: 14.4 + """ + self.monitor(None, 0) + + +__all__ = ['Socket'] diff --git a/src/console/zmq/sugar/tracker.py b/src/console/zmq/sugar/tracker.py new file mode 100755 index 00000000..fb8c007f --- /dev/null +++ b/src/console/zmq/sugar/tracker.py @@ -0,0 +1,120 @@ +"""Tracker for zero-copy messages with 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time + +try: + # below 3.3 + from threading import _Event as Event +except (ImportError, AttributeError): + # python throws ImportError, cython throws AttributeError + from threading import Event + +from zmq.error import NotDone +from zmq.backend import Frame + +class MessageTracker(object): + """MessageTracker(*towatch) + + A class for tracking if 0MQ is done using one or more messages. + + When you send a 0MQ message, it is not sent immediately. The 0MQ IO thread + sends the message at some later time. Often you want to know when 0MQ has + actually sent the message though. This is complicated by the fact that + a single 0MQ message can be sent multiple times using different sockets. + This class allows you to track all of the 0MQ usages of a message. + + Parameters + ---------- + *towatch : tuple of Event, MessageTracker, Message instances. + This list of objects to track. This class can track the low-level + Events used by the Message class, other MessageTrackers or + actual Messages. + """ + events = None + peers = None + + def __init__(self, *towatch): + """MessageTracker(*towatch) + + Create a message tracker to track a set of mesages. + + Parameters + ---------- + *towatch : tuple of Event, MessageTracker, Message instances. + This list of objects to track. This class can track the low-level + Events used by the Message class, other MessageTrackers or + actual Messages. + """ + self.events = set() + self.peers = set() + for obj in towatch: + if isinstance(obj, Event): + self.events.add(obj) + elif isinstance(obj, MessageTracker): + self.peers.add(obj) + elif isinstance(obj, Frame): + if not obj.tracker: + raise ValueError("Not a tracked message") + self.peers.add(obj.tracker) + else: + raise TypeError("Require Events or Message Frames, not %s"%type(obj)) + + @property + def done(self): + """Is 0MQ completely done with the message(s) being tracked?""" + for evt in self.events: + if not evt.is_set(): + return False + for pm in self.peers: + if not pm.done: + return False + return True + + def wait(self, timeout=-1): + """mt.wait(timeout=-1) + + Wait for 0MQ to be done with the message or until `timeout`. + + Parameters + ---------- + timeout : float [default: -1, wait forever] + Maximum time in (s) to wait before raising NotDone. + + Returns + ------- + None + if done before `timeout` + + Raises + ------ + NotDone + if `timeout` reached before I am done. + """ + tic = time.time() + if timeout is False or timeout < 0: + remaining = 3600*24*7 # a week + else: + remaining = timeout + done = False + for evt in self.events: + if remaining < 0: + raise NotDone + evt.wait(timeout=remaining) + if not evt.is_set(): + raise NotDone + toc = time.time() + remaining -= (toc-tic) + tic = toc + + for peer in self.peers: + if remaining < 0: + raise NotDone + peer.wait(timeout=remaining) + toc = time.time() + remaining -= (toc-tic) + tic = toc + +__all__ = ['MessageTracker']
\ No newline at end of file diff --git a/src/console/zmq/sugar/version.py b/src/console/zmq/sugar/version.py new file mode 100755 index 00000000..ea8fbbc4 --- /dev/null +++ b/src/console/zmq/sugar/version.py @@ -0,0 +1,48 @@ +"""PyZMQ and 0MQ version functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from zmq.backend import zmq_version_info + + +VERSION_MAJOR = 14 +VERSION_MINOR = 5 +VERSION_PATCH = 0 +VERSION_EXTRA = "" +__version__ = '%i.%i.%i' % (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) + +if VERSION_EXTRA: + __version__ = "%s-%s" % (__version__, VERSION_EXTRA) + version_info = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH, float('inf')) +else: + version_info = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) + +__revision__ = '' + +def pyzmq_version(): + """return the version of pyzmq as a string""" + if __revision__: + return '@'.join([__version__,__revision__[:6]]) + else: + return __version__ + +def pyzmq_version_info(): + """return the pyzmq version as a tuple of at least three numbers + + If pyzmq is a development version, `inf` will be appended after the third integer. + """ + return version_info + + +def zmq_version(): + """return the version of libzmq as a string""" + return "%i.%i.%i" % zmq_version_info() + + +__all__ = ['zmq_version', 'zmq_version_info', + 'pyzmq_version','pyzmq_version_info', + '__version__', '__revision__' +] + diff --git a/src/console/zmq/tests/__init__.py b/src/console/zmq/tests/__init__.py new file mode 100755 index 00000000..325a3f19 --- /dev/null +++ b/src/console/zmq/tests/__init__.py @@ -0,0 +1,211 @@ +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +import functools +import sys +import time +from threading import Thread + +from unittest import TestCase + +import zmq +from zmq.utils import jsonapi + +try: + import gevent + from zmq import green as gzmq + have_gevent = True +except ImportError: + have_gevent = False + +try: + from unittest import SkipTest +except ImportError: + try: + from nose import SkipTest + except ImportError: + class SkipTest(Exception): + pass + +PYPY = 'PyPy' in sys.version + +#----------------------------------------------------------------------------- +# skip decorators (directly from unittest) +#----------------------------------------------------------------------------- + +_id = lambda x: x + +def skip(reason): + """ + Unconditionally skip a test. + """ + def decorator(test_item): + if not (isinstance(test_item, type) and issubclass(test_item, TestCase)): + @functools.wraps(test_item) + def skip_wrapper(*args, **kwargs): + raise SkipTest(reason) + test_item = skip_wrapper + + test_item.__unittest_skip__ = True + test_item.__unittest_skip_why__ = reason + return test_item + return decorator + +def skip_if(condition, reason="Skipped"): + """ + Skip a test if the condition is true. + """ + if condition: + return skip(reason) + return _id + +skip_pypy = skip_if(PYPY, "Doesn't work on PyPy") + +#----------------------------------------------------------------------------- +# Base test class +#----------------------------------------------------------------------------- + +class BaseZMQTestCase(TestCase): + green = False + + @property + def Context(self): + if self.green: + return gzmq.Context + else: + return zmq.Context + + def socket(self, socket_type): + s = self.context.socket(socket_type) + self.sockets.append(s) + return s + + def setUp(self): + if self.green and not have_gevent: + raise SkipTest("requires gevent") + self.context = self.Context.instance() + self.sockets = [] + + def tearDown(self): + contexts = set([self.context]) + while self.sockets: + sock = self.sockets.pop() + contexts.add(sock.context) # in case additional contexts are created + sock.close(0) + for ctx in contexts: + t = Thread(target=ctx.term) + t.daemon = True + t.start() + t.join(timeout=2) + if t.is_alive(): + # reset Context.instance, so the failure to term doesn't corrupt subsequent tests + zmq.sugar.context.Context._instance = None + raise RuntimeError("context could not terminate, open sockets likely remain in test") + + def create_bound_pair(self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'): + """Create a bound socket pair using a random port.""" + s1 = self.context.socket(type1) + s1.setsockopt(zmq.LINGER, 0) + port = s1.bind_to_random_port(interface) + s2 = self.context.socket(type2) + s2.setsockopt(zmq.LINGER, 0) + s2.connect('%s:%s' % (interface, port)) + self.sockets.extend([s1,s2]) + return s1, s2 + + def ping_pong(self, s1, s2, msg): + s1.send(msg) + msg2 = s2.recv() + s2.send(msg2) + msg3 = s1.recv() + return msg3 + + def ping_pong_json(self, s1, s2, o): + if jsonapi.jsonmod is None: + raise SkipTest("No json library") + s1.send_json(o) + o2 = s2.recv_json() + s2.send_json(o2) + o3 = s1.recv_json() + return o3 + + def ping_pong_pyobj(self, s1, s2, o): + s1.send_pyobj(o) + o2 = s2.recv_pyobj() + s2.send_pyobj(o2) + o3 = s1.recv_pyobj() + return o3 + + def assertRaisesErrno(self, errno, func, *args, **kwargs): + try: + func(*args, **kwargs) + except zmq.ZMQError as e: + self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \ +got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno))) + else: + self.fail("Function did not raise any error") + + def _select_recv(self, multipart, socket, **kwargs): + """call recv[_multipart] in a way that raises if there is nothing to receive""" + if zmq.zmq_version_info() >= (3,1,0): + # zmq 3.1 has a bug, where poll can return false positives, + # so we wait a little bit just in case + # See LIBZMQ-280 on JIRA + time.sleep(0.1) + + r,w,x = zmq.select([socket], [], [], timeout=5) + assert len(r) > 0, "Should have received a message" + kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0) + + recv = socket.recv_multipart if multipart else socket.recv + return recv(**kwargs) + + def recv(self, socket, **kwargs): + """call recv in a way that raises if there is nothing to receive""" + return self._select_recv(False, socket, **kwargs) + + def recv_multipart(self, socket, **kwargs): + """call recv_multipart in a way that raises if there is nothing to receive""" + return self._select_recv(True, socket, **kwargs) + + +class PollZMQTestCase(BaseZMQTestCase): + pass + +class GreenTest: + """Mixin for making green versions of test classes""" + green = True + + def assertRaisesErrno(self, errno, func, *args, **kwargs): + if errno == zmq.EAGAIN: + raise SkipTest("Skipping because we're green.") + try: + func(*args, **kwargs) + except zmq.ZMQError: + e = sys.exc_info()[1] + self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \ +got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno))) + else: + self.fail("Function did not raise any error") + + def tearDown(self): + contexts = set([self.context]) + while self.sockets: + sock = self.sockets.pop() + contexts.add(sock.context) # in case additional contexts are created + sock.close() + try: + gevent.joinall([gevent.spawn(ctx.term) for ctx in contexts], timeout=2, raise_error=True) + except gevent.Timeout: + raise RuntimeError("context could not terminate, open sockets likely remain in test") + + def skip_green(self): + raise SkipTest("Skipping because we are green") + +def skip_green(f): + def skipping_test(self, *args, **kwargs): + if self.green: + raise SkipTest("Skipping because we are green") + else: + return f(self, *args, **kwargs) + return skipping_test diff --git a/src/console/zmq/tests/test_auth.py b/src/console/zmq/tests/test_auth.py new file mode 100755 index 00000000..d350f61f --- /dev/null +++ b/src/console/zmq/tests/test_auth.py @@ -0,0 +1,431 @@ +# -*- coding: utf8 -*- + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import logging +import os +import shutil +import sys +import tempfile + +import zmq.auth +from zmq.auth.ioloop import IOLoopAuthenticator +from zmq.auth.thread import ThreadAuthenticator + +from zmq.eventloop import ioloop, zmqstream +from zmq.tests import (BaseZMQTestCase, SkipTest) + +class BaseAuthTestCase(BaseZMQTestCase): + def setUp(self): + if zmq.zmq_version_info() < (4,0): + raise SkipTest("security is new in libzmq 4.0") + try: + zmq.curve_keypair() + except zmq.ZMQError: + raise SkipTest("security requires libzmq to be linked against libsodium") + super(BaseAuthTestCase, self).setUp() + # enable debug logging while we run tests + logging.getLogger('zmq.auth').setLevel(logging.DEBUG) + self.auth = self.make_auth() + self.auth.start() + self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs() + + def make_auth(self): + raise NotImplementedError() + + def tearDown(self): + if self.auth: + self.auth.stop() + self.auth = None + self.remove_certs(self.base_dir) + super(BaseAuthTestCase, self).tearDown() + + def create_certs(self): + """Create CURVE certificates for a test""" + + # Create temporary CURVE keypairs for this test run. We create all keys in a + # temp directory and then move them into the appropriate private or public + # directory. + + base_dir = tempfile.mkdtemp() + keys_dir = os.path.join(base_dir, 'certificates') + public_keys_dir = os.path.join(base_dir, 'public_keys') + secret_keys_dir = os.path.join(base_dir, 'private_keys') + + os.mkdir(keys_dir) + os.mkdir(public_keys_dir) + os.mkdir(secret_keys_dir) + + server_public_file, server_secret_file = zmq.auth.create_certificates(keys_dir, "server") + client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "client") + + for key_file in os.listdir(keys_dir): + if key_file.endswith(".key"): + shutil.move(os.path.join(keys_dir, key_file), + os.path.join(public_keys_dir, '.')) + + for key_file in os.listdir(keys_dir): + if key_file.endswith(".key_secret"): + shutil.move(os.path.join(keys_dir, key_file), + os.path.join(secret_keys_dir, '.')) + + return (base_dir, public_keys_dir, secret_keys_dir) + + def remove_certs(self, base_dir): + """Remove certificates for a test""" + shutil.rmtree(base_dir) + + def load_certs(self, secret_keys_dir): + """Return server and client certificate keys""" + server_secret_file = os.path.join(secret_keys_dir, "server.key_secret") + client_secret_file = os.path.join(secret_keys_dir, "client.key_secret") + + server_public, server_secret = zmq.auth.load_certificate(server_secret_file) + client_public, client_secret = zmq.auth.load_certificate(client_secret_file) + + return server_public, server_secret, client_public, client_secret + + +class TestThreadAuthentication(BaseAuthTestCase): + """Test authentication running in a thread""" + + def make_auth(self): + return ThreadAuthenticator(self.context) + + def can_connect(self, server, client): + """Check if client can connect to server using tcp transport""" + result = False + iface = 'tcp://127.0.0.1' + port = server.bind_to_random_port(iface) + client.connect("%s:%i" % (iface, port)) + msg = [b"Hello World"] + server.send_multipart(msg) + if client.poll(1000): + rcvd_msg = client.recv_multipart() + self.assertEqual(rcvd_msg, msg) + result = True + return result + + def test_null(self): + """threaded auth - NULL""" + # A default NULL connection should always succeed, and not + # go through our authentication infrastructure at all. + self.auth.stop() + self.auth = None + + server = self.socket(zmq.PUSH) + client = self.socket(zmq.PULL) + self.assertTrue(self.can_connect(server, client)) + + # By setting a domain we switch on authentication for NULL sockets, + # though no policies are configured yet. The client connection + # should still be allowed. + server = self.socket(zmq.PUSH) + server.zap_domain = b'global' + client = self.socket(zmq.PULL) + self.assertTrue(self.can_connect(server, client)) + + def test_blacklist(self): + """threaded auth - Blacklist""" + # Blacklist 127.0.0.1, connection should fail + self.auth.deny('127.0.0.1') + server = self.socket(zmq.PUSH) + # By setting a domain we switch on authentication for NULL sockets, + # though no policies are configured yet. + server.zap_domain = b'global' + client = self.socket(zmq.PULL) + self.assertFalse(self.can_connect(server, client)) + + def test_whitelist(self): + """threaded auth - Whitelist""" + # Whitelist 127.0.0.1, connection should pass" + self.auth.allow('127.0.0.1') + server = self.socket(zmq.PUSH) + # By setting a domain we switch on authentication for NULL sockets, + # though no policies are configured yet. + server.zap_domain = b'global' + client = self.socket(zmq.PULL) + self.assertTrue(self.can_connect(server, client)) + + def test_plain(self): + """threaded auth - PLAIN""" + + # Try PLAIN authentication - without configuring server, connection should fail + server = self.socket(zmq.PUSH) + server.plain_server = True + client = self.socket(zmq.PULL) + client.plain_username = b'admin' + client.plain_password = b'Password' + self.assertFalse(self.can_connect(server, client)) + + # Try PLAIN authentication - with server configured, connection should pass + server = self.socket(zmq.PUSH) + server.plain_server = True + client = self.socket(zmq.PULL) + client.plain_username = b'admin' + client.plain_password = b'Password' + self.auth.configure_plain(domain='*', passwords={'admin': 'Password'}) + self.assertTrue(self.can_connect(server, client)) + + # Try PLAIN authentication - with bogus credentials, connection should fail + server = self.socket(zmq.PUSH) + server.plain_server = True + client = self.socket(zmq.PULL) + client.plain_username = b'admin' + client.plain_password = b'Bogus' + self.assertFalse(self.can_connect(server, client)) + + # Remove authenticator and check that a normal connection works + self.auth.stop() + self.auth = None + + server = self.socket(zmq.PUSH) + client = self.socket(zmq.PULL) + self.assertTrue(self.can_connect(server, client)) + client.close() + server.close() + + def test_curve(self): + """threaded auth - CURVE""" + self.auth.allow('127.0.0.1') + certs = self.load_certs(self.secret_keys_dir) + server_public, server_secret, client_public, client_secret = certs + + #Try CURVE authentication - without configuring server, connection should fail + server = self.socket(zmq.PUSH) + server.curve_publickey = server_public + server.curve_secretkey = server_secret + server.curve_server = True + client = self.socket(zmq.PULL) + client.curve_publickey = client_public + client.curve_secretkey = client_secret + client.curve_serverkey = server_public + self.assertFalse(self.can_connect(server, client)) + + #Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass + self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) + server = self.socket(zmq.PUSH) + server.curve_publickey = server_public + server.curve_secretkey = server_secret + server.curve_server = True + client = self.socket(zmq.PULL) + client.curve_publickey = client_public + client.curve_secretkey = client_secret + client.curve_serverkey = server_public + self.assertTrue(self.can_connect(server, client)) + + # Try CURVE authentication - with server configured, connection should pass + self.auth.configure_curve(domain='*', location=self.public_keys_dir) + server = self.socket(zmq.PUSH) + server.curve_publickey = server_public + server.curve_secretkey = server_secret + server.curve_server = True + client = self.socket(zmq.PULL) + client.curve_publickey = client_public + client.curve_secretkey = client_secret + client.curve_serverkey = server_public + self.assertTrue(self.can_connect(server, client)) + + # Remove authenticator and check that a normal connection works + self.auth.stop() + self.auth = None + + # Try connecting using NULL and no authentication enabled, connection should pass + server = self.socket(zmq.PUSH) + client = self.socket(zmq.PULL) + self.assertTrue(self.can_connect(server, client)) + + +def with_ioloop(method, expect_success=True): + """decorator for running tests with an IOLoop""" + def test_method(self): + r = method(self) + + loop = self.io_loop + if expect_success: + self.pullstream.on_recv(self.on_message_succeed) + else: + self.pullstream.on_recv(self.on_message_fail) + + t = loop.time() + loop.add_callback(self.attempt_connection) + loop.add_callback(self.send_msg) + if expect_success: + loop.add_timeout(t + 1, self.on_test_timeout_fail) + else: + loop.add_timeout(t + 1, self.on_test_timeout_succeed) + + loop.start() + if self.fail_msg: + self.fail(self.fail_msg) + + return r + return test_method + +def should_auth(method): + return with_ioloop(method, True) + +def should_not_auth(method): + return with_ioloop(method, False) + +class TestIOLoopAuthentication(BaseAuthTestCase): + """Test authentication running in ioloop""" + + def setUp(self): + self.fail_msg = None + self.io_loop = ioloop.IOLoop() + super(TestIOLoopAuthentication, self).setUp() + self.server = self.socket(zmq.PUSH) + self.client = self.socket(zmq.PULL) + self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop) + self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop) + + def make_auth(self): + return IOLoopAuthenticator(self.context, io_loop=self.io_loop) + + def tearDown(self): + if self.auth: + self.auth.stop() + self.auth = None + self.io_loop.close(all_fds=True) + super(TestIOLoopAuthentication, self).tearDown() + + def attempt_connection(self): + """Check if client can connect to server using tcp transport""" + iface = 'tcp://127.0.0.1' + port = self.server.bind_to_random_port(iface) + self.client.connect("%s:%i" % (iface, port)) + + def send_msg(self): + """Send a message from server to a client""" + msg = [b"Hello World"] + self.pushstream.send_multipart(msg) + + def on_message_succeed(self, frames): + """A message was received, as expected.""" + if frames != [b"Hello World"]: + self.fail_msg = "Unexpected message received" + self.io_loop.stop() + + def on_message_fail(self, frames): + """A message was received, unexpectedly.""" + self.fail_msg = 'Received messaged unexpectedly, security failed' + self.io_loop.stop() + + def on_test_timeout_succeed(self): + """Test timer expired, indicates test success""" + self.io_loop.stop() + + def on_test_timeout_fail(self): + """Test timer expired, indicates test failure""" + self.fail_msg = 'Test timed out' + self.io_loop.stop() + + @should_auth + def test_none(self): + """ioloop auth - NONE""" + # A default NULL connection should always succeed, and not + # go through our authentication infrastructure at all. + # no auth should be running + self.auth.stop() + self.auth = None + + @should_auth + def test_null(self): + """ioloop auth - NULL""" + # By setting a domain we switch on authentication for NULL sockets, + # though no policies are configured yet. The client connection + # should still be allowed. + self.server.zap_domain = b'global' + + @should_not_auth + def test_blacklist(self): + """ioloop auth - Blacklist""" + # Blacklist 127.0.0.1, connection should fail + self.auth.deny('127.0.0.1') + self.server.zap_domain = b'global' + + @should_auth + def test_whitelist(self): + """ioloop auth - Whitelist""" + # Whitelist 127.0.0.1, which overrides the blacklist, connection should pass" + self.auth.allow('127.0.0.1') + + self.server.setsockopt(zmq.ZAP_DOMAIN, b'global') + + @should_not_auth + def test_plain_unconfigured_server(self): + """ioloop auth - PLAIN, unconfigured server""" + self.client.plain_username = b'admin' + self.client.plain_password = b'Password' + # Try PLAIN authentication - without configuring server, connection should fail + self.server.plain_server = True + + @should_auth + def test_plain_configured_server(self): + """ioloop auth - PLAIN, configured server""" + self.client.plain_username = b'admin' + self.client.plain_password = b'Password' + # Try PLAIN authentication - with server configured, connection should pass + self.server.plain_server = True + self.auth.configure_plain(domain='*', passwords={'admin': 'Password'}) + + @should_not_auth + def test_plain_bogus_credentials(self): + """ioloop auth - PLAIN, bogus credentials""" + self.client.plain_username = b'admin' + self.client.plain_password = b'Bogus' + self.server.plain_server = True + + self.auth.configure_plain(domain='*', passwords={'admin': 'Password'}) + + @should_not_auth + def test_curve_unconfigured_server(self): + """ioloop auth - CURVE, unconfigured server""" + certs = self.load_certs(self.secret_keys_dir) + server_public, server_secret, client_public, client_secret = certs + + self.auth.allow('127.0.0.1') + + self.server.curve_publickey = server_public + self.server.curve_secretkey = server_secret + self.server.curve_server = True + + self.client.curve_publickey = client_public + self.client.curve_secretkey = client_secret + self.client.curve_serverkey = server_public + + @should_auth + def test_curve_allow_any(self): + """ioloop auth - CURVE, CURVE_ALLOW_ANY""" + certs = self.load_certs(self.secret_keys_dir) + server_public, server_secret, client_public, client_secret = certs + + self.auth.allow('127.0.0.1') + self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) + + self.server.curve_publickey = server_public + self.server.curve_secretkey = server_secret + self.server.curve_server = True + + self.client.curve_publickey = client_public + self.client.curve_secretkey = client_secret + self.client.curve_serverkey = server_public + + @should_auth + def test_curve_configured_server(self): + """ioloop auth - CURVE, configured server""" + self.auth.allow('127.0.0.1') + certs = self.load_certs(self.secret_keys_dir) + server_public, server_secret, client_public, client_secret = certs + + self.auth.configure_curve(domain='*', location=self.public_keys_dir) + + self.server.curve_publickey = server_public + self.server.curve_secretkey = server_secret + self.server.curve_server = True + + self.client.curve_publickey = client_public + self.client.curve_secretkey = client_secret + self.client.curve_serverkey = server_public diff --git a/src/console/zmq/tests/test_cffi_backend.py b/src/console/zmq/tests/test_cffi_backend.py new file mode 100755 index 00000000..1f85eebf --- /dev/null +++ b/src/console/zmq/tests/test_cffi_backend.py @@ -0,0 +1,310 @@ +# -*- coding: utf8 -*- + +import sys +import time + +from unittest import TestCase + +from zmq.tests import BaseZMQTestCase, SkipTest + +try: + from zmq.backend.cffi import ( + zmq_version_info, + PUSH, PULL, IDENTITY, + REQ, REP, POLLIN, POLLOUT, + ) + from zmq.backend.cffi._cffi import ffi, C + have_ffi_backend = True +except ImportError: + have_ffi_backend = False + + +class TestCFFIBackend(TestCase): + + def setUp(self): + if not have_ffi_backend or not 'PyPy' in sys.version: + raise SkipTest('PyPy Tests Only') + + def test_zmq_version_info(self): + version = zmq_version_info() + + assert version[0] in range(2,11) + + def test_zmq_ctx_new_destroy(self): + ctx = C.zmq_ctx_new() + + assert ctx != ffi.NULL + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_socket_open_close(self): + ctx = C.zmq_ctx_new() + socket = C.zmq_socket(ctx, PUSH) + + assert ctx != ffi.NULL + assert ffi.NULL != socket + assert 0 == C.zmq_close(socket) + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_setsockopt(self): + ctx = C.zmq_ctx_new() + socket = C.zmq_socket(ctx, PUSH) + + identity = ffi.new('char[3]', 'zmq') + ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3) + + assert ret == 0 + assert ctx != ffi.NULL + assert ffi.NULL != socket + assert 0 == C.zmq_close(socket) + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_getsockopt(self): + ctx = C.zmq_ctx_new() + socket = C.zmq_socket(ctx, PUSH) + + identity = ffi.new('char[]', 'zmq') + ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3) + assert ret == 0 + + option_len = ffi.new('size_t*', 3) + option = ffi.new('char*') + ret = C.zmq_getsockopt(socket, + IDENTITY, + ffi.cast('void*', option), + option_len) + + assert ret == 0 + assert ffi.string(ffi.cast('char*', option))[0] == "z" + assert ffi.string(ffi.cast('char*', option))[1] == "m" + assert ffi.string(ffi.cast('char*', option))[2] == "q" + assert ctx != ffi.NULL + assert ffi.NULL != socket + assert 0 == C.zmq_close(socket) + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_bind(self): + ctx = C.zmq_ctx_new() + socket = C.zmq_socket(ctx, 8) + + assert 0 == C.zmq_bind(socket, 'tcp://*:4444') + assert ctx != ffi.NULL + assert ffi.NULL != socket + assert 0 == C.zmq_close(socket) + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_bind_connect(self): + ctx = C.zmq_ctx_new() + + socket1 = C.zmq_socket(ctx, PUSH) + socket2 = C.zmq_socket(ctx, PULL) + + assert 0 == C.zmq_bind(socket1, 'tcp://*:4444') + assert 0 == C.zmq_connect(socket2, 'tcp://127.0.0.1:4444') + assert ctx != ffi.NULL + assert ffi.NULL != socket1 + assert ffi.NULL != socket2 + assert 0 == C.zmq_close(socket1) + assert 0 == C.zmq_close(socket2) + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_msg_init_close(self): + zmq_msg = ffi.new('zmq_msg_t*') + + assert ffi.NULL != zmq_msg + assert 0 == C.zmq_msg_init(zmq_msg) + assert 0 == C.zmq_msg_close(zmq_msg) + + def test_zmq_msg_init_size(self): + zmq_msg = ffi.new('zmq_msg_t*') + + assert ffi.NULL != zmq_msg + assert 0 == C.zmq_msg_init_size(zmq_msg, 10) + assert 0 == C.zmq_msg_close(zmq_msg) + + def test_zmq_msg_init_data(self): + zmq_msg = ffi.new('zmq_msg_t*') + message = ffi.new('char[5]', 'Hello') + + assert 0 == C.zmq_msg_init_data(zmq_msg, + ffi.cast('void*', message), + 5, + ffi.NULL, + ffi.NULL) + + assert ffi.NULL != zmq_msg + assert 0 == C.zmq_msg_close(zmq_msg) + + def test_zmq_msg_data(self): + zmq_msg = ffi.new('zmq_msg_t*') + message = ffi.new('char[]', 'Hello') + assert 0 == C.zmq_msg_init_data(zmq_msg, + ffi.cast('void*', message), + 5, + ffi.NULL, + ffi.NULL) + + data = C.zmq_msg_data(zmq_msg) + + assert ffi.NULL != zmq_msg + assert ffi.string(ffi.cast("char*", data)) == 'Hello' + assert 0 == C.zmq_msg_close(zmq_msg) + + + def test_zmq_send(self): + ctx = C.zmq_ctx_new() + + sender = C.zmq_socket(ctx, REQ) + receiver = C.zmq_socket(ctx, REP) + + assert 0 == C.zmq_bind(receiver, 'tcp://*:7777') + assert 0 == C.zmq_connect(sender, 'tcp://127.0.0.1:7777') + + time.sleep(0.1) + + zmq_msg = ffi.new('zmq_msg_t*') + message = ffi.new('char[5]', 'Hello') + + C.zmq_msg_init_data(zmq_msg, + ffi.cast('void*', message), + ffi.cast('size_t', 5), + ffi.NULL, + ffi.NULL) + + assert 5 == C.zmq_msg_send(zmq_msg, sender, 0) + assert 0 == C.zmq_msg_close(zmq_msg) + assert C.zmq_close(sender) == 0 + assert C.zmq_close(receiver) == 0 + assert C.zmq_ctx_destroy(ctx) == 0 + + def test_zmq_recv(self): + ctx = C.zmq_ctx_new() + + sender = C.zmq_socket(ctx, REQ) + receiver = C.zmq_socket(ctx, REP) + + assert 0 == C.zmq_bind(receiver, 'tcp://*:2222') + assert 0 == C.zmq_connect(sender, 'tcp://127.0.0.1:2222') + + time.sleep(0.1) + + zmq_msg = ffi.new('zmq_msg_t*') + message = ffi.new('char[5]', 'Hello') + + C.zmq_msg_init_data(zmq_msg, + ffi.cast('void*', message), + ffi.cast('size_t', 5), + ffi.NULL, + ffi.NULL) + + zmq_msg2 = ffi.new('zmq_msg_t*') + C.zmq_msg_init(zmq_msg2) + + assert 5 == C.zmq_msg_send(zmq_msg, sender, 0) + assert 5 == C.zmq_msg_recv(zmq_msg2, receiver, 0) + assert 5 == C.zmq_msg_size(zmq_msg2) + assert b"Hello" == ffi.buffer(C.zmq_msg_data(zmq_msg2), + C.zmq_msg_size(zmq_msg2))[:] + assert C.zmq_close(sender) == 0 + assert C.zmq_close(receiver) == 0 + assert C.zmq_ctx_destroy(ctx) == 0 + + def test_zmq_poll(self): + ctx = C.zmq_ctx_new() + + sender = C.zmq_socket(ctx, REQ) + receiver = C.zmq_socket(ctx, REP) + + r1 = C.zmq_bind(receiver, 'tcp://*:3333') + r2 = C.zmq_connect(sender, 'tcp://127.0.0.1:3333') + + zmq_msg = ffi.new('zmq_msg_t*') + message = ffi.new('char[5]', 'Hello') + + C.zmq_msg_init_data(zmq_msg, + ffi.cast('void*', message), + ffi.cast('size_t', 5), + ffi.NULL, + ffi.NULL) + + receiver_pollitem = ffi.new('zmq_pollitem_t*') + receiver_pollitem.socket = receiver + receiver_pollitem.fd = 0 + receiver_pollitem.events = POLLIN | POLLOUT + receiver_pollitem.revents = 0 + + ret = C.zmq_poll(ffi.NULL, 0, 0) + assert ret == 0 + + ret = C.zmq_poll(receiver_pollitem, 1, 0) + assert ret == 0 + + ret = C.zmq_msg_send(zmq_msg, sender, 0) + print(ffi.string(C.zmq_strerror(C.zmq_errno()))) + assert ret == 5 + + time.sleep(0.2) + + ret = C.zmq_poll(receiver_pollitem, 1, 0) + assert ret == 1 + + assert int(receiver_pollitem.revents) & POLLIN + assert not int(receiver_pollitem.revents) & POLLOUT + + zmq_msg2 = ffi.new('zmq_msg_t*') + C.zmq_msg_init(zmq_msg2) + + ret_recv = C.zmq_msg_recv(zmq_msg2, receiver, 0) + assert ret_recv == 5 + + assert 5 == C.zmq_msg_size(zmq_msg2) + assert b"Hello" == ffi.buffer(C.zmq_msg_data(zmq_msg2), + C.zmq_msg_size(zmq_msg2))[:] + + sender_pollitem = ffi.new('zmq_pollitem_t*') + sender_pollitem.socket = sender + sender_pollitem.fd = 0 + sender_pollitem.events = POLLIN | POLLOUT + sender_pollitem.revents = 0 + + ret = C.zmq_poll(sender_pollitem, 1, 0) + assert ret == 0 + + zmq_msg_again = ffi.new('zmq_msg_t*') + message_again = ffi.new('char[11]', 'Hello Again') + + C.zmq_msg_init_data(zmq_msg_again, + ffi.cast('void*', message_again), + ffi.cast('size_t', 11), + ffi.NULL, + ffi.NULL) + + assert 11 == C.zmq_msg_send(zmq_msg_again, receiver, 0) + + time.sleep(0.2) + + assert 0 <= C.zmq_poll(sender_pollitem, 1, 0) + assert int(sender_pollitem.revents) & POLLIN + assert 11 == C.zmq_msg_recv(zmq_msg2, sender, 0) + assert 11 == C.zmq_msg_size(zmq_msg2) + assert b"Hello Again" == ffi.buffer(C.zmq_msg_data(zmq_msg2), + int(C.zmq_msg_size(zmq_msg2)))[:] + assert 0 == C.zmq_close(sender) + assert 0 == C.zmq_close(receiver) + assert 0 == C.zmq_ctx_destroy(ctx) + assert 0 == C.zmq_msg_close(zmq_msg) + assert 0 == C.zmq_msg_close(zmq_msg2) + assert 0 == C.zmq_msg_close(zmq_msg_again) + + def test_zmq_stopwatch_functions(self): + stopwatch = C.zmq_stopwatch_start() + ret = C.zmq_stopwatch_stop(stopwatch) + + assert ffi.NULL != stopwatch + assert 0 < int(ret) + + def test_zmq_sleep(self): + try: + C.zmq_sleep(1) + except Exception as e: + raise AssertionError("Error executing zmq_sleep(int)") + diff --git a/src/console/zmq/tests/test_constants.py b/src/console/zmq/tests/test_constants.py new file mode 100755 index 00000000..d32b2b48 --- /dev/null +++ b/src/console/zmq/tests/test_constants.py @@ -0,0 +1,104 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import json +from unittest import TestCase + +import zmq + +from zmq.utils import constant_names +from zmq.sugar import constants as sugar_constants +from zmq.backend import constants as backend_constants + +all_set = set(constant_names.all_names) + +class TestConstants(TestCase): + + def _duplicate_test(self, namelist, listname): + """test that a given list has no duplicates""" + dupes = {} + for name in set(namelist): + cnt = namelist.count(name) + if cnt > 1: + dupes[name] = cnt + if dupes: + self.fail("The following names occur more than once in %s: %s" % (listname, json.dumps(dupes, indent=2))) + + def test_duplicate_all(self): + return self._duplicate_test(constant_names.all_names, "all_names") + + def _change_key(self, change, version): + """return changed-in key""" + return "%s-in %d.%d.%d" % tuple([change] + list(version)) + + def test_duplicate_changed(self): + all_changed = [] + for change in ("new", "removed"): + d = getattr(constant_names, change + "_in") + for version, namelist in d.items(): + all_changed.extend(namelist) + self._duplicate_test(namelist, self._change_key(change, version)) + + self._duplicate_test(all_changed, "all-changed") + + def test_changed_in_all(self): + missing = {} + for change in ("new", "removed"): + d = getattr(constant_names, change + "_in") + for version, namelist in d.items(): + key = self._change_key(change, version) + for name in namelist: + if name not in all_set: + if key not in missing: + missing[key] = [] + missing[key].append(name) + + if missing: + self.fail( + "The following names are missing in `all_names`: %s" % json.dumps(missing, indent=2) + ) + + def test_no_negative_constants(self): + for name in sugar_constants.__all__: + self.assertNotEqual(getattr(zmq, name), sugar_constants._UNDEFINED) + + def test_undefined_constants(self): + all_aliases = [] + for alias_group in sugar_constants.aliases: + all_aliases.extend(alias_group) + + for name in all_set.difference(all_aliases): + raw = getattr(backend_constants, name) + if raw == sugar_constants._UNDEFINED: + self.assertRaises(AttributeError, getattr, zmq, name) + else: + self.assertEqual(getattr(zmq, name), raw) + + def test_new(self): + zmq_version = zmq.zmq_version_info() + for version, new_names in constant_names.new_in.items(): + should_have = zmq_version >= version + for name in new_names: + try: + value = getattr(zmq, name) + except AttributeError: + if should_have: + self.fail("AttributeError: zmq.%s" % name) + else: + if not should_have: + self.fail("Shouldn't have: zmq.%s=%s" % (name, value)) + + def test_removed(self): + zmq_version = zmq.zmq_version_info() + for version, new_names in constant_names.removed_in.items(): + should_have = zmq_version < version + for name in new_names: + try: + value = getattr(zmq, name) + except AttributeError: + if should_have: + self.fail("AttributeError: zmq.%s" % name) + else: + if not should_have: + self.fail("Shouldn't have: zmq.%s=%s" % (name, value)) + diff --git a/src/console/zmq/tests/test_context.py b/src/console/zmq/tests/test_context.py new file mode 100755 index 00000000..e3280778 --- /dev/null +++ b/src/console/zmq/tests/test_context.py @@ -0,0 +1,257 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import gc +import sys +import time +from threading import Thread, Event + +import zmq +from zmq.tests import ( + BaseZMQTestCase, have_gevent, GreenTest, skip_green, PYPY, SkipTest, +) + + +class TestContext(BaseZMQTestCase): + + def test_init(self): + c1 = self.Context() + self.assert_(isinstance(c1, self.Context)) + del c1 + c2 = self.Context() + self.assert_(isinstance(c2, self.Context)) + del c2 + c3 = self.Context() + self.assert_(isinstance(c3, self.Context)) + del c3 + + def test_dir(self): + ctx = self.Context() + self.assertTrue('socket' in dir(ctx)) + if zmq.zmq_version_info() > (3,): + self.assertTrue('IO_THREADS' in dir(ctx)) + ctx.term() + + def test_term(self): + c = self.Context() + c.term() + self.assert_(c.closed) + + def test_context_manager(self): + with self.Context() as c: + pass + self.assert_(c.closed) + + def test_fail_init(self): + self.assertRaisesErrno(zmq.EINVAL, self.Context, -1) + + def test_term_hang(self): + rep,req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER) + req.setsockopt(zmq.LINGER, 0) + req.send(b'hello', copy=False) + req.close() + rep.close() + self.context.term() + + def test_instance(self): + ctx = self.Context.instance() + c2 = self.Context.instance(io_threads=2) + self.assertTrue(c2 is ctx) + c2.term() + c3 = self.Context.instance() + c4 = self.Context.instance() + self.assertFalse(c3 is c2) + self.assertFalse(c3.closed) + self.assertTrue(c3 is c4) + + def test_many_sockets(self): + """opening and closing many sockets shouldn't cause problems""" + ctx = self.Context() + for i in range(16): + sockets = [ ctx.socket(zmq.REP) for i in range(65) ] + [ s.close() for s in sockets ] + # give the reaper a chance + time.sleep(1e-2) + ctx.term() + + def test_sockopts(self): + """setting socket options with ctx attributes""" + ctx = self.Context() + ctx.linger = 5 + self.assertEqual(ctx.linger, 5) + s = ctx.socket(zmq.REQ) + self.assertEqual(s.linger, 5) + self.assertEqual(s.getsockopt(zmq.LINGER), 5) + s.close() + # check that subscribe doesn't get set on sockets that don't subscribe: + ctx.subscribe = b'' + s = ctx.socket(zmq.REQ) + s.close() + + ctx.term() + + + def test_destroy(self): + """Context.destroy should close sockets""" + ctx = self.Context() + sockets = [ ctx.socket(zmq.REP) for i in range(65) ] + + # close half of the sockets + [ s.close() for s in sockets[::2] ] + + ctx.destroy() + # reaper is not instantaneous + time.sleep(1e-2) + for s in sockets: + self.assertTrue(s.closed) + + def test_destroy_linger(self): + """Context.destroy should set linger on closing sockets""" + req,rep = self.create_bound_pair(zmq.REQ, zmq.REP) + req.send(b'hi') + time.sleep(1e-2) + self.context.destroy(linger=0) + # reaper is not instantaneous + time.sleep(1e-2) + for s in (req,rep): + self.assertTrue(s.closed) + + def test_term_noclose(self): + """Context.term won't close sockets""" + ctx = self.Context() + s = ctx.socket(zmq.REQ) + self.assertFalse(s.closed) + t = Thread(target=ctx.term) + t.start() + t.join(timeout=0.1) + self.assertTrue(t.is_alive(), "Context should be waiting") + s.close() + t.join(timeout=0.1) + self.assertFalse(t.is_alive(), "Context should have closed") + + def test_gc(self): + """test close&term by garbage collection alone""" + if PYPY: + raise SkipTest("GC doesn't work ") + + # test credit @dln (GH #137): + def gcf(): + def inner(): + ctx = self.Context() + s = ctx.socket(zmq.PUSH) + inner() + gc.collect() + t = Thread(target=gcf) + t.start() + t.join(timeout=1) + self.assertFalse(t.is_alive(), "Garbage collection should have cleaned up context") + + def test_cyclic_destroy(self): + """ctx.destroy should succeed when cyclic ref prevents gc""" + # test credit @dln (GH #137): + class CyclicReference(object): + def __init__(self, parent=None): + self.parent = parent + + def crash(self, sock): + self.sock = sock + self.child = CyclicReference(self) + + def crash_zmq(): + ctx = self.Context() + sock = ctx.socket(zmq.PULL) + c = CyclicReference() + c.crash(sock) + ctx.destroy() + + crash_zmq() + + def test_term_thread(self): + """ctx.term should not crash active threads (#139)""" + ctx = self.Context() + evt = Event() + evt.clear() + + def block(): + s = ctx.socket(zmq.REP) + s.bind_to_random_port('tcp://127.0.0.1') + evt.set() + try: + s.recv() + except zmq.ZMQError as e: + self.assertEqual(e.errno, zmq.ETERM) + return + finally: + s.close() + self.fail("recv should have been interrupted with ETERM") + t = Thread(target=block) + t.start() + + evt.wait(1) + self.assertTrue(evt.is_set(), "sync event never fired") + time.sleep(0.01) + ctx.term() + t.join(timeout=1) + self.assertFalse(t.is_alive(), "term should have interrupted s.recv()") + + def test_destroy_no_sockets(self): + ctx = self.Context() + s = ctx.socket(zmq.PUB) + s.bind_to_random_port('tcp://127.0.0.1') + s.close() + ctx.destroy() + assert s.closed + assert ctx.closed + + def test_ctx_opts(self): + if zmq.zmq_version_info() < (3,): + raise SkipTest("context options require libzmq 3") + ctx = self.Context() + ctx.set(zmq.MAX_SOCKETS, 2) + self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 2) + ctx.max_sockets = 100 + self.assertEqual(ctx.max_sockets, 100) + self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 100) + + def test_shadow(self): + ctx = self.Context() + ctx2 = self.Context.shadow(ctx.underlying) + self.assertEqual(ctx.underlying, ctx2.underlying) + s = ctx.socket(zmq.PUB) + s.close() + del ctx2 + self.assertFalse(ctx.closed) + s = ctx.socket(zmq.PUB) + ctx2 = self.Context.shadow(ctx.underlying) + s2 = ctx2.socket(zmq.PUB) + s.close() + s2.close() + ctx.term() + self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB) + del ctx2 + + def test_shadow_pyczmq(self): + try: + from pyczmq import zctx, zsocket, zstr + except Exception: + raise SkipTest("Requires pyczmq") + + ctx = zctx.new() + a = zsocket.new(ctx, zmq.PUSH) + zsocket.bind(a, "inproc://a") + ctx2 = self.Context.shadow_pyczmq(ctx) + b = ctx2.socket(zmq.PULL) + b.connect("inproc://a") + zstr.send(a, b'hi') + rcvd = self.recv(b) + self.assertEqual(rcvd, b'hi') + b.close() + + +if False: # disable green context tests + class TestContextGreen(GreenTest, TestContext): + """gevent subclass of context tests""" + # skip tests that use real threads: + test_gc = GreenTest.skip_green + test_term_thread = GreenTest.skip_green + test_destroy_linger = GreenTest.skip_green diff --git a/src/console/zmq/tests/test_device.py b/src/console/zmq/tests/test_device.py new file mode 100755 index 00000000..f8305074 --- /dev/null +++ b/src/console/zmq/tests/test_device.py @@ -0,0 +1,146 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time + +import zmq +from zmq import devices +from zmq.tests import BaseZMQTestCase, SkipTest, have_gevent, GreenTest, PYPY +from zmq.utils.strtypes import (bytes,unicode,basestring) + +if PYPY: + # cleanup of shared Context doesn't work on PyPy + devices.Device.context_factory = zmq.Context + +class TestDevice(BaseZMQTestCase): + + def test_device_types(self): + for devtype in (zmq.STREAMER, zmq.FORWARDER, zmq.QUEUE): + dev = devices.Device(devtype, zmq.PAIR, zmq.PAIR) + self.assertEqual(dev.device_type, devtype) + del dev + + def test_device_attributes(self): + dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB) + self.assertEqual(dev.in_type, zmq.SUB) + self.assertEqual(dev.out_type, zmq.PUB) + self.assertEqual(dev.device_type, zmq.QUEUE) + self.assertEqual(dev.daemon, True) + del dev + + def test_tsdevice_attributes(self): + dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB) + self.assertEqual(dev.in_type, zmq.SUB) + self.assertEqual(dev.out_type, zmq.PUB) + self.assertEqual(dev.device_type, zmq.QUEUE) + self.assertEqual(dev.daemon, True) + del dev + + + def test_single_socket_forwarder_connect(self): + dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) + req = self.context.socket(zmq.REQ) + port = req.bind_to_random_port('tcp://127.0.0.1') + dev.connect_in('tcp://127.0.0.1:%i'%port) + dev.start() + time.sleep(.25) + msg = b'hello' + req.send(msg) + self.assertEqual(msg, self.recv(req)) + del dev + req.close() + dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) + req = self.context.socket(zmq.REQ) + port = req.bind_to_random_port('tcp://127.0.0.1') + dev.connect_out('tcp://127.0.0.1:%i'%port) + dev.start() + time.sleep(.25) + msg = b'hello again' + req.send(msg) + self.assertEqual(msg, self.recv(req)) + del dev + req.close() + + def test_single_socket_forwarder_bind(self): + dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) + # select random port: + binder = self.context.socket(zmq.REQ) + port = binder.bind_to_random_port('tcp://127.0.0.1') + binder.close() + time.sleep(0.1) + req = self.context.socket(zmq.REQ) + req.connect('tcp://127.0.0.1:%i'%port) + dev.bind_in('tcp://127.0.0.1:%i'%port) + dev.start() + time.sleep(.25) + msg = b'hello' + req.send(msg) + self.assertEqual(msg, self.recv(req)) + del dev + req.close() + dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) + # select random port: + binder = self.context.socket(zmq.REQ) + port = binder.bind_to_random_port('tcp://127.0.0.1') + binder.close() + time.sleep(0.1) + req = self.context.socket(zmq.REQ) + req.connect('tcp://127.0.0.1:%i'%port) + dev.bind_in('tcp://127.0.0.1:%i'%port) + dev.start() + time.sleep(.25) + msg = b'hello again' + req.send(msg) + self.assertEqual(msg, self.recv(req)) + del dev + req.close() + + def test_proxy(self): + if zmq.zmq_version_info() < (3,2): + raise SkipTest("Proxies only in libzmq >= 3") + dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH) + binder = self.context.socket(zmq.REQ) + iface = 'tcp://127.0.0.1' + port = binder.bind_to_random_port(iface) + port2 = binder.bind_to_random_port(iface) + port3 = binder.bind_to_random_port(iface) + binder.close() + time.sleep(0.1) + dev.bind_in("%s:%i" % (iface, port)) + dev.bind_out("%s:%i" % (iface, port2)) + dev.bind_mon("%s:%i" % (iface, port3)) + dev.start() + time.sleep(0.25) + msg = b'hello' + push = self.context.socket(zmq.PUSH) + push.connect("%s:%i" % (iface, port)) + pull = self.context.socket(zmq.PULL) + pull.connect("%s:%i" % (iface, port2)) + mon = self.context.socket(zmq.PULL) + mon.connect("%s:%i" % (iface, port3)) + push.send(msg) + self.sockets.extend([push, pull, mon]) + self.assertEqual(msg, self.recv(pull)) + self.assertEqual(msg, self.recv(mon)) + +if have_gevent: + import gevent + import zmq.green + + class TestDeviceGreen(GreenTest, BaseZMQTestCase): + + def test_green_device(self): + rep = self.context.socket(zmq.REP) + req = self.context.socket(zmq.REQ) + self.sockets.extend([req, rep]) + port = rep.bind_to_random_port('tcp://127.0.0.1') + g = gevent.spawn(zmq.green.device, zmq.QUEUE, rep, rep) + req.connect('tcp://127.0.0.1:%i' % port) + req.send(b'hi') + timeout = gevent.Timeout(3) + timeout.start() + receiver = gevent.spawn(req.recv) + self.assertEqual(receiver.get(2), b'hi') + timeout.cancel() + g.kill(block=True) + diff --git a/src/console/zmq/tests/test_error.py b/src/console/zmq/tests/test_error.py new file mode 100755 index 00000000..a2eee14a --- /dev/null +++ b/src/console/zmq/tests/test_error.py @@ -0,0 +1,43 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys +import time + +import zmq +from zmq import ZMQError, strerror, Again, ContextTerminated +from zmq.tests import BaseZMQTestCase + +if sys.version_info[0] >= 3: + long = int + +class TestZMQError(BaseZMQTestCase): + + def test_strerror(self): + """test that strerror gets the right type.""" + for i in range(10): + e = strerror(i) + self.assertTrue(isinstance(e, str)) + + def test_zmqerror(self): + for errno in range(10): + e = ZMQError(errno) + self.assertEqual(e.errno, errno) + self.assertEqual(str(e), strerror(errno)) + + def test_again(self): + s = self.context.socket(zmq.REP) + self.assertRaises(Again, s.recv, zmq.NOBLOCK) + self.assertRaisesErrno(zmq.EAGAIN, s.recv, zmq.NOBLOCK) + s.close() + + def atest_ctxterm(self): + s = self.context.socket(zmq.REP) + t = Thread(target=self.context.term) + t.start() + self.assertRaises(ContextTerminated, s.recv, zmq.NOBLOCK) + self.assertRaisesErrno(zmq.TERM, s.recv, zmq.NOBLOCK) + s.close() + t.join() + diff --git a/src/console/zmq/tests/test_etc.py b/src/console/zmq/tests/test_etc.py new file mode 100755 index 00000000..ad224064 --- /dev/null +++ b/src/console/zmq/tests/test_etc.py @@ -0,0 +1,15 @@ +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +import sys + +import zmq + +from . import skip_if + +@skip_if(zmq.zmq_version_info() < (4,1), "libzmq < 4.1") +def test_has(): + assert not zmq.has('something weird') + has_ipc = zmq.has('ipc') + not_windows = not sys.platform.startswith('win') + assert has_ipc == not_windows diff --git a/src/console/zmq/tests/test_imports.py b/src/console/zmq/tests/test_imports.py new file mode 100755 index 00000000..c0ddfaac --- /dev/null +++ b/src/console/zmq/tests/test_imports.py @@ -0,0 +1,62 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys +from unittest import TestCase + +class TestImports(TestCase): + """Test Imports - the quickest test to ensure that we haven't + introduced version-incompatible syntax errors.""" + + def test_toplevel(self): + """test toplevel import""" + import zmq + + def test_core(self): + """test core imports""" + from zmq import Context + from zmq import Socket + from zmq import Poller + from zmq import Frame + from zmq import constants + from zmq import device, proxy + from zmq import Stopwatch + from zmq import ( + zmq_version, + zmq_version_info, + pyzmq_version, + pyzmq_version_info, + ) + + def test_devices(self): + """test device imports""" + import zmq.devices + from zmq.devices import basedevice + from zmq.devices import monitoredqueue + from zmq.devices import monitoredqueuedevice + + def test_log(self): + """test log imports""" + import zmq.log + from zmq.log import handlers + + def test_eventloop(self): + """test eventloop imports""" + import zmq.eventloop + from zmq.eventloop import ioloop + from zmq.eventloop import zmqstream + from zmq.eventloop.minitornado.platform import auto + from zmq.eventloop.minitornado import ioloop + + def test_utils(self): + """test util imports""" + import zmq.utils + from zmq.utils import strtypes + from zmq.utils import jsonapi + + def test_ssh(self): + """test ssh imports""" + from zmq.ssh import tunnel + + + diff --git a/src/console/zmq/tests/test_ioloop.py b/src/console/zmq/tests/test_ioloop.py new file mode 100755 index 00000000..2a8b1153 --- /dev/null +++ b/src/console/zmq/tests/test_ioloop.py @@ -0,0 +1,113 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +import os +import threading + +import zmq +from zmq.tests import BaseZMQTestCase +from zmq.eventloop import ioloop +from zmq.eventloop.minitornado.ioloop import _Timeout +try: + from tornado.ioloop import PollIOLoop, IOLoop as BaseIOLoop +except ImportError: + from zmq.eventloop.minitornado.ioloop import IOLoop as BaseIOLoop + + +def printer(): + os.system("say hello") + raise Exception + print (time.time()) + + +class Delay(threading.Thread): + def __init__(self, f, delay=1): + self.f=f + self.delay=delay + self.aborted=False + self.cond=threading.Condition() + super(Delay, self).__init__() + + def run(self): + self.cond.acquire() + self.cond.wait(self.delay) + self.cond.release() + if not self.aborted: + self.f() + + def abort(self): + self.aborted=True + self.cond.acquire() + self.cond.notify() + self.cond.release() + + +class TestIOLoop(BaseZMQTestCase): + + def test_simple(self): + """simple IOLoop creation test""" + loop = ioloop.IOLoop() + dc = ioloop.PeriodicCallback(loop.stop, 200, loop) + pc = ioloop.PeriodicCallback(lambda : None, 10, loop) + pc.start() + dc.start() + t = Delay(loop.stop,1) + t.start() + loop.start() + if t.isAlive(): + t.abort() + else: + self.fail("IOLoop failed to exit") + + def test_timeout_compare(self): + """test timeout comparisons""" + loop = ioloop.IOLoop() + t = _Timeout(1, 2, loop) + t2 = _Timeout(1, 3, loop) + self.assertEqual(t < t2, id(t) < id(t2)) + t2 = _Timeout(2,1, loop) + self.assertTrue(t < t2) + + def test_poller_events(self): + """Tornado poller implementation maps events correctly""" + req,rep = self.create_bound_pair(zmq.REQ, zmq.REP) + poller = ioloop.ZMQPoller() + poller.register(req, ioloop.IOLoop.READ) + poller.register(rep, ioloop.IOLoop.READ) + events = dict(poller.poll(0)) + self.assertEqual(events.get(rep), None) + self.assertEqual(events.get(req), None) + + poller.register(req, ioloop.IOLoop.WRITE) + poller.register(rep, ioloop.IOLoop.WRITE) + events = dict(poller.poll(1)) + self.assertEqual(events.get(req), ioloop.IOLoop.WRITE) + self.assertEqual(events.get(rep), None) + + poller.register(rep, ioloop.IOLoop.READ) + req.send(b'hi') + events = dict(poller.poll(1)) + self.assertEqual(events.get(rep), ioloop.IOLoop.READ) + self.assertEqual(events.get(req), None) + + def test_instance(self): + """Test IOLoop.instance returns the right object""" + loop = ioloop.IOLoop.instance() + self.assertEqual(loop.__class__, ioloop.IOLoop) + loop = BaseIOLoop.instance() + self.assertEqual(loop.__class__, ioloop.IOLoop) + + def test_close_all(self): + """Test close(all_fds=True)""" + loop = ioloop.IOLoop.instance() + req,rep = self.create_bound_pair(zmq.REQ, zmq.REP) + loop.add_handler(req, lambda msg: msg, ioloop.IOLoop.READ) + loop.add_handler(rep, lambda msg: msg, ioloop.IOLoop.READ) + self.assertEqual(req.closed, False) + self.assertEqual(rep.closed, False) + loop.close(all_fds=True) + self.assertEqual(req.closed, True) + self.assertEqual(rep.closed, True) + + diff --git a/src/console/zmq/tests/test_log.py b/src/console/zmq/tests/test_log.py new file mode 100755 index 00000000..9206f095 --- /dev/null +++ b/src/console/zmq/tests/test_log.py @@ -0,0 +1,116 @@ +# encoding: utf-8 + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import logging +import time +from unittest import TestCase + +import zmq +from zmq.log import handlers +from zmq.utils.strtypes import b, u +from zmq.tests import BaseZMQTestCase + + +class TestPubLog(BaseZMQTestCase): + + iface = 'inproc://zmqlog' + topic= 'zmq' + + @property + def logger(self): + # print dir(self) + logger = logging.getLogger('zmqtest') + logger.setLevel(logging.DEBUG) + return logger + + def connect_handler(self, topic=None): + topic = self.topic if topic is None else topic + logger = self.logger + pub,sub = self.create_bound_pair(zmq.PUB, zmq.SUB) + handler = handlers.PUBHandler(pub) + handler.setLevel(logging.DEBUG) + handler.root_topic = topic + logger.addHandler(handler) + sub.setsockopt(zmq.SUBSCRIBE, b(topic)) + time.sleep(0.1) + return logger, handler, sub + + def test_init_iface(self): + logger = self.logger + ctx = self.context + handler = handlers.PUBHandler(self.iface) + self.assertFalse(handler.ctx is ctx) + self.sockets.append(handler.socket) + # handler.ctx.term() + handler = handlers.PUBHandler(self.iface, self.context) + self.sockets.append(handler.socket) + self.assertTrue(handler.ctx is ctx) + handler.setLevel(logging.DEBUG) + handler.root_topic = self.topic + logger.addHandler(handler) + sub = ctx.socket(zmq.SUB) + self.sockets.append(sub) + sub.setsockopt(zmq.SUBSCRIBE, b(self.topic)) + sub.connect(self.iface) + import time; time.sleep(0.25) + msg1 = 'message' + logger.info(msg1) + + (topic, msg2) = sub.recv_multipart() + self.assertEqual(topic, b'zmq.INFO') + self.assertEqual(msg2, b(msg1)+b'\n') + logger.removeHandler(handler) + + def test_init_socket(self): + pub,sub = self.create_bound_pair(zmq.PUB, zmq.SUB) + logger = self.logger + handler = handlers.PUBHandler(pub) + handler.setLevel(logging.DEBUG) + handler.root_topic = self.topic + logger.addHandler(handler) + + self.assertTrue(handler.socket is pub) + self.assertTrue(handler.ctx is pub.context) + self.assertTrue(handler.ctx is self.context) + sub.setsockopt(zmq.SUBSCRIBE, b(self.topic)) + import time; time.sleep(0.1) + msg1 = 'message' + logger.info(msg1) + + (topic, msg2) = sub.recv_multipart() + self.assertEqual(topic, b'zmq.INFO') + self.assertEqual(msg2, b(msg1)+b'\n') + logger.removeHandler(handler) + + def test_root_topic(self): + logger, handler, sub = self.connect_handler() + handler.socket.bind(self.iface) + sub2 = sub.context.socket(zmq.SUB) + self.sockets.append(sub2) + sub2.connect(self.iface) + sub2.setsockopt(zmq.SUBSCRIBE, b'') + handler.root_topic = b'twoonly' + msg1 = 'ignored' + logger.info(msg1) + self.assertRaisesErrno(zmq.EAGAIN, sub.recv, zmq.NOBLOCK) + topic,msg2 = sub2.recv_multipart() + self.assertEqual(topic, b'twoonly.INFO') + self.assertEqual(msg2, b(msg1)+b'\n') + + logger.removeHandler(handler) + + def test_unicode_message(self): + logger, handler, sub = self.connect_handler() + base_topic = b(self.topic + '.INFO') + for msg, expected in [ + (u('hello'), [base_topic, b('hello\n')]), + (u('héllo'), [base_topic, b('héllo\n')]), + (u('tøpic::héllo'), [base_topic + b('.tøpic'), b('héllo\n')]), + ]: + logger.info(msg) + received = sub.recv_multipart() + self.assertEqual(received, expected) + diff --git a/src/console/zmq/tests/test_message.py b/src/console/zmq/tests/test_message.py new file mode 100755 index 00000000..d8770bdf --- /dev/null +++ b/src/console/zmq/tests/test_message.py @@ -0,0 +1,362 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import copy +import sys +try: + from sys import getrefcount as grc +except ImportError: + grc = None + +import time +from pprint import pprint +from unittest import TestCase + +import zmq +from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy, PYPY +from zmq.utils.strtypes import unicode, bytes, b, u + + +# some useful constants: + +x = b'x' + +try: + view = memoryview +except NameError: + view = buffer + +if grc: + rc0 = grc(x) + v = view(x) + view_rc = grc(x) - rc0 + +def await_gc(obj, rc): + """wait for refcount on an object to drop to an expected value + + Necessary because of the zero-copy gc thread, + which can take some time to receive its DECREF message. + """ + for i in range(50): + # rc + 2 because of the refs in this function + if grc(obj) <= rc + 2: + return + time.sleep(0.05) + +class TestFrame(BaseZMQTestCase): + + @skip_pypy + def test_above_30(self): + """Message above 30 bytes are never copied by 0MQ.""" + for i in range(5, 16): # 32, 64,..., 65536 + s = (2**i)*x + self.assertEqual(grc(s), 2) + m = zmq.Frame(s) + self.assertEqual(grc(s), 4) + del m + await_gc(s, 2) + self.assertEqual(grc(s), 2) + del s + + def test_str(self): + """Test the str representations of the Frames.""" + for i in range(16): + s = (2**i)*x + m = zmq.Frame(s) + m_str = str(m) + m_str_b = b(m_str) # py3compat + self.assertEqual(s, m_str_b) + + def test_bytes(self): + """Test the Frame.bytes property.""" + for i in range(1,16): + s = (2**i)*x + m = zmq.Frame(s) + b = m.bytes + self.assertEqual(s, m.bytes) + if not PYPY: + # check that it copies + self.assert_(b is not s) + # check that it copies only once + self.assert_(b is m.bytes) + + def test_unicode(self): + """Test the unicode representations of the Frames.""" + s = u('asdf') + self.assertRaises(TypeError, zmq.Frame, s) + for i in range(16): + s = (2**i)*u('§') + m = zmq.Frame(s.encode('utf8')) + self.assertEqual(s, unicode(m.bytes,'utf8')) + + def test_len(self): + """Test the len of the Frames.""" + for i in range(16): + s = (2**i)*x + m = zmq.Frame(s) + self.assertEqual(len(s), len(m)) + + @skip_pypy + def test_lifecycle1(self): + """Run through a ref counting cycle with a copy.""" + for i in range(5, 16): # 32, 64,..., 65536 + s = (2**i)*x + rc = 2 + self.assertEqual(grc(s), rc) + m = zmq.Frame(s) + rc += 2 + self.assertEqual(grc(s), rc) + m2 = copy.copy(m) + rc += 1 + self.assertEqual(grc(s), rc) + buf = m2.buffer + + rc += view_rc + self.assertEqual(grc(s), rc) + + self.assertEqual(s, b(str(m))) + self.assertEqual(s, bytes(m2)) + self.assertEqual(s, m.bytes) + # self.assert_(s is str(m)) + # self.assert_(s is str(m2)) + del m2 + rc -= 1 + self.assertEqual(grc(s), rc) + rc -= view_rc + del buf + self.assertEqual(grc(s), rc) + del m + rc -= 2 + await_gc(s, rc) + self.assertEqual(grc(s), rc) + self.assertEqual(rc, 2) + del s + + @skip_pypy + def test_lifecycle2(self): + """Run through a different ref counting cycle with a copy.""" + for i in range(5, 16): # 32, 64,..., 65536 + s = (2**i)*x + rc = 2 + self.assertEqual(grc(s), rc) + m = zmq.Frame(s) + rc += 2 + self.assertEqual(grc(s), rc) + m2 = copy.copy(m) + rc += 1 + self.assertEqual(grc(s), rc) + buf = m.buffer + rc += view_rc + self.assertEqual(grc(s), rc) + self.assertEqual(s, b(str(m))) + self.assertEqual(s, bytes(m2)) + self.assertEqual(s, m2.bytes) + self.assertEqual(s, m.bytes) + # self.assert_(s is str(m)) + # self.assert_(s is str(m2)) + del buf + self.assertEqual(grc(s), rc) + del m + # m.buffer is kept until m is del'd + rc -= view_rc + rc -= 1 + self.assertEqual(grc(s), rc) + del m2 + rc -= 2 + await_gc(s, rc) + self.assertEqual(grc(s), rc) + self.assertEqual(rc, 2) + del s + + @skip_pypy + def test_tracker(self): + m = zmq.Frame(b'asdf', track=True) + self.assertFalse(m.tracker.done) + pm = zmq.MessageTracker(m) + self.assertFalse(pm.done) + del m + for i in range(10): + if pm.done: + break + time.sleep(0.1) + self.assertTrue(pm.done) + + def test_no_tracker(self): + m = zmq.Frame(b'asdf', track=False) + self.assertEqual(m.tracker, None) + m2 = copy.copy(m) + self.assertEqual(m2.tracker, None) + self.assertRaises(ValueError, zmq.MessageTracker, m) + + @skip_pypy + def test_multi_tracker(self): + m = zmq.Frame(b'asdf', track=True) + m2 = zmq.Frame(b'whoda', track=True) + mt = zmq.MessageTracker(m,m2) + self.assertFalse(m.tracker.done) + self.assertFalse(mt.done) + self.assertRaises(zmq.NotDone, mt.wait, 0.1) + del m + time.sleep(0.1) + self.assertRaises(zmq.NotDone, mt.wait, 0.1) + self.assertFalse(mt.done) + del m2 + self.assertTrue(mt.wait() is None) + self.assertTrue(mt.done) + + + def test_buffer_in(self): + """test using a buffer as input""" + ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√") + m = zmq.Frame(view(ins)) + + def test_bad_buffer_in(self): + """test using a bad object""" + self.assertRaises(TypeError, zmq.Frame, 5) + self.assertRaises(TypeError, zmq.Frame, object()) + + def test_buffer_out(self): + """receiving buffered output""" + ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√") + m = zmq.Frame(ins) + outb = m.buffer + self.assertTrue(isinstance(outb, view)) + self.assert_(outb is m.buffer) + self.assert_(m.buffer is m.buffer) + + def test_multisend(self): + """ensure that a message remains intact after multiple sends""" + a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + s = b"message" + m = zmq.Frame(s) + self.assertEqual(s, m.bytes) + + a.send(m, copy=False) + time.sleep(0.1) + self.assertEqual(s, m.bytes) + a.send(m, copy=False) + time.sleep(0.1) + self.assertEqual(s, m.bytes) + a.send(m, copy=True) + time.sleep(0.1) + self.assertEqual(s, m.bytes) + a.send(m, copy=True) + time.sleep(0.1) + self.assertEqual(s, m.bytes) + for i in range(4): + r = b.recv() + self.assertEqual(s,r) + self.assertEqual(s, m.bytes) + + def test_buffer_numpy(self): + """test non-copying numpy array messages""" + try: + import numpy + except ImportError: + raise SkipTest("numpy required") + rand = numpy.random.randint + shapes = [ rand(2,16) for i in range(5) ] + for i in range(1,len(shapes)+1): + shape = shapes[:i] + A = numpy.random.random(shape) + m = zmq.Frame(A) + if view.__name__ == 'buffer': + self.assertEqual(A.data, m.buffer) + B = numpy.frombuffer(m.buffer,dtype=A.dtype).reshape(A.shape) + else: + self.assertEqual(memoryview(A), m.buffer) + B = numpy.array(m.buffer,dtype=A.dtype).reshape(A.shape) + self.assertEqual((A==B).all(), True) + + def test_memoryview(self): + """test messages from memoryview""" + major,minor = sys.version_info[:2] + if not (major >= 3 or (major == 2 and minor >= 7)): + raise SkipTest("memoryviews only in python >= 2.7") + + s = b'carrotjuice' + v = memoryview(s) + m = zmq.Frame(s) + buf = m.buffer + s2 = buf.tobytes() + self.assertEqual(s2,s) + self.assertEqual(m.bytes,s) + + def test_noncopying_recv(self): + """check for clobbering message buffers""" + null = b'\0'*64 + sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + for i in range(32): + # try a few times + sb.send(null, copy=False) + m = sa.recv(copy=False) + mb = m.bytes + # buf = view(m) + buf = m.buffer + del m + for i in range(5): + ff=b'\xff'*(40 + i*10) + sb.send(ff, copy=False) + m2 = sa.recv(copy=False) + if view.__name__ == 'buffer': + b = bytes(buf) + else: + b = buf.tobytes() + self.assertEqual(b, null) + self.assertEqual(mb, null) + self.assertEqual(m2.bytes, ff) + + @skip_pypy + def test_buffer_numpy(self): + """test non-copying numpy array messages""" + try: + import numpy + except ImportError: + raise SkipTest("requires numpy") + if sys.version_info < (2,7): + raise SkipTest("requires new-style buffer interface (py >= 2.7)") + rand = numpy.random.randint + shapes = [ rand(2,5) for i in range(5) ] + a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + dtypes = [int, float, '>i4', 'B'] + for i in range(1,len(shapes)+1): + shape = shapes[:i] + for dt in dtypes: + A = numpy.empty(shape, dtype=dt) + while numpy.isnan(A).any(): + # don't let nan sneak in + A = numpy.ndarray(shape, dtype=dt) + a.send(A, copy=False) + msg = b.recv(copy=False) + + B = numpy.frombuffer(msg, A.dtype).reshape(A.shape) + self.assertEqual(A.shape, B.shape) + self.assertTrue((A==B).all()) + A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')]) + A['a'] = 1024 + A['b'] = 1e9 + A['c'] = 'hello there' + a.send(A, copy=False) + msg = b.recv(copy=False) + + B = numpy.frombuffer(msg, A.dtype).reshape(A.shape) + self.assertEqual(A.shape, B.shape) + self.assertTrue((A==B).all()) + + def test_frame_more(self): + """test Frame.more attribute""" + frame = zmq.Frame(b"hello") + self.assertFalse(frame.more) + sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + sa.send_multipart([b'hi', b'there']) + frame = self.recv(sb, copy=False) + self.assertTrue(frame.more) + if zmq.zmq_version_info()[0] >= 3 and not PYPY: + self.assertTrue(frame.get(zmq.MORE)) + frame = self.recv(sb, copy=False) + self.assertFalse(frame.more) + if zmq.zmq_version_info()[0] >= 3 and not PYPY: + self.assertFalse(frame.get(zmq.MORE)) + diff --git a/src/console/zmq/tests/test_monitor.py b/src/console/zmq/tests/test_monitor.py new file mode 100755 index 00000000..4f035388 --- /dev/null +++ b/src/console/zmq/tests/test_monitor.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import sys +import time +import struct + +from unittest import TestCase + +import zmq +from zmq.tests import BaseZMQTestCase, skip_if, skip_pypy +from zmq.utils.monitor import recv_monitor_message + +skip_lt_4 = skip_if(zmq.zmq_version_info() < (4,), "requires zmq >= 4") + +class TestSocketMonitor(BaseZMQTestCase): + + @skip_lt_4 + def test_monitor(self): + """Test monitoring interface for sockets.""" + s_rep = self.context.socket(zmq.REP) + s_req = self.context.socket(zmq.REQ) + self.sockets.extend([s_rep, s_req]) + s_req.bind("tcp://127.0.0.1:6666") + # try monitoring the REP socket + + s_rep.monitor("inproc://monitor.rep", zmq.EVENT_ALL) + # create listening socket for monitor + s_event = self.context.socket(zmq.PAIR) + self.sockets.append(s_event) + s_event.connect("inproc://monitor.rep") + s_event.linger = 0 + # test receive event for connect event + s_rep.connect("tcp://127.0.0.1:6666") + m = recv_monitor_message(s_event) + if m['event'] == zmq.EVENT_CONNECT_DELAYED: + self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6666") + # test receive event for connected event + m = recv_monitor_message(s_event) + self.assertEqual(m['event'], zmq.EVENT_CONNECTED) + self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6666") + + # test monitor can be disabled. + s_rep.disable_monitor() + m = recv_monitor_message(s_event) + self.assertEqual(m['event'], zmq.EVENT_MONITOR_STOPPED) + + + @skip_lt_4 + def test_monitor_connected(self): + """Test connected monitoring socket.""" + s_rep = self.context.socket(zmq.REP) + s_req = self.context.socket(zmq.REQ) + self.sockets.extend([s_rep, s_req]) + s_req.bind("tcp://127.0.0.1:6667") + # try monitoring the REP socket + # create listening socket for monitor + s_event = s_rep.get_monitor_socket() + s_event.linger = 0 + self.sockets.append(s_event) + # test receive event for connect event + s_rep.connect("tcp://127.0.0.1:6667") + m = recv_monitor_message(s_event) + if m['event'] == zmq.EVENT_CONNECT_DELAYED: + self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6667") + # test receive event for connected event + m = recv_monitor_message(s_event) + self.assertEqual(m['event'], zmq.EVENT_CONNECTED) + self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6667") diff --git a/src/console/zmq/tests/test_monqueue.py b/src/console/zmq/tests/test_monqueue.py new file mode 100755 index 00000000..e855602e --- /dev/null +++ b/src/console/zmq/tests/test_monqueue.py @@ -0,0 +1,227 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +from unittest import TestCase + +import zmq +from zmq import devices + +from zmq.tests import BaseZMQTestCase, SkipTest, PYPY +from zmq.utils.strtypes import unicode + + +if PYPY or zmq.zmq_version_info() >= (4,1): + # cleanup of shared Context doesn't work on PyPy + # there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052) + devices.Device.context_factory = zmq.Context + + +class TestMonitoredQueue(BaseZMQTestCase): + + sockets = [] + + def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'): + self.device = devices.ThreadMonitoredQueue(zmq.PAIR, zmq.PAIR, zmq.PUB, + in_prefix, out_prefix) + alice = self.context.socket(zmq.PAIR) + bob = self.context.socket(zmq.PAIR) + mon = self.context.socket(zmq.SUB) + + aport = alice.bind_to_random_port('tcp://127.0.0.1') + bport = bob.bind_to_random_port('tcp://127.0.0.1') + mport = mon.bind_to_random_port('tcp://127.0.0.1') + mon.setsockopt(zmq.SUBSCRIBE, mon_sub) + + self.device.connect_in("tcp://127.0.0.1:%i"%aport) + self.device.connect_out("tcp://127.0.0.1:%i"%bport) + self.device.connect_mon("tcp://127.0.0.1:%i"%mport) + self.device.start() + time.sleep(.2) + try: + # this is currenlty necessary to ensure no dropped monitor messages + # see LIBZMQ-248 for more info + mon.recv_multipart(zmq.NOBLOCK) + except zmq.ZMQError: + pass + self.sockets.extend([alice, bob, mon]) + return alice, bob, mon + + + def teardown_device(self): + for socket in self.sockets: + socket.close() + del socket + del self.device + + def test_reply(self): + alice, bob, mon = self.build_device() + alices = b"hello bob".split() + alice.send_multipart(alices) + bobs = self.recv_multipart(bob) + self.assertEqual(alices, bobs) + bobs = b"hello alice".split() + bob.send_multipart(bobs) + alices = self.recv_multipart(alice) + self.assertEqual(alices, bobs) + self.teardown_device() + + def test_queue(self): + alice, bob, mon = self.build_device() + alices = b"hello bob".split() + alice.send_multipart(alices) + alices2 = b"hello again".split() + alice.send_multipart(alices2) + alices3 = b"hello again and again".split() + alice.send_multipart(alices3) + bobs = self.recv_multipart(bob) + self.assertEqual(alices, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices2, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices3, bobs) + bobs = b"hello alice".split() + bob.send_multipart(bobs) + alices = self.recv_multipart(alice) + self.assertEqual(alices, bobs) + self.teardown_device() + + def test_monitor(self): + alice, bob, mon = self.build_device() + alices = b"hello bob".split() + alice.send_multipart(alices) + alices2 = b"hello again".split() + alice.send_multipart(alices2) + alices3 = b"hello again and again".split() + alice.send_multipart(alices3) + bobs = self.recv_multipart(bob) + self.assertEqual(alices, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'in']+bobs, mons) + bobs = self.recv_multipart(bob) + self.assertEqual(alices2, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices3, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'in']+alices2, mons) + bobs = b"hello alice".split() + bob.send_multipart(bobs) + alices = self.recv_multipart(alice) + self.assertEqual(alices, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'in']+alices3, mons) + mons = self.recv_multipart(mon) + self.assertEqual([b'out']+bobs, mons) + self.teardown_device() + + def test_prefix(self): + alice, bob, mon = self.build_device(b"", b'foo', b'bar') + alices = b"hello bob".split() + alice.send_multipart(alices) + alices2 = b"hello again".split() + alice.send_multipart(alices2) + alices3 = b"hello again and again".split() + alice.send_multipart(alices3) + bobs = self.recv_multipart(bob) + self.assertEqual(alices, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'foo']+bobs, mons) + bobs = self.recv_multipart(bob) + self.assertEqual(alices2, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices3, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'foo']+alices2, mons) + bobs = b"hello alice".split() + bob.send_multipart(bobs) + alices = self.recv_multipart(alice) + self.assertEqual(alices, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'foo']+alices3, mons) + mons = self.recv_multipart(mon) + self.assertEqual([b'bar']+bobs, mons) + self.teardown_device() + + def test_monitor_subscribe(self): + alice, bob, mon = self.build_device(b"out") + alices = b"hello bob".split() + alice.send_multipart(alices) + alices2 = b"hello again".split() + alice.send_multipart(alices2) + alices3 = b"hello again and again".split() + alice.send_multipart(alices3) + bobs = self.recv_multipart(bob) + self.assertEqual(alices, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices2, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices3, bobs) + bobs = b"hello alice".split() + bob.send_multipart(bobs) + alices = self.recv_multipart(alice) + self.assertEqual(alices, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'out']+bobs, mons) + self.teardown_device() + + def test_router_router(self): + """test router-router MQ devices""" + dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out') + self.device = dev + dev.setsockopt_in(zmq.LINGER, 0) + dev.setsockopt_out(zmq.LINGER, 0) + dev.setsockopt_mon(zmq.LINGER, 0) + + binder = self.context.socket(zmq.DEALER) + porta = binder.bind_to_random_port('tcp://127.0.0.1') + portb = binder.bind_to_random_port('tcp://127.0.0.1') + binder.close() + time.sleep(0.1) + a = self.context.socket(zmq.DEALER) + a.identity = b'a' + b = self.context.socket(zmq.DEALER) + b.identity = b'b' + self.sockets.extend([a, b]) + + a.connect('tcp://127.0.0.1:%i'%porta) + dev.bind_in('tcp://127.0.0.1:%i'%porta) + b.connect('tcp://127.0.0.1:%i'%portb) + dev.bind_out('tcp://127.0.0.1:%i'%portb) + dev.start() + time.sleep(0.2) + if zmq.zmq_version_info() >= (3,1,0): + # flush erroneous poll state, due to LIBZMQ-280 + ping_msg = [ b'ping', b'pong' ] + for s in (a,b): + s.send_multipart(ping_msg) + try: + s.recv(zmq.NOBLOCK) + except zmq.ZMQError: + pass + msg = [ b'hello', b'there' ] + a.send_multipart([b'b']+msg) + bmsg = self.recv_multipart(b) + self.assertEqual(bmsg, [b'a']+msg) + b.send_multipart(bmsg) + amsg = self.recv_multipart(a) + self.assertEqual(amsg, [b'b']+msg) + self.teardown_device() + + def test_default_mq_args(self): + self.device = dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.DEALER, zmq.PUB) + dev.setsockopt_in(zmq.LINGER, 0) + dev.setsockopt_out(zmq.LINGER, 0) + dev.setsockopt_mon(zmq.LINGER, 0) + # this will raise if default args are wrong + dev.start() + self.teardown_device() + + def test_mq_check_prefix(self): + ins = self.context.socket(zmq.ROUTER) + outs = self.context.socket(zmq.DEALER) + mons = self.context.socket(zmq.PUB) + self.sockets.extend([ins, outs, mons]) + + ins = unicode('in') + outs = unicode('out') + self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons) diff --git a/src/console/zmq/tests/test_multipart.py b/src/console/zmq/tests/test_multipart.py new file mode 100755 index 00000000..24d41be0 --- /dev/null +++ b/src/console/zmq/tests/test_multipart.py @@ -0,0 +1,35 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import zmq + + +from zmq.tests import BaseZMQTestCase, SkipTest, have_gevent, GreenTest + + +class TestMultipart(BaseZMQTestCase): + + def test_router_dealer(self): + router, dealer = self.create_bound_pair(zmq.ROUTER, zmq.DEALER) + + msg1 = b'message1' + dealer.send(msg1) + ident = self.recv(router) + more = router.rcvmore + self.assertEqual(more, True) + msg2 = self.recv(router) + self.assertEqual(msg1, msg2) + more = router.rcvmore + self.assertEqual(more, False) + + def test_basic_multipart(self): + a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + msg = [ b'hi', b'there', b'b'] + a.send_multipart(msg) + recvd = b.recv_multipart() + self.assertEqual(msg, recvd) + +if have_gevent: + class TestMultipartGreen(GreenTest, TestMultipart): + pass diff --git a/src/console/zmq/tests/test_pair.py b/src/console/zmq/tests/test_pair.py new file mode 100755 index 00000000..e88c1e8b --- /dev/null +++ b/src/console/zmq/tests/test_pair.py @@ -0,0 +1,53 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import zmq + + +from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest + + +x = b' ' +class TestPair(BaseZMQTestCase): + + def test_basic(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + + msg1 = b'message1' + msg2 = self.ping_pong(s1, s2, msg1) + self.assertEqual(msg1, msg2) + + def test_multiple(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + + for i in range(10): + msg = i*x + s1.send(msg) + + for i in range(10): + msg = i*x + s2.send(msg) + + for i in range(10): + msg = s1.recv() + self.assertEqual(msg, i*x) + + for i in range(10): + msg = s2.recv() + self.assertEqual(msg, i*x) + + def test_json(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + o = dict(a=10,b=list(range(10))) + o2 = self.ping_pong_json(s1, s2, o) + + def test_pyobj(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + o = dict(a=10,b=range(10)) + o2 = self.ping_pong_pyobj(s1, s2, o) + +if have_gevent: + class TestReqRepGreen(GreenTest, TestPair): + pass + diff --git a/src/console/zmq/tests/test_poll.py b/src/console/zmq/tests/test_poll.py new file mode 100755 index 00000000..57346c89 --- /dev/null +++ b/src/console/zmq/tests/test_poll.py @@ -0,0 +1,229 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import time +from unittest import TestCase + +import zmq + +from zmq.tests import PollZMQTestCase, have_gevent, GreenTest + +def wait(): + time.sleep(.25) + + +class TestPoll(PollZMQTestCase): + + Poller = zmq.Poller + + # This test is failing due to this issue: + # http://github.com/sustrik/zeromq2/issues#issue/26 + def test_pair(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + + # Sleep to allow sockets to connect. + wait() + + poller = self.Poller() + poller.register(s1, zmq.POLLIN|zmq.POLLOUT) + poller.register(s2, zmq.POLLIN|zmq.POLLOUT) + # Poll result should contain both sockets + socks = dict(poller.poll()) + # Now make sure that both are send ready. + self.assertEqual(socks[s1], zmq.POLLOUT) + self.assertEqual(socks[s2], zmq.POLLOUT) + # Now do a send on both, wait and test for zmq.POLLOUT|zmq.POLLIN + s1.send(b'msg1') + s2.send(b'msg2') + wait() + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLOUT|zmq.POLLIN) + self.assertEqual(socks[s2], zmq.POLLOUT|zmq.POLLIN) + # Make sure that both are in POLLOUT after recv. + s1.recv() + s2.recv() + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLOUT) + self.assertEqual(socks[s2], zmq.POLLOUT) + + poller.unregister(s1) + poller.unregister(s2) + + # Wait for everything to finish. + wait() + + def test_reqrep(self): + s1, s2 = self.create_bound_pair(zmq.REP, zmq.REQ) + + # Sleep to allow sockets to connect. + wait() + + poller = self.Poller() + poller.register(s1, zmq.POLLIN|zmq.POLLOUT) + poller.register(s2, zmq.POLLIN|zmq.POLLOUT) + + # Make sure that s1 is in state 0 and s2 is in POLLOUT + socks = dict(poller.poll()) + self.assertEqual(s1 in socks, 0) + self.assertEqual(socks[s2], zmq.POLLOUT) + + # Make sure that s2 goes immediately into state 0 after send. + s2.send(b'msg1') + socks = dict(poller.poll()) + self.assertEqual(s2 in socks, 0) + + # Make sure that s1 goes into POLLIN state after a time.sleep(). + time.sleep(0.5) + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLIN) + + # Make sure that s1 goes into POLLOUT after recv. + s1.recv() + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLOUT) + + # Make sure s1 goes into state 0 after send. + s1.send(b'msg2') + socks = dict(poller.poll()) + self.assertEqual(s1 in socks, 0) + + # Wait and then see that s2 is in POLLIN. + time.sleep(0.5) + socks = dict(poller.poll()) + self.assertEqual(socks[s2], zmq.POLLIN) + + # Make sure that s2 is in POLLOUT after recv. + s2.recv() + socks = dict(poller.poll()) + self.assertEqual(socks[s2], zmq.POLLOUT) + + poller.unregister(s1) + poller.unregister(s2) + + # Wait for everything to finish. + wait() + + def test_no_events(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + poller = self.Poller() + poller.register(s1, zmq.POLLIN|zmq.POLLOUT) + poller.register(s2, 0) + self.assertTrue(s1 in poller) + self.assertFalse(s2 in poller) + poller.register(s1, 0) + self.assertFalse(s1 in poller) + + def test_pubsub(self): + s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB) + s2.setsockopt(zmq.SUBSCRIBE, b'') + + # Sleep to allow sockets to connect. + wait() + + poller = self.Poller() + poller.register(s1, zmq.POLLIN|zmq.POLLOUT) + poller.register(s2, zmq.POLLIN) + + # Now make sure that both are send ready. + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLOUT) + self.assertEqual(s2 in socks, 0) + # Make sure that s1 stays in POLLOUT after a send. + s1.send(b'msg1') + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLOUT) + + # Make sure that s2 is POLLIN after waiting. + wait() + socks = dict(poller.poll()) + self.assertEqual(socks[s2], zmq.POLLIN) + + # Make sure that s2 goes into 0 after recv. + s2.recv() + socks = dict(poller.poll()) + self.assertEqual(s2 in socks, 0) + + poller.unregister(s1) + poller.unregister(s2) + + # Wait for everything to finish. + wait() + def test_timeout(self): + """make sure Poller.poll timeout has the right units (milliseconds).""" + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + poller = self.Poller() + poller.register(s1, zmq.POLLIN) + tic = time.time() + evt = poller.poll(.005) + toc = time.time() + self.assertTrue(toc-tic < 0.1) + tic = time.time() + evt = poller.poll(5) + toc = time.time() + self.assertTrue(toc-tic < 0.1) + self.assertTrue(toc-tic > .001) + tic = time.time() + evt = poller.poll(500) + toc = time.time() + self.assertTrue(toc-tic < 1) + self.assertTrue(toc-tic > 0.1) + +class TestSelect(PollZMQTestCase): + + def test_pair(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + + # Sleep to allow sockets to connect. + wait() + + rlist, wlist, xlist = zmq.select([s1, s2], [s1, s2], [s1, s2]) + self.assert_(s1 in wlist) + self.assert_(s2 in wlist) + self.assert_(s1 not in rlist) + self.assert_(s2 not in rlist) + + def test_timeout(self): + """make sure select timeout has the right units (seconds).""" + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + tic = time.time() + r,w,x = zmq.select([s1,s2],[],[],.005) + toc = time.time() + self.assertTrue(toc-tic < 1) + self.assertTrue(toc-tic > 0.001) + tic = time.time() + r,w,x = zmq.select([s1,s2],[],[],.25) + toc = time.time() + self.assertTrue(toc-tic < 1) + self.assertTrue(toc-tic > 0.1) + + +if have_gevent: + import gevent + from zmq import green as gzmq + + class TestPollGreen(GreenTest, TestPoll): + Poller = gzmq.Poller + + def test_wakeup(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + poller = self.Poller() + poller.register(s2, zmq.POLLIN) + + tic = time.time() + r = gevent.spawn(lambda: poller.poll(10000)) + s = gevent.spawn(lambda: s1.send(b'msg1')) + r.join() + toc = time.time() + self.assertTrue(toc-tic < 1) + + def test_socket_poll(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + + tic = time.time() + r = gevent.spawn(lambda: s2.poll(10000)) + s = gevent.spawn(lambda: s1.send(b'msg1')) + r.join() + toc = time.time() + self.assertTrue(toc-tic < 1) + diff --git a/src/console/zmq/tests/test_pubsub.py b/src/console/zmq/tests/test_pubsub.py new file mode 100755 index 00000000..a3ee22aa --- /dev/null +++ b/src/console/zmq/tests/test_pubsub.py @@ -0,0 +1,41 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import time +from unittest import TestCase + +import zmq + +from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest + + +class TestPubSub(BaseZMQTestCase): + + pass + + # We are disabling this test while an issue is being resolved. + def test_basic(self): + s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB) + s2.setsockopt(zmq.SUBSCRIBE,b'') + time.sleep(0.1) + msg1 = b'message' + s1.send(msg1) + msg2 = s2.recv() # This is blocking! + self.assertEqual(msg1, msg2) + + def test_topic(self): + s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB) + s2.setsockopt(zmq.SUBSCRIBE, b'x') + time.sleep(0.1) + msg1 = b'message' + s1.send(msg1) + self.assertRaisesErrno(zmq.EAGAIN, s2.recv, zmq.NOBLOCK) + msg1 = b'xmessage' + s1.send(msg1) + msg2 = s2.recv() + self.assertEqual(msg1, msg2) + +if have_gevent: + class TestPubSubGreen(GreenTest, TestPubSub): + pass diff --git a/src/console/zmq/tests/test_reqrep.py b/src/console/zmq/tests/test_reqrep.py new file mode 100755 index 00000000..de17f2b3 --- /dev/null +++ b/src/console/zmq/tests/test_reqrep.py @@ -0,0 +1,62 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from unittest import TestCase + +import zmq +from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest + + +class TestReqRep(BaseZMQTestCase): + + def test_basic(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + + msg1 = b'message 1' + msg2 = self.ping_pong(s1, s2, msg1) + self.assertEqual(msg1, msg2) + + def test_multiple(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + + for i in range(10): + msg1 = i*b' ' + msg2 = self.ping_pong(s1, s2, msg1) + self.assertEqual(msg1, msg2) + + def test_bad_send_recv(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + + if zmq.zmq_version() != '2.1.8': + # this doesn't work on 2.1.8 + for copy in (True,False): + self.assertRaisesErrno(zmq.EFSM, s1.recv, copy=copy) + self.assertRaisesErrno(zmq.EFSM, s2.send, b'asdf', copy=copy) + + # I have to have this or we die on an Abort trap. + msg1 = b'asdf' + msg2 = self.ping_pong(s1, s2, msg1) + self.assertEqual(msg1, msg2) + + def test_json(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + o = dict(a=10,b=list(range(10))) + o2 = self.ping_pong_json(s1, s2, o) + + def test_pyobj(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + o = dict(a=10,b=range(10)) + o2 = self.ping_pong_pyobj(s1, s2, o) + + def test_large_msg(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + msg1 = 10000*b'X' + + for i in range(10): + msg2 = self.ping_pong(s1, s2, msg1) + self.assertEqual(msg1, msg2) + +if have_gevent: + class TestReqRepGreen(GreenTest, TestReqRep): + pass diff --git a/src/console/zmq/tests/test_security.py b/src/console/zmq/tests/test_security.py new file mode 100755 index 00000000..687b7e0f --- /dev/null +++ b/src/console/zmq/tests/test_security.py @@ -0,0 +1,212 @@ +"""Test libzmq security (libzmq >= 3.3.0)""" +# -*- coding: utf8 -*- + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import os +from threading import Thread + +import zmq +from zmq.tests import ( + BaseZMQTestCase, SkipTest, PYPY +) +from zmq.utils import z85 + + +USER = b"admin" +PASS = b"password" + +class TestSecurity(BaseZMQTestCase): + + def setUp(self): + if zmq.zmq_version_info() < (4,0): + raise SkipTest("security is new in libzmq 4.0") + try: + zmq.curve_keypair() + except zmq.ZMQError: + raise SkipTest("security requires libzmq to be linked against libsodium") + super(TestSecurity, self).setUp() + + + def zap_handler(self): + socket = self.context.socket(zmq.REP) + socket.bind("inproc://zeromq.zap.01") + try: + msg = self.recv_multipart(socket) + + version, sequence, domain, address, identity, mechanism = msg[:6] + if mechanism == b'PLAIN': + username, password = msg[6:] + elif mechanism == b'CURVE': + key = msg[6] + + self.assertEqual(version, b"1.0") + self.assertEqual(identity, b"IDENT") + reply = [version, sequence] + if mechanism == b'CURVE' or \ + (mechanism == b'PLAIN' and username == USER and password == PASS) or \ + (mechanism == b'NULL'): + reply.extend([ + b"200", + b"OK", + b"anonymous", + b"\5Hello\0\0\0\5World", + ]) + else: + reply.extend([ + b"400", + b"Invalid username or password", + b"", + b"", + ]) + socket.send_multipart(reply) + finally: + socket.close() + + def start_zap(self): + self.zap_thread = Thread(target=self.zap_handler) + self.zap_thread.start() + + def stop_zap(self): + self.zap_thread.join() + + def bounce(self, server, client, test_metadata=True): + msg = [os.urandom(64), os.urandom(64)] + client.send_multipart(msg) + frames = self.recv_multipart(server, copy=False) + recvd = list(map(lambda x: x.bytes, frames)) + + try: + if test_metadata and not PYPY: + for frame in frames: + self.assertEqual(frame.get('User-Id'), 'anonymous') + self.assertEqual(frame.get('Hello'), 'World') + self.assertEqual(frame['Socket-Type'], 'DEALER') + except zmq.ZMQVersionError: + pass + + self.assertEqual(recvd, msg) + server.send_multipart(recvd) + msg2 = self.recv_multipart(client) + self.assertEqual(msg2, msg) + + def test_null(self): + """test NULL (default) security""" + server = self.socket(zmq.DEALER) + client = self.socket(zmq.DEALER) + self.assertEqual(client.MECHANISM, zmq.NULL) + self.assertEqual(server.mechanism, zmq.NULL) + self.assertEqual(client.plain_server, 0) + self.assertEqual(server.plain_server, 0) + iface = 'tcp://127.0.0.1' + port = server.bind_to_random_port(iface) + client.connect("%s:%i" % (iface, port)) + self.bounce(server, client, False) + + def test_plain(self): + """test PLAIN authentication""" + server = self.socket(zmq.DEALER) + server.identity = b'IDENT' + client = self.socket(zmq.DEALER) + self.assertEqual(client.plain_username, b'') + self.assertEqual(client.plain_password, b'') + client.plain_username = USER + client.plain_password = PASS + self.assertEqual(client.getsockopt(zmq.PLAIN_USERNAME), USER) + self.assertEqual(client.getsockopt(zmq.PLAIN_PASSWORD), PASS) + self.assertEqual(client.plain_server, 0) + self.assertEqual(server.plain_server, 0) + server.plain_server = True + self.assertEqual(server.mechanism, zmq.PLAIN) + self.assertEqual(client.mechanism, zmq.PLAIN) + + assert not client.plain_server + assert server.plain_server + + self.start_zap() + + iface = 'tcp://127.0.0.1' + port = server.bind_to_random_port(iface) + client.connect("%s:%i" % (iface, port)) + self.bounce(server, client) + self.stop_zap() + + def skip_plain_inauth(self): + """test PLAIN failed authentication""" + server = self.socket(zmq.DEALER) + server.identity = b'IDENT' + client = self.socket(zmq.DEALER) + self.sockets.extend([server, client]) + client.plain_username = USER + client.plain_password = b'incorrect' + server.plain_server = True + self.assertEqual(server.mechanism, zmq.PLAIN) + self.assertEqual(client.mechanism, zmq.PLAIN) + + self.start_zap() + + iface = 'tcp://127.0.0.1' + port = server.bind_to_random_port(iface) + client.connect("%s:%i" % (iface, port)) + client.send(b'ping') + server.rcvtimeo = 250 + self.assertRaisesErrno(zmq.EAGAIN, server.recv) + self.stop_zap() + + def test_keypair(self): + """test curve_keypair""" + try: + public, secret = zmq.curve_keypair() + except zmq.ZMQError: + raise SkipTest("CURVE unsupported") + + self.assertEqual(type(secret), bytes) + self.assertEqual(type(public), bytes) + self.assertEqual(len(secret), 40) + self.assertEqual(len(public), 40) + + # verify that it is indeed Z85 + bsecret, bpublic = [ z85.decode(key) for key in (public, secret) ] + self.assertEqual(type(bsecret), bytes) + self.assertEqual(type(bpublic), bytes) + self.assertEqual(len(bsecret), 32) + self.assertEqual(len(bpublic), 32) + + + def test_curve(self): + """test CURVE encryption""" + server = self.socket(zmq.DEALER) + server.identity = b'IDENT' + client = self.socket(zmq.DEALER) + self.sockets.extend([server, client]) + try: + server.curve_server = True + except zmq.ZMQError as e: + # will raise EINVAL if not linked against libsodium + if e.errno == zmq.EINVAL: + raise SkipTest("CURVE unsupported") + + server_public, server_secret = zmq.curve_keypair() + client_public, client_secret = zmq.curve_keypair() + + server.curve_secretkey = server_secret + server.curve_publickey = server_public + client.curve_serverkey = server_public + client.curve_publickey = client_public + client.curve_secretkey = client_secret + + self.assertEqual(server.mechanism, zmq.CURVE) + self.assertEqual(client.mechanism, zmq.CURVE) + + self.assertEqual(server.get(zmq.CURVE_SERVER), True) + self.assertEqual(client.get(zmq.CURVE_SERVER), False) + + self.start_zap() + + iface = 'tcp://127.0.0.1' + port = server.bind_to_random_port(iface) + client.connect("%s:%i" % (iface, port)) + self.bounce(server, client) + self.stop_zap() + diff --git a/src/console/zmq/tests/test_socket.py b/src/console/zmq/tests/test_socket.py new file mode 100755 index 00000000..5c842edc --- /dev/null +++ b/src/console/zmq/tests/test_socket.py @@ -0,0 +1,450 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +import warnings + +import zmq +from zmq.tests import ( + BaseZMQTestCase, SkipTest, have_gevent, GreenTest, skip_pypy, skip_if +) +from zmq.utils.strtypes import bytes, unicode + + +class TestSocket(BaseZMQTestCase): + + def test_create(self): + ctx = self.Context() + s = ctx.socket(zmq.PUB) + # Superluminal protocol not yet implemented + self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a') + self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a') + self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://') + s.close() + del ctx + + def test_context_manager(self): + url = 'inproc://a' + with self.Context() as ctx: + with ctx.socket(zmq.PUSH) as a: + a.bind(url) + with ctx.socket(zmq.PULL) as b: + b.connect(url) + msg = b'hi' + a.send(msg) + rcvd = self.recv(b) + self.assertEqual(rcvd, msg) + self.assertEqual(b.closed, True) + self.assertEqual(a.closed, True) + self.assertEqual(ctx.closed, True) + + def test_dir(self): + ctx = self.Context() + s = ctx.socket(zmq.PUB) + self.assertTrue('send' in dir(s)) + self.assertTrue('IDENTITY' in dir(s)) + self.assertTrue('AFFINITY' in dir(s)) + self.assertTrue('FD' in dir(s)) + s.close() + ctx.term() + + def test_bind_unicode(self): + s = self.socket(zmq.PUB) + p = s.bind_to_random_port(unicode("tcp://*")) + + def test_connect_unicode(self): + s = self.socket(zmq.PUB) + s.connect(unicode("tcp://127.0.0.1:5555")) + + def test_bind_to_random_port(self): + # Check that bind_to_random_port do not hide usefull exception + ctx = self.Context() + c = ctx.socket(zmq.PUB) + # Invalid format + try: + c.bind_to_random_port('tcp:*') + except zmq.ZMQError as e: + self.assertEqual(e.errno, zmq.EINVAL) + # Invalid protocol + try: + c.bind_to_random_port('rand://*') + except zmq.ZMQError as e: + self.assertEqual(e.errno, zmq.EPROTONOSUPPORT) + + def test_identity(self): + s = self.context.socket(zmq.PULL) + self.sockets.append(s) + ident = b'identity\0\0' + s.identity = ident + self.assertEqual(s.get(zmq.IDENTITY), ident) + + def test_unicode_sockopts(self): + """test setting/getting sockopts with unicode strings""" + topic = "tést" + if str is not unicode: + topic = topic.decode('utf8') + p,s = self.create_bound_pair(zmq.PUB, zmq.SUB) + self.assertEqual(s.send_unicode, s.send_unicode) + self.assertEqual(p.recv_unicode, p.recv_unicode) + self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic) + self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic) + s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16') + self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic) + s.setsockopt_unicode(zmq.SUBSCRIBE, topic) + self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY) + self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE) + + identb = s.getsockopt(zmq.IDENTITY) + identu = identb.decode('utf16') + identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16') + self.assertEqual(identu, identu2) + time.sleep(0.1) # wait for connection/subscription + p.send_unicode(topic,zmq.SNDMORE) + p.send_unicode(topic*2, encoding='latin-1') + self.assertEqual(topic, s.recv_unicode()) + self.assertEqual(topic*2, s.recv_unicode(encoding='latin-1')) + + def test_int_sockopts(self): + "test integer sockopts" + v = zmq.zmq_version_info() + if v < (3,0): + default_hwm = 0 + else: + default_hwm = 1000 + p,s = self.create_bound_pair(zmq.PUB, zmq.SUB) + p.setsockopt(zmq.LINGER, 0) + self.assertEqual(p.getsockopt(zmq.LINGER), 0) + p.setsockopt(zmq.LINGER, -1) + self.assertEqual(p.getsockopt(zmq.LINGER), -1) + self.assertEqual(p.hwm, default_hwm) + p.hwm = 11 + self.assertEqual(p.hwm, 11) + # p.setsockopt(zmq.EVENTS, zmq.POLLIN) + self.assertEqual(p.getsockopt(zmq.EVENTS), zmq.POLLOUT) + self.assertRaisesErrno(zmq.EINVAL, p.setsockopt,zmq.EVENTS, 2**7-1) + self.assertEqual(p.getsockopt(zmq.TYPE), p.socket_type) + self.assertEqual(p.getsockopt(zmq.TYPE), zmq.PUB) + self.assertEqual(s.getsockopt(zmq.TYPE), s.socket_type) + self.assertEqual(s.getsockopt(zmq.TYPE), zmq.SUB) + + # check for overflow / wrong type: + errors = [] + backref = {} + constants = zmq.constants + for name in constants.__all__: + value = getattr(constants, name) + if isinstance(value, int): + backref[value] = name + for opt in zmq.constants.int_sockopts.union(zmq.constants.int64_sockopts): + sopt = backref[opt] + if sopt.startswith(( + 'ROUTER', 'XPUB', 'TCP', 'FAIL', + 'REQ_', 'CURVE_', 'PROBE_ROUTER', + 'IPC_FILTER', 'GSSAPI', + )): + # some sockopts are write-only + continue + try: + n = p.getsockopt(opt) + except zmq.ZMQError as e: + errors.append("getsockopt(zmq.%s) raised '%s'."%(sopt, e)) + else: + if n > 2**31: + errors.append("getsockopt(zmq.%s) returned a ridiculous value." + " It is probably the wrong type."%sopt) + if errors: + self.fail('\n'.join([''] + errors)) + + def test_bad_sockopts(self): + """Test that appropriate errors are raised on bad socket options""" + s = self.context.socket(zmq.PUB) + self.sockets.append(s) + s.setsockopt(zmq.LINGER, 0) + # unrecognized int sockopts pass through to libzmq, and should raise EINVAL + self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5) + self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999) + # but only int sockopts are allowed through this way, otherwise raise a TypeError + self.assertRaises(TypeError, s.setsockopt, 9999, b"5") + # some sockopts are valid in general, but not on every socket: + self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi') + + def test_sockopt_roundtrip(self): + "test set/getsockopt roundtrip." + p = self.context.socket(zmq.PUB) + self.sockets.append(p) + p.setsockopt(zmq.LINGER, 11) + self.assertEqual(p.getsockopt(zmq.LINGER), 11) + + def test_send_unicode(self): + "test sending unicode objects" + a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + self.sockets.extend([a,b]) + u = "çπ§" + if str is not unicode: + u = u.decode('utf8') + self.assertRaises(TypeError, a.send, u,copy=False) + self.assertRaises(TypeError, a.send, u,copy=True) + a.send_unicode(u) + s = b.recv() + self.assertEqual(s,u.encode('utf8')) + self.assertEqual(s.decode('utf8'),u) + a.send_unicode(u,encoding='utf16') + s = b.recv_unicode(encoding='utf16') + self.assertEqual(s,u) + + @skip_pypy + def test_tracker(self): + "test the MessageTracker object for tracking when zmq is done with a buffer" + addr = 'tcp://127.0.0.1' + a = self.context.socket(zmq.PUB) + port = a.bind_to_random_port(addr) + a.close() + iface = "%s:%i"%(addr,port) + a = self.context.socket(zmq.PAIR) + # a.setsockopt(zmq.IDENTITY, b"a") + b = self.context.socket(zmq.PAIR) + self.sockets.extend([a,b]) + a.connect(iface) + time.sleep(0.1) + p1 = a.send(b'something', copy=False, track=True) + self.assertTrue(isinstance(p1, zmq.MessageTracker)) + self.assertFalse(p1.done) + p2 = a.send_multipart([b'something', b'else'], copy=False, track=True) + self.assert_(isinstance(p2, zmq.MessageTracker)) + self.assertEqual(p2.done, False) + self.assertEqual(p1.done, False) + + b.bind(iface) + msg = b.recv_multipart() + for i in range(10): + if p1.done: + break + time.sleep(0.1) + self.assertEqual(p1.done, True) + self.assertEqual(msg, [b'something']) + msg = b.recv_multipart() + for i in range(10): + if p2.done: + break + time.sleep(0.1) + self.assertEqual(p2.done, True) + self.assertEqual(msg, [b'something', b'else']) + m = zmq.Frame(b"again", track=True) + self.assertEqual(m.tracker.done, False) + p1 = a.send(m, copy=False) + p2 = a.send(m, copy=False) + self.assertEqual(m.tracker.done, False) + self.assertEqual(p1.done, False) + self.assertEqual(p2.done, False) + msg = b.recv_multipart() + self.assertEqual(m.tracker.done, False) + self.assertEqual(msg, [b'again']) + msg = b.recv_multipart() + self.assertEqual(m.tracker.done, False) + self.assertEqual(msg, [b'again']) + self.assertEqual(p1.done, False) + self.assertEqual(p2.done, False) + pm = m.tracker + del m + for i in range(10): + if p1.done: + break + time.sleep(0.1) + self.assertEqual(p1.done, True) + self.assertEqual(p2.done, True) + m = zmq.Frame(b'something', track=False) + self.assertRaises(ValueError, a.send, m, copy=False, track=True) + + + def test_close(self): + ctx = self.Context() + s = ctx.socket(zmq.PUB) + s.close() + self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'') + self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'') + self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'') + self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf') + self.assertRaisesErrno(zmq.ENOTSOCK, s.recv) + del ctx + + def test_attr(self): + """set setting/getting sockopts as attributes""" + s = self.context.socket(zmq.DEALER) + self.sockets.append(s) + linger = 10 + s.linger = linger + self.assertEqual(linger, s.linger) + self.assertEqual(linger, s.getsockopt(zmq.LINGER)) + self.assertEqual(s.fd, s.getsockopt(zmq.FD)) + + def test_bad_attr(self): + s = self.context.socket(zmq.DEALER) + self.sockets.append(s) + try: + s.apple='foo' + except AttributeError: + pass + else: + self.fail("bad setattr should have raised AttributeError") + try: + s.apple + except AttributeError: + pass + else: + self.fail("bad getattr should have raised AttributeError") + + def test_subclass(self): + """subclasses can assign attributes""" + class S(zmq.Socket): + a = None + def __init__(self, *a, **kw): + self.a=-1 + super(S, self).__init__(*a, **kw) + + s = S(self.context, zmq.REP) + self.sockets.append(s) + self.assertEqual(s.a, -1) + s.a=1 + self.assertEqual(s.a, 1) + a=s.a + self.assertEqual(a, 1) + + def test_recv_multipart(self): + a,b = self.create_bound_pair() + msg = b'hi' + for i in range(3): + a.send(msg) + time.sleep(0.1) + for i in range(3): + self.assertEqual(b.recv_multipart(), [msg]) + + def test_close_after_destroy(self): + """s.close() after ctx.destroy() should be fine""" + ctx = self.Context() + s = ctx.socket(zmq.REP) + ctx.destroy() + # reaper is not instantaneous + time.sleep(1e-2) + s.close() + self.assertTrue(s.closed) + + def test_poll(self): + a,b = self.create_bound_pair() + tic = time.time() + evt = a.poll(50) + self.assertEqual(evt, 0) + evt = a.poll(50, zmq.POLLOUT) + self.assertEqual(evt, zmq.POLLOUT) + msg = b'hi' + a.send(msg) + evt = b.poll(50) + self.assertEqual(evt, zmq.POLLIN) + msg2 = self.recv(b) + evt = b.poll(50) + self.assertEqual(evt, 0) + self.assertEqual(msg2, msg) + + def test_ipc_path_max_length(self): + """IPC_PATH_MAX_LEN is a sensible value""" + if zmq.IPC_PATH_MAX_LEN == 0: + raise SkipTest("IPC_PATH_MAX_LEN undefined") + + msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN + self.assertTrue(zmq.IPC_PATH_MAX_LEN > 30, msg) + self.assertTrue(zmq.IPC_PATH_MAX_LEN < 1025, msg) + + def test_ipc_path_max_length_msg(self): + if zmq.IPC_PATH_MAX_LEN == 0: + raise SkipTest("IPC_PATH_MAX_LEN undefined") + + s = self.context.socket(zmq.PUB) + self.sockets.append(s) + try: + s.bind('ipc://{0}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1))) + except zmq.ZMQError as e: + self.assertTrue(str(zmq.IPC_PATH_MAX_LEN) in e.strerror) + + def test_hwm(self): + zmq3 = zmq.zmq_version_info()[0] >= 3 + for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER): + s = self.context.socket(stype) + s.hwm = 100 + self.assertEqual(s.hwm, 100) + if zmq3: + try: + self.assertEqual(s.sndhwm, 100) + except AttributeError: + pass + try: + self.assertEqual(s.rcvhwm, 100) + except AttributeError: + pass + s.close() + + def test_shadow(self): + p = self.socket(zmq.PUSH) + p.bind("tcp://127.0.0.1:5555") + p2 = zmq.Socket.shadow(p.underlying) + self.assertEqual(p.underlying, p2.underlying) + s = self.socket(zmq.PULL) + s2 = zmq.Socket.shadow(s.underlying) + self.assertNotEqual(s.underlying, p.underlying) + self.assertEqual(s.underlying, s2.underlying) + s2.connect("tcp://127.0.0.1:5555") + sent = b'hi' + p2.send(sent) + rcvd = self.recv(s2) + self.assertEqual(rcvd, sent) + + def test_shadow_pyczmq(self): + try: + from pyczmq import zctx, zsocket + except Exception: + raise SkipTest("Requires pyczmq") + + ctx = zctx.new() + ca = zsocket.new(ctx, zmq.PUSH) + cb = zsocket.new(ctx, zmq.PULL) + a = zmq.Socket.shadow(ca) + b = zmq.Socket.shadow(cb) + a.bind("inproc://a") + b.connect("inproc://a") + a.send(b'hi') + rcvd = self.recv(b) + self.assertEqual(rcvd, b'hi') + + +if have_gevent: + import gevent + + class TestSocketGreen(GreenTest, TestSocket): + test_bad_attr = GreenTest.skip_green + test_close_after_destroy = GreenTest.skip_green + + def test_timeout(self): + a,b = self.create_bound_pair() + g = gevent.spawn_later(0.5, lambda: a.send(b'hi')) + timeout = gevent.Timeout(0.1) + timeout.start() + self.assertRaises(gevent.Timeout, b.recv) + g.kill() + + @skip_if(not hasattr(zmq, 'RCVTIMEO')) + def test_warn_set_timeo(self): + s = self.context.socket(zmq.REQ) + with warnings.catch_warnings(record=True) as w: + s.rcvtimeo = 5 + s.close() + self.assertEqual(len(w), 1) + self.assertEqual(w[0].category, UserWarning) + + + @skip_if(not hasattr(zmq, 'SNDTIMEO')) + def test_warn_get_timeo(self): + s = self.context.socket(zmq.REQ) + with warnings.catch_warnings(record=True) as w: + s.sndtimeo + s.close() + self.assertEqual(len(w), 1) + self.assertEqual(w[0].category, UserWarning) diff --git a/src/console/zmq/tests/test_stopwatch.py b/src/console/zmq/tests/test_stopwatch.py new file mode 100755 index 00000000..49fb79f2 --- /dev/null +++ b/src/console/zmq/tests/test_stopwatch.py @@ -0,0 +1,42 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import sys +import time + +from unittest import TestCase + +from zmq import Stopwatch, ZMQError + +if sys.version_info[0] >= 3: + long = int + +class TestStopWatch(TestCase): + + def test_stop_long(self): + """Ensure stop returns a long int.""" + watch = Stopwatch() + watch.start() + us = watch.stop() + self.assertTrue(isinstance(us, long)) + + def test_stop_microseconds(self): + """Test that stop/sleep have right units.""" + watch = Stopwatch() + watch.start() + tic = time.time() + watch.sleep(1) + us = watch.stop() + toc = time.time() + self.assertAlmostEqual(us/1e6,(toc-tic),places=0) + + def test_double_stop(self): + """Test error raised on multiple calls to stop.""" + watch = Stopwatch() + watch.start() + watch.stop() + self.assertRaises(ZMQError, watch.stop) + self.assertRaises(ZMQError, watch.stop) + diff --git a/src/console/zmq/tests/test_version.py b/src/console/zmq/tests/test_version.py new file mode 100755 index 00000000..6ebebf30 --- /dev/null +++ b/src/console/zmq/tests/test_version.py @@ -0,0 +1,44 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from unittest import TestCase +import zmq +from zmq.sugar import version + + +class TestVersion(TestCase): + + def test_pyzmq_version(self): + vs = zmq.pyzmq_version() + vs2 = zmq.__version__ + self.assertTrue(isinstance(vs, str)) + if zmq.__revision__: + self.assertEqual(vs, '@'.join(vs2, zmq.__revision__)) + else: + self.assertEqual(vs, vs2) + if version.VERSION_EXTRA: + self.assertTrue(version.VERSION_EXTRA in vs) + self.assertTrue(version.VERSION_EXTRA in vs2) + + def test_pyzmq_version_info(self): + info = zmq.pyzmq_version_info() + self.assertTrue(isinstance(info, tuple)) + for n in info[:3]: + self.assertTrue(isinstance(n, int)) + if version.VERSION_EXTRA: + self.assertEqual(len(info), 4) + self.assertEqual(info[-1], float('inf')) + else: + self.assertEqual(len(info), 3) + + def test_zmq_version_info(self): + info = zmq.zmq_version_info() + self.assertTrue(isinstance(info, tuple)) + for n in info[:3]: + self.assertTrue(isinstance(n, int)) + + def test_zmq_version(self): + v = zmq.zmq_version() + self.assertTrue(isinstance(v, str)) + diff --git a/src/console/zmq/tests/test_win32_shim.py b/src/console/zmq/tests/test_win32_shim.py new file mode 100755 index 00000000..55657bda --- /dev/null +++ b/src/console/zmq/tests/test_win32_shim.py @@ -0,0 +1,56 @@ +from __future__ import print_function + +import os + +from functools import wraps +from zmq.tests import BaseZMQTestCase +from zmq.utils.win32 import allow_interrupt + + +def count_calls(f): + @wraps(f) + def _(*args, **kwds): + try: + return f(*args, **kwds) + finally: + _.__calls__ += 1 + _.__calls__ = 0 + return _ + + +class TestWindowsConsoleControlHandler(BaseZMQTestCase): + + def test_handler(self): + @count_calls + def interrupt_polling(): + print('Caught CTRL-C!') + + if os.name == 'nt': + from ctypes import windll + from ctypes.wintypes import BOOL, DWORD + + kernel32 = windll.LoadLibrary('kernel32') + + # <http://msdn.microsoft.com/en-us/library/ms683155.aspx> + GenerateConsoleCtrlEvent = kernel32.GenerateConsoleCtrlEvent + GenerateConsoleCtrlEvent.argtypes = (DWORD, DWORD) + GenerateConsoleCtrlEvent.restype = BOOL + + try: + # Simulate CTRL-C event while handler is active. + with allow_interrupt(interrupt_polling): + result = GenerateConsoleCtrlEvent(0, 0) + if result == 0: + raise WindowsError + except KeyboardInterrupt: + pass + else: + self.fail('Expecting `KeyboardInterrupt` exception!') + + # Make sure our handler was called. + self.assertEqual(interrupt_polling.__calls__, 1) + else: + # On non-Windows systems, this utility is just a no-op! + with allow_interrupt(interrupt_polling): + pass + self.assertEqual(interrupt_polling.__calls__, 0) diff --git a/src/console/zmq/tests/test_z85.py b/src/console/zmq/tests/test_z85.py new file mode 100755 index 00000000..8a73cb4d --- /dev/null +++ b/src/console/zmq/tests/test_z85.py @@ -0,0 +1,63 @@ +# -*- coding: utf8 -*- +"""Test Z85 encoding + +confirm values and roundtrip with test values from the reference implementation. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from unittest import TestCase +from zmq.utils import z85 + + +class TestZ85(TestCase): + + def test_client_public(self): + client_public = \ + b"\xBB\x88\x47\x1D\x65\xE2\x65\x9B" \ + b"\x30\xC5\x5A\x53\x21\xCE\xBB\x5A" \ + b"\xAB\x2B\x70\xA3\x98\x64\x5C\x26" \ + b"\xDC\xA2\xB2\xFC\xB4\x3F\xC5\x18" + encoded = z85.encode(client_public) + + self.assertEqual(encoded, b"Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID") + decoded = z85.decode(encoded) + self.assertEqual(decoded, client_public) + + def test_client_secret(self): + client_secret = \ + b"\x7B\xB8\x64\xB4\x89\xAF\xA3\x67" \ + b"\x1F\xBE\x69\x10\x1F\x94\xB3\x89" \ + b"\x72\xF2\x48\x16\xDF\xB0\x1B\x51" \ + b"\x65\x6B\x3F\xEC\x8D\xFD\x08\x88" + encoded = z85.encode(client_secret) + + self.assertEqual(encoded, b"D:)Q[IlAW!ahhC2ac:9*A}h:p?([4%wOTJ%JR%cs") + decoded = z85.decode(encoded) + self.assertEqual(decoded, client_secret) + + def test_server_public(self): + server_public = \ + b"\x54\xFC\xBA\x24\xE9\x32\x49\x96" \ + b"\x93\x16\xFB\x61\x7C\x87\x2B\xB0" \ + b"\xC1\xD1\xFF\x14\x80\x04\x27\xC5" \ + b"\x94\xCB\xFA\xCF\x1B\xC2\xD6\x52" + encoded = z85.encode(server_public) + + self.assertEqual(encoded, b"rq:rM>}U?@Lns47E1%kR.o@n%FcmmsL/@{H8]yf7") + decoded = z85.decode(encoded) + self.assertEqual(decoded, server_public) + + def test_server_secret(self): + server_secret = \ + b"\x8E\x0B\xDD\x69\x76\x28\xB9\x1D" \ + b"\x8F\x24\x55\x87\xEE\x95\xC5\xB0" \ + b"\x4D\x48\x96\x3F\x79\x25\x98\x77" \ + b"\xB4\x9C\xD9\x06\x3A\xEA\xD3\xB7" + encoded = z85.encode(server_secret) + + self.assertEqual(encoded, b"JTKVSB%%)wK0E.X)V>+}o?pNmC{O&4W4b!Ni{Lh6") + decoded = z85.decode(encoded) + self.assertEqual(decoded, server_secret) + diff --git a/src/console/zmq/tests/test_zmqstream.py b/src/console/zmq/tests/test_zmqstream.py new file mode 100755 index 00000000..cdb3a171 --- /dev/null +++ b/src/console/zmq/tests/test_zmqstream.py @@ -0,0 +1,34 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import sys +import time + +from unittest import TestCase + +import zmq +from zmq.eventloop import ioloop, zmqstream + +class TestZMQStream(TestCase): + + def setUp(self): + self.context = zmq.Context() + self.socket = self.context.socket(zmq.REP) + self.loop = ioloop.IOLoop.instance() + self.stream = zmqstream.ZMQStream(self.socket) + + def tearDown(self): + self.socket.close() + self.context.term() + + def test_callable_check(self): + """Ensure callable check works (py3k).""" + + self.stream.on_send(lambda *args: None) + self.stream.on_recv(lambda *args: None) + self.assertRaises(AssertionError, self.stream.on_recv, 1) + self.assertRaises(AssertionError, self.stream.on_send, 1) + self.assertRaises(AssertionError, self.stream.on_recv, zmq) + diff --git a/src/console/zmq/utils/__init__.py b/src/console/zmq/utils/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/console/zmq/utils/__init__.py diff --git a/src/console/zmq/utils/buffers.pxd b/src/console/zmq/utils/buffers.pxd new file mode 100644 index 00000000..998aa551 --- /dev/null +++ b/src/console/zmq/utils/buffers.pxd @@ -0,0 +1,313 @@ +"""Python version-independent methods for C/Python buffers. + +This file was copied and adapted from mpi4py. + +Authors +------- +* MinRK +""" + +#----------------------------------------------------------------------------- +# Copyright (c) 2010 Lisandro Dalcin +# All rights reserved. +# Used under BSD License: http://www.opensource.org/licenses/bsd-license.php +# +# Retrieval: +# Jul 23, 2010 18:00 PST (r539) +# http://code.google.com/p/mpi4py/source/browse/trunk/src/MPI/asbuffer.pxi +# +# Modifications from original: +# Copyright (c) 2010-2012 Brian Granger, Min Ragan-Kelley +# +# Distributed under the terms of the New BSD License. The full license is in +# the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + + +#----------------------------------------------------------------------------- +# Python includes. +#----------------------------------------------------------------------------- + +# get version-independent aliases: +cdef extern from "pyversion_compat.h": + pass + +# Python 3 buffer interface (PEP 3118) +cdef extern from "Python.h": + int PY_MAJOR_VERSION + int PY_MINOR_VERSION + ctypedef int Py_ssize_t + ctypedef struct PyMemoryViewObject: + pass + ctypedef struct Py_buffer: + void *buf + Py_ssize_t len + int readonly + char *format + int ndim + Py_ssize_t *shape + Py_ssize_t *strides + Py_ssize_t *suboffsets + Py_ssize_t itemsize + void *internal + cdef enum: + PyBUF_SIMPLE + PyBUF_WRITABLE + PyBUF_FORMAT + PyBUF_ANY_CONTIGUOUS + int PyObject_CheckBuffer(object) + int PyObject_GetBuffer(object, Py_buffer *, int) except -1 + void PyBuffer_Release(Py_buffer *) + + int PyBuffer_FillInfo(Py_buffer *view, object obj, void *buf, + Py_ssize_t len, int readonly, int infoflags) except -1 + object PyMemoryView_FromBuffer(Py_buffer *info) + + object PyMemoryView_FromObject(object) + +# Python 2 buffer interface (legacy) +cdef extern from "Python.h": + ctypedef void const_void "const void" + Py_ssize_t Py_END_OF_BUFFER + int PyObject_CheckReadBuffer(object) + int PyObject_AsReadBuffer (object, const_void **, Py_ssize_t *) except -1 + int PyObject_AsWriteBuffer(object, void **, Py_ssize_t *) except -1 + + object PyBuffer_FromMemory(void *ptr, Py_ssize_t s) + object PyBuffer_FromReadWriteMemory(void *ptr, Py_ssize_t s) + + object PyBuffer_FromObject(object, Py_ssize_t offset, Py_ssize_t size) + object PyBuffer_FromReadWriteObject(object, Py_ssize_t offset, Py_ssize_t size) + + +#----------------------------------------------------------------------------- +# asbuffer: C buffer from python object +#----------------------------------------------------------------------------- + + +cdef inline int memoryview_available(): + return PY_MAJOR_VERSION >= 3 or (PY_MAJOR_VERSION >=2 and PY_MINOR_VERSION >= 7) + +cdef inline int oldstyle_available(): + return PY_MAJOR_VERSION < 3 + + +cdef inline int check_buffer(object ob): + """Version independent check for whether an object is a buffer. + + Parameters + ---------- + object : object + Any Python object + + Returns + ------- + int : 0 if no buffer interface, 3 if newstyle buffer interface, 2 if oldstyle. + """ + if PyObject_CheckBuffer(ob): + return 3 + if oldstyle_available(): + return PyObject_CheckReadBuffer(ob) and 2 + return 0 + + +cdef inline object asbuffer(object ob, int writable, int format, + void **base, Py_ssize_t *size, + Py_ssize_t *itemsize): + """Turn an object into a C buffer in a Python version-independent way. + + Parameters + ---------- + ob : object + The object to be turned into a buffer. + Must provide a Python Buffer interface + writable : int + Whether the resulting buffer should be allowed to write + to the object. + format : int + The format of the buffer. See Python buffer docs. + base : void ** + The pointer that will be used to store the resulting C buffer. + size : Py_ssize_t * + The size of the buffer(s). + itemsize : Py_ssize_t * + The size of an item, if the buffer is non-contiguous. + + Returns + ------- + An object describing the buffer format. Generally a str, such as 'B'. + """ + + cdef void *bptr = NULL + cdef Py_ssize_t blen = 0, bitemlen = 0 + cdef Py_buffer view + cdef int flags = PyBUF_SIMPLE + cdef int mode = 0 + + bfmt = None + + mode = check_buffer(ob) + if mode == 0: + raise TypeError("%r does not provide a buffer interface."%ob) + + if mode == 3: + flags = PyBUF_ANY_CONTIGUOUS + if writable: + flags |= PyBUF_WRITABLE + if format: + flags |= PyBUF_FORMAT + PyObject_GetBuffer(ob, &view, flags) + bptr = view.buf + blen = view.len + if format: + if view.format != NULL: + bfmt = view.format + bitemlen = view.itemsize + PyBuffer_Release(&view) + else: # oldstyle + if writable: + PyObject_AsWriteBuffer(ob, &bptr, &blen) + else: + PyObject_AsReadBuffer(ob, <const_void **>&bptr, &blen) + if format: + try: # numpy.ndarray + dtype = ob.dtype + bfmt = dtype.char + bitemlen = dtype.itemsize + except AttributeError: + try: # array.array + bfmt = ob.typecode + bitemlen = ob.itemsize + except AttributeError: + if isinstance(ob, bytes): + bfmt = b"B" + bitemlen = 1 + else: + # nothing found + bfmt = None + bitemlen = 0 + if base: base[0] = <void *>bptr + if size: size[0] = <Py_ssize_t>blen + if itemsize: itemsize[0] = <Py_ssize_t>bitemlen + + if PY_MAJOR_VERSION >= 3 and bfmt is not None: + return bfmt.decode('ascii') + return bfmt + + +cdef inline object asbuffer_r(object ob, void **base, Py_ssize_t *size): + """Wrapper for standard calls to asbuffer with a readonly buffer.""" + asbuffer(ob, 0, 0, base, size, NULL) + return ob + + +cdef inline object asbuffer_w(object ob, void **base, Py_ssize_t *size): + """Wrapper for standard calls to asbuffer with a writable buffer.""" + asbuffer(ob, 1, 0, base, size, NULL) + return ob + +#------------------------------------------------------------------------------ +# frombuffer: python buffer/view from C buffer +#------------------------------------------------------------------------------ + + +cdef inline object frombuffer_3(void *ptr, Py_ssize_t s, int readonly): + """Python 3 version of frombuffer. + + This is the Python 3 model, but will work on Python >= 2.6. Currently, + we use it only on >= 3.0. + """ + cdef Py_buffer pybuf + cdef Py_ssize_t *shape = [s] + cdef str astr="" + PyBuffer_FillInfo(&pybuf, astr, ptr, s, readonly, PyBUF_SIMPLE) + pybuf.format = "B" + pybuf.shape = shape + return PyMemoryView_FromBuffer(&pybuf) + + +cdef inline object frombuffer_2(void *ptr, Py_ssize_t s, int readonly): + """Python 2 version of frombuffer. + + This must be used for Python <= 2.6, but we use it for all Python < 3. + """ + + if oldstyle_available(): + if readonly: + return PyBuffer_FromMemory(ptr, s) + else: + return PyBuffer_FromReadWriteMemory(ptr, s) + else: + raise NotImplementedError("Old style buffers not available.") + + +cdef inline object frombuffer(void *ptr, Py_ssize_t s, int readonly): + """Create a Python Buffer/View of a C array. + + Parameters + ---------- + ptr : void * + Pointer to the array to be copied. + s : size_t + Length of the buffer. + readonly : int + whether the resulting object should be allowed to write to the buffer. + + Returns + ------- + Python Buffer/View of the C buffer. + """ + # oldstyle first priority for now + if oldstyle_available(): + return frombuffer_2(ptr, s, readonly) + else: + return frombuffer_3(ptr, s, readonly) + + +cdef inline object frombuffer_r(void *ptr, Py_ssize_t s): + """Wrapper for readonly view frombuffer.""" + return frombuffer(ptr, s, 1) + + +cdef inline object frombuffer_w(void *ptr, Py_ssize_t s): + """Wrapper for writable view frombuffer.""" + return frombuffer(ptr, s, 0) + +#------------------------------------------------------------------------------ +# viewfromobject: python buffer/view from python object, refcounts intact +# frombuffer(asbuffer(obj)) would lose track of refs +#------------------------------------------------------------------------------ + +cdef inline object viewfromobject(object obj, int readonly): + """Construct a Python Buffer/View object from another Python object. + + This work in a Python version independent manner. + + Parameters + ---------- + obj : object + The input object to be cast as a buffer + readonly : int + Whether the result should be prevented from overwriting the original. + + Returns + ------- + Buffer/View of the original object. + """ + if not memoryview_available(): + if readonly: + return PyBuffer_FromObject(obj, 0, Py_END_OF_BUFFER) + else: + return PyBuffer_FromReadWriteObject(obj, 0, Py_END_OF_BUFFER) + else: + return PyMemoryView_FromObject(obj) + + +cdef inline object viewfromobject_r(object obj): + """Wrapper for readonly viewfromobject.""" + return viewfromobject(obj, 1) + + +cdef inline object viewfromobject_w(object obj): + """Wrapper for writable viewfromobject.""" + return viewfromobject(obj, 0) diff --git a/src/console/zmq/utils/compiler.json b/src/console/zmq/utils/compiler.json new file mode 100644 index 00000000..e58fc130 --- /dev/null +++ b/src/console/zmq/utils/compiler.json @@ -0,0 +1,24 @@ +{ + "extra_link_args": [], + "define_macros": [ + [ + "HAVE_SYS_UN_H", + 1 + ] + ], + "runtime_library_dirs": [ + "$ORIGIN/.." + ], + "libraries": [ + "zmq" + ], + "library_dirs": [ + "zmq" + ], + "include_dirs": [ + "/auto/srg-sce-swinfra-usr/emb/users/hhaim/work/depot/asr1k/emb/private/bpsim/main/src/zmq/include", + "zmq/utils", + "zmq/backend/cython", + "zmq/devices" + ] +}
\ No newline at end of file diff --git a/src/console/zmq/utils/config.json b/src/console/zmq/utils/config.json new file mode 100644 index 00000000..1e4611f9 --- /dev/null +++ b/src/console/zmq/utils/config.json @@ -0,0 +1,10 @@ +{ + "have_sys_un_h": true, + "zmq_prefix": "/auto/srg-sce-swinfra-usr/emb/users/hhaim/work/depot/asr1k/emb/private/bpsim/main/src/zmq", + "no_libzmq_extension": true, + "libzmq_extension": false, + "easy_install": {}, + "bdist_egg": {}, + "skip_check_zmq": false, + "build_ext": {} +}
\ No newline at end of file diff --git a/src/console/zmq/utils/constant_names.py b/src/console/zmq/utils/constant_names.py new file mode 100755 index 00000000..47da9dc2 --- /dev/null +++ b/src/console/zmq/utils/constant_names.py @@ -0,0 +1,365 @@ +"""0MQ Constant names""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +# dictionaries of constants new or removed in particular versions + +new_in = { + (2,2,0) : [ + 'RCVTIMEO', + 'SNDTIMEO', + ], + (3,2,2) : [ + # errnos + 'EMSGSIZE', + 'EAFNOSUPPORT', + 'ENETUNREACH', + 'ECONNABORTED', + 'ECONNRESET', + 'ENOTCONN', + 'ETIMEDOUT', + 'EHOSTUNREACH', + 'ENETRESET', + + # ctx opts + 'IO_THREADS', + 'MAX_SOCKETS', + 'IO_THREADS_DFLT', + 'MAX_SOCKETS_DFLT', + + # socket opts + 'ROUTER_BEHAVIOR', + 'ROUTER_MANDATORY', + 'FAIL_UNROUTABLE', + 'TCP_KEEPALIVE', + 'TCP_KEEPALIVE_CNT', + 'TCP_KEEPALIVE_IDLE', + 'TCP_KEEPALIVE_INTVL', + 'DELAY_ATTACH_ON_CONNECT', + 'XPUB_VERBOSE', + + # msg opts + 'MORE', + + 'EVENT_CONNECTED', + 'EVENT_CONNECT_DELAYED', + 'EVENT_CONNECT_RETRIED', + 'EVENT_LISTENING', + 'EVENT_BIND_FAILED', + 'EVENT_ACCEPTED', + 'EVENT_ACCEPT_FAILED', + 'EVENT_CLOSED', + 'EVENT_CLOSE_FAILED', + 'EVENT_DISCONNECTED', + 'EVENT_ALL', + ], + (4,0,0) : [ + # socket types + 'STREAM', + + # socket opts + 'IMMEDIATE', + 'ROUTER_RAW', + 'IPV6', + 'MECHANISM', + 'PLAIN_SERVER', + 'PLAIN_USERNAME', + 'PLAIN_PASSWORD', + 'CURVE_SERVER', + 'CURVE_PUBLICKEY', + 'CURVE_SECRETKEY', + 'CURVE_SERVERKEY', + 'PROBE_ROUTER', + 'REQ_RELAXED', + 'REQ_CORRELATE', + 'CONFLATE', + 'ZAP_DOMAIN', + + # security + 'NULL', + 'PLAIN', + 'CURVE', + + # events + 'EVENT_MONITOR_STOPPED', + ], + (4,1,0) : [ + # ctx opts + 'SOCKET_LIMIT', + 'THREAD_PRIORITY', + 'THREAD_PRIORITY_DFLT', + 'THREAD_SCHED_POLICY', + 'THREAD_SCHED_POLICY_DFLT', + + # socket opts + 'ROUTER_HANDOVER', + 'TOS', + 'IPC_FILTER_PID', + 'IPC_FILTER_UID', + 'IPC_FILTER_GID', + 'CONNECT_RID', + 'GSSAPI_SERVER', + 'GSSAPI_PRINCIPAL', + 'GSSAPI_SERVICE_PRINCIPAL', + 'GSSAPI_PLAINTEXT', + 'HANDSHAKE_IVL', + 'IDENTITY_FD', + 'XPUB_NODROP', + 'SOCKS_PROXY', + + # msg opts + 'SRCFD', + 'SHARED', + + # security + 'GSSAPI', + + ], +} + + +removed_in = { + (3,2,2) : [ + 'UPSTREAM', + 'DOWNSTREAM', + + 'HWM', + 'SWAP', + 'MCAST_LOOP', + 'RECOVERY_IVL_MSEC', + ] +} + +# collections of zmq constant names based on their role +# base names have no specific use +# opt names are validated in get/set methods of various objects + +base_names = [ + # base + 'VERSION', + 'VERSION_MAJOR', + 'VERSION_MINOR', + 'VERSION_PATCH', + 'NOBLOCK', + 'DONTWAIT', + + 'POLLIN', + 'POLLOUT', + 'POLLERR', + + 'SNDMORE', + + 'STREAMER', + 'FORWARDER', + 'QUEUE', + + 'IO_THREADS_DFLT', + 'MAX_SOCKETS_DFLT', + 'POLLITEMS_DFLT', + 'THREAD_PRIORITY_DFLT', + 'THREAD_SCHED_POLICY_DFLT', + + # socktypes + 'PAIR', + 'PUB', + 'SUB', + 'REQ', + 'REP', + 'DEALER', + 'ROUTER', + 'XREQ', + 'XREP', + 'PULL', + 'PUSH', + 'XPUB', + 'XSUB', + 'UPSTREAM', + 'DOWNSTREAM', + 'STREAM', + + # events + 'EVENT_CONNECTED', + 'EVENT_CONNECT_DELAYED', + 'EVENT_CONNECT_RETRIED', + 'EVENT_LISTENING', + 'EVENT_BIND_FAILED', + 'EVENT_ACCEPTED', + 'EVENT_ACCEPT_FAILED', + 'EVENT_CLOSED', + 'EVENT_CLOSE_FAILED', + 'EVENT_DISCONNECTED', + 'EVENT_ALL', + 'EVENT_MONITOR_STOPPED', + + # security + 'NULL', + 'PLAIN', + 'CURVE', + 'GSSAPI', + + ## ERRNO + # Often used (these are alse in errno.) + 'EAGAIN', + 'EINVAL', + 'EFAULT', + 'ENOMEM', + 'ENODEV', + 'EMSGSIZE', + 'EAFNOSUPPORT', + 'ENETUNREACH', + 'ECONNABORTED', + 'ECONNRESET', + 'ENOTCONN', + 'ETIMEDOUT', + 'EHOSTUNREACH', + 'ENETRESET', + + # For Windows compatability + 'HAUSNUMERO', + 'ENOTSUP', + 'EPROTONOSUPPORT', + 'ENOBUFS', + 'ENETDOWN', + 'EADDRINUSE', + 'EADDRNOTAVAIL', + 'ECONNREFUSED', + 'EINPROGRESS', + 'ENOTSOCK', + + # 0MQ Native + 'EFSM', + 'ENOCOMPATPROTO', + 'ETERM', + 'EMTHREAD', +] + +int64_sockopt_names = [ + 'AFFINITY', + 'MAXMSGSIZE', + + # sockopts removed in 3.0.0 + 'HWM', + 'SWAP', + 'MCAST_LOOP', + 'RECOVERY_IVL_MSEC', +] + +bytes_sockopt_names = [ + 'IDENTITY', + 'SUBSCRIBE', + 'UNSUBSCRIBE', + 'LAST_ENDPOINT', + 'TCP_ACCEPT_FILTER', + + 'PLAIN_USERNAME', + 'PLAIN_PASSWORD', + + 'CURVE_PUBLICKEY', + 'CURVE_SECRETKEY', + 'CURVE_SERVERKEY', + 'ZAP_DOMAIN', + 'CONNECT_RID', + 'GSSAPI_PRINCIPAL', + 'GSSAPI_SERVICE_PRINCIPAL', + 'SOCKS_PROXY', +] + +fd_sockopt_names = [ + 'FD', + 'IDENTITY_FD', +] + +int_sockopt_names = [ + # sockopts + 'RECONNECT_IVL_MAX', + + # sockopts new in 2.2.0 + 'SNDTIMEO', + 'RCVTIMEO', + + # new in 3.x + 'SNDHWM', + 'RCVHWM', + 'MULTICAST_HOPS', + 'IPV4ONLY', + + 'ROUTER_BEHAVIOR', + 'TCP_KEEPALIVE', + 'TCP_KEEPALIVE_CNT', + 'TCP_KEEPALIVE_IDLE', + 'TCP_KEEPALIVE_INTVL', + 'DELAY_ATTACH_ON_CONNECT', + 'XPUB_VERBOSE', + + 'EVENTS', + 'TYPE', + 'LINGER', + 'RECONNECT_IVL', + 'BACKLOG', + + 'ROUTER_MANDATORY', + 'FAIL_UNROUTABLE', + + 'ROUTER_RAW', + 'IMMEDIATE', + 'IPV6', + 'MECHANISM', + 'PLAIN_SERVER', + 'CURVE_SERVER', + 'PROBE_ROUTER', + 'REQ_RELAXED', + 'REQ_CORRELATE', + 'CONFLATE', + 'ROUTER_HANDOVER', + 'TOS', + 'IPC_FILTER_PID', + 'IPC_FILTER_UID', + 'IPC_FILTER_GID', + 'GSSAPI_SERVER', + 'GSSAPI_PLAINTEXT', + 'HANDSHAKE_IVL', + 'XPUB_NODROP', +] + +switched_sockopt_names = [ + 'RATE', + 'RECOVERY_IVL', + 'SNDBUF', + 'RCVBUF', + 'RCVMORE', +] + +ctx_opt_names = [ + 'IO_THREADS', + 'MAX_SOCKETS', + 'SOCKET_LIMIT', + 'THREAD_PRIORITY', + 'THREAD_SCHED_POLICY', +] + +msg_opt_names = [ + 'MORE', + 'SRCFD', + 'SHARED', +] + +from itertools import chain + +all_names = list(chain( + base_names, + ctx_opt_names, + bytes_sockopt_names, + fd_sockopt_names, + int_sockopt_names, + int64_sockopt_names, + switched_sockopt_names, + msg_opt_names, +)) + +del chain + +def no_prefix(name): + """does the given constant have a ZMQ_ prefix?""" + return name.startswith('E') and not name.startswith('EVENT') + diff --git a/src/console/zmq/utils/garbage.py b/src/console/zmq/utils/garbage.py new file mode 100755 index 00000000..80a8725a --- /dev/null +++ b/src/console/zmq/utils/garbage.py @@ -0,0 +1,180 @@ +"""Garbage collection thread for representing zmq refcount of Python objects +used in zero-copy sends. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import atexit +import struct + +from os import getpid +from collections import namedtuple +from threading import Thread, Event, Lock +import warnings + +import zmq + + +gcref = namedtuple('gcref', ['obj', 'event']) + +class GarbageCollectorThread(Thread): + """Thread in which garbage collection actually happens.""" + def __init__(self, gc): + super(GarbageCollectorThread, self).__init__() + self.gc = gc + self.daemon = True + self.pid = getpid() + self.ready = Event() + + def run(self): + # detect fork at begining of the thread + if getpid is None or getpid() != self.pid: + self.ready.set() + return + try: + s = self.gc.context.socket(zmq.PULL) + s.linger = 0 + s.bind(self.gc.url) + finally: + self.ready.set() + + while True: + # detect fork + if getpid is None or getpid() != self.pid: + return + msg = s.recv() + if msg == b'DIE': + break + fmt = 'L' if len(msg) == 4 else 'Q' + key = struct.unpack(fmt, msg)[0] + tup = self.gc.refs.pop(key, None) + if tup and tup.event: + tup.event.set() + del tup + s.close() + + +class GarbageCollector(object): + """PyZMQ Garbage Collector + + Used for representing the reference held by libzmq during zero-copy sends. + This object holds a dictionary, keyed by Python id, + of the Python objects whose memory are currently in use by zeromq. + + When zeromq is done with the memory, it sends a message on an inproc PUSH socket + containing the packed size_t (32 or 64-bit unsigned int), + which is the key in the dict. + When the PULL socket in the gc thread receives that message, + the reference is popped from the dict, + and any tracker events that should be signaled fire. + """ + + refs = None + _context = None + _lock = None + url = "inproc://pyzmq.gc.01" + + def __init__(self, context=None): + super(GarbageCollector, self).__init__() + self.refs = {} + self.pid = None + self.thread = None + self._context = context + self._lock = Lock() + self._stay_down = False + atexit.register(self._atexit) + + @property + def context(self): + if self._context is None: + self._context = zmq.Context() + return self._context + + @context.setter + def context(self, ctx): + if self.is_alive(): + if self.refs: + warnings.warn("Replacing gc context while gc is running", RuntimeWarning) + self.stop() + self._context = ctx + + def _atexit(self): + """atexit callback + + sets _stay_down flag so that gc doesn't try to start up again in other atexit handlers + """ + self._stay_down = True + self.stop() + + def stop(self): + """stop the garbage-collection thread""" + if not self.is_alive(): + return + self._stop() + + def _stop(self): + push = self.context.socket(zmq.PUSH) + push.connect(self.url) + push.send(b'DIE') + push.close() + self.thread.join() + self.context.term() + self.refs.clear() + self.context = None + + def start(self): + """Start a new garbage collection thread. + + Creates a new zmq Context used for garbage collection. + Under most circumstances, this will only be called once per process. + """ + if self.thread is not None and self.pid != getpid(): + # It's re-starting, must free earlier thread's context + # since a fork probably broke it + self._stop() + self.pid = getpid() + self.refs = {} + self.thread = GarbageCollectorThread(self) + self.thread.start() + self.thread.ready.wait() + + def is_alive(self): + """Is the garbage collection thread currently running? + + Includes checks for process shutdown or fork. + """ + if (getpid is None or + getpid() != self.pid or + self.thread is None or + not self.thread.is_alive() + ): + return False + return True + + def store(self, obj, event=None): + """store an object and (optionally) event for zero-copy""" + if not self.is_alive(): + if self._stay_down: + return 0 + # safely start the gc thread + # use lock and double check, + # so we don't start multiple threads + with self._lock: + if not self.is_alive(): + self.start() + tup = gcref(obj, event) + theid = id(tup) + self.refs[theid] = tup + return theid + + def __del__(self): + if not self.is_alive(): + return + try: + self.stop() + except Exception as e: + raise (e) + +gc = GarbageCollector() diff --git a/src/console/zmq/utils/getpid_compat.h b/src/console/zmq/utils/getpid_compat.h new file mode 100644 index 00000000..47ce90fa --- /dev/null +++ b/src/console/zmq/utils/getpid_compat.h @@ -0,0 +1,6 @@ +#ifdef _WIN32 + #include <process.h> + #define getpid _getpid +#else + #include <unistd.h> +#endif diff --git a/src/console/zmq/utils/interop.py b/src/console/zmq/utils/interop.py new file mode 100755 index 00000000..26c01969 --- /dev/null +++ b/src/console/zmq/utils/interop.py @@ -0,0 +1,33 @@ +"""Utils for interoperability with other libraries. + +Just CFFI pointer casting for now. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +try: + long +except NameError: + long = int # Python 3 + + +def cast_int_addr(n): + """Cast an address to a Python int + + This could be a Python integer or a CFFI pointer + """ + if isinstance(n, (int, long)): + return n + try: + import cffi + except ImportError: + pass + else: + # from pyzmq, this is an FFI void * + ffi = cffi.FFI() + if isinstance(n, ffi.CData): + return int(ffi.cast("size_t", n)) + + raise ValueError("Cannot cast %r to int" % n) diff --git a/src/console/zmq/utils/ipcmaxlen.h b/src/console/zmq/utils/ipcmaxlen.h new file mode 100644 index 00000000..7218db78 --- /dev/null +++ b/src/console/zmq/utils/ipcmaxlen.h @@ -0,0 +1,21 @@ +/* + +Platform-independant detection of IPC path max length + +Copyright (c) 2012 Godefroid Chapelle + +Distributed under the terms of the New BSD License. The full license is in +the file COPYING.BSD, distributed as part of this software. + */ + +#if defined(HAVE_SYS_UN_H) +#include "sys/un.h" +int get_ipc_path_max_len(void) { + struct sockaddr_un *dummy; + return sizeof(dummy->sun_path) - 1; +} +#else +int get_ipc_path_max_len(void) { + return 0; +} +#endif diff --git a/src/console/zmq/utils/jsonapi.py b/src/console/zmq/utils/jsonapi.py new file mode 100755 index 00000000..865ca6d5 --- /dev/null +++ b/src/console/zmq/utils/jsonapi.py @@ -0,0 +1,59 @@ +"""Priority based json library imports. + +Always serializes to bytes instead of unicode for zeromq compatibility +on Python 2 and 3. + +Use ``jsonapi.loads()`` and ``jsonapi.dumps()`` for guaranteed symmetry. + +Priority: ``simplejson`` > ``jsonlib2`` > stdlib ``json`` + +``jsonapi.loads/dumps`` provide kwarg-compatibility with stdlib json. + +``jsonapi.jsonmod`` will be the module of the actual underlying implementation. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.utils.strtypes import bytes, unicode + +jsonmod = None + +priority = ['simplejson', 'jsonlib2', 'json'] +for mod in priority: + try: + jsonmod = __import__(mod) + except ImportError: + pass + else: + break + +def dumps(o, **kwargs): + """Serialize object to JSON bytes (utf-8). + + See jsonapi.jsonmod.dumps for details on kwargs. + """ + + if 'separators' not in kwargs: + kwargs['separators'] = (',', ':') + + s = jsonmod.dumps(o, **kwargs) + + if isinstance(s, unicode): + s = s.encode('utf8') + + return s + +def loads(s, **kwargs): + """Load object from JSON bytes (utf-8). + + See jsonapi.jsonmod.loads for details on kwargs. + """ + + if str is unicode and isinstance(s, bytes): + s = s.decode('utf8') + + return jsonmod.loads(s, **kwargs) + +__all__ = ['jsonmod', 'dumps', 'loads'] + diff --git a/src/console/zmq/utils/monitor.py b/src/console/zmq/utils/monitor.py new file mode 100755 index 00000000..734d54b1 --- /dev/null +++ b/src/console/zmq/utils/monitor.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +"""Module holding utility and convenience functions for zmq event monitoring.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import struct +import zmq +from zmq.error import _check_version + +def parse_monitor_message(msg): + """decode zmq_monitor event messages. + + Parameters + ---------- + msg : list(bytes) + zmq multipart message that has arrived on a monitor PAIR socket. + + First frame is:: + + 16 bit event id + 32 bit event value + no padding + + Second frame is the endpoint as a bytestring + + Returns + ------- + event : dict + event description as dict with the keys `event`, `value`, and `endpoint`. + """ + + if len(msg) != 2 or len(msg[0]) != 6: + raise RuntimeError("Invalid event message format: %s" % msg) + event = {} + event['event'], event['value'] = struct.unpack("=hi", msg[0]) + event['endpoint'] = msg[1] + return event + +def recv_monitor_message(socket, flags=0): + """Receive and decode the given raw message from the monitoring socket and return a dict. + + Requires libzmq ≥ 4.0 + + The returned dict will have the following entries: + event : int, the event id as described in libzmq.zmq_socket_monitor + value : int, the event value associated with the event, see libzmq.zmq_socket_monitor + endpoint : string, the affected endpoint + + Parameters + ---------- + socket : zmq PAIR socket + The PAIR socket (created by other.get_monitor_socket()) on which to recv the message + flags : bitfield (int) + standard zmq recv flags + + Returns + ------- + event : dict + event description as dict with the keys `event`, `value`, and `endpoint`. + """ + _check_version((4,0), 'libzmq event API') + # will always return a list + msg = socket.recv_multipart(flags) + # 4.0-style event API + return parse_monitor_message(msg) + +__all__ = ['parse_monitor_message', 'recv_monitor_message'] diff --git a/src/console/zmq/utils/pyversion_compat.h b/src/console/zmq/utils/pyversion_compat.h new file mode 100644 index 00000000..fac09046 --- /dev/null +++ b/src/console/zmq/utils/pyversion_compat.h @@ -0,0 +1,25 @@ +#include "Python.h" + +#if PY_VERSION_HEX < 0x02070000 + #define PyMemoryView_FromBuffer(info) (PyErr_SetString(PyExc_NotImplementedError, \ + "new buffer interface is not available"), (PyObject *)NULL) + #define PyMemoryView_FromObject(object) (PyErr_SetString(PyExc_NotImplementedError, \ + "new buffer interface is not available"), (PyObject *)NULL) +#endif + +#if PY_VERSION_HEX >= 0x03000000 + // for buffers + #define Py_END_OF_BUFFER ((Py_ssize_t) 0) + + #define PyObject_CheckReadBuffer(object) (0) + + #define PyBuffer_FromMemory(ptr, s) (PyErr_SetString(PyExc_NotImplementedError, \ + "old buffer interface is not available"), (PyObject *)NULL) + #define PyBuffer_FromReadWriteMemory(ptr, s) (PyErr_SetString(PyExc_NotImplementedError, \ + "old buffer interface is not available"), (PyObject *)NULL) + #define PyBuffer_FromObject(object, offset, size) (PyErr_SetString(PyExc_NotImplementedError, \ + "old buffer interface is not available"), (PyObject *)NULL) + #define PyBuffer_FromReadWriteObject(object, offset, size) (PyErr_SetString(PyExc_NotImplementedError, \ + "old buffer interface is not available"), (PyObject *)NULL) + +#endif diff --git a/src/console/zmq/utils/sixcerpt.py b/src/console/zmq/utils/sixcerpt.py new file mode 100755 index 00000000..5492fd59 --- /dev/null +++ b/src/console/zmq/utils/sixcerpt.py @@ -0,0 +1,52 @@ +"""Excerpts of six.py""" + +# Copyright (C) 2010-2014 Benjamin Peterson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import sys + +# Useful for very coarse version differentiation. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 + +if PY3: + + def reraise(tp, value, tb=None): + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + +else: + def exec_(_code_, _globs_=None, _locs_=None): + """Execute code in a namespace.""" + if _globs_ is None: + frame = sys._getframe(1) + _globs_ = frame.f_globals + if _locs_ is None: + _locs_ = frame.f_locals + del frame + elif _locs_ is None: + _locs_ = _globs_ + exec("""exec _code_ in _globs_, _locs_""") + + + exec_("""def reraise(tp, value, tb=None): + raise tp, value, tb +""") diff --git a/src/console/zmq/utils/strtypes.py b/src/console/zmq/utils/strtypes.py new file mode 100755 index 00000000..548410dc --- /dev/null +++ b/src/console/zmq/utils/strtypes.py @@ -0,0 +1,45 @@ +"""Declare basic string types unambiguously for various Python versions. + +Authors +------- +* MinRK +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys + +if sys.version_info[0] >= 3: + bytes = bytes + unicode = str + basestring = (bytes, unicode) +else: + unicode = unicode + bytes = str + basestring = basestring + +def cast_bytes(s, encoding='utf8', errors='strict'): + """cast unicode or bytes to bytes""" + if isinstance(s, bytes): + return s + elif isinstance(s, unicode): + return s.encode(encoding, errors) + else: + raise TypeError("Expected unicode or bytes, got %r" % s) + +def cast_unicode(s, encoding='utf8', errors='strict'): + """cast bytes or unicode to unicode""" + if isinstance(s, bytes): + return s.decode(encoding, errors) + elif isinstance(s, unicode): + return s + else: + raise TypeError("Expected unicode or bytes, got %r" % s) + +# give short 'b' alias for cast_bytes, so that we can use fake b('stuff') +# to simulate b'stuff' +b = asbytes = cast_bytes +u = cast_unicode + +__all__ = ['asbytes', 'bytes', 'unicode', 'basestring', 'b', 'u', 'cast_bytes', 'cast_unicode'] diff --git a/src/console/zmq/utils/win32.py b/src/console/zmq/utils/win32.py new file mode 100755 index 00000000..ea758299 --- /dev/null +++ b/src/console/zmq/utils/win32.py @@ -0,0 +1,132 @@ +"""Win32 compatibility utilities.""" + +#----------------------------------------------------------------------------- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +#----------------------------------------------------------------------------- + +import os + +# No-op implementation for other platforms. +class _allow_interrupt(object): + """Utility for fixing CTRL-C events on Windows. + + On Windows, the Python interpreter intercepts CTRL-C events in order to + translate them into ``KeyboardInterrupt`` exceptions. It (presumably) + does this by setting a flag in its "control control handler" and + checking it later at a convenient location in the interpreter. + + However, when the Python interpreter is blocked waiting for the ZMQ + poll operation to complete, it must wait for ZMQ's ``select()`` + operation to complete before translating the CTRL-C event into the + ``KeyboardInterrupt`` exception. + + The only way to fix this seems to be to add our own "console control + handler" and perform some application-defined operation that will + unblock the ZMQ polling operation in order to force ZMQ to pass control + back to the Python interpreter. + + This context manager performs all that Windows-y stuff, providing you + with a hook that is called when a CTRL-C event is intercepted. This + hook allows you to unblock your ZMQ poll operation immediately, which + will then result in the expected ``KeyboardInterrupt`` exception. + + Without this context manager, your ZMQ-based application will not + respond normally to CTRL-C events on Windows. If a CTRL-C event occurs + while blocked on ZMQ socket polling, the translation to a + ``KeyboardInterrupt`` exception will be delayed until the I/O completes + and control returns to the Python interpreter (this may never happen if + you use an infinite timeout). + + A no-op implementation is provided on non-Win32 systems to avoid the + application from having to conditionally use it. + + Example usage: + + .. sourcecode:: python + + def stop_my_application(): + # ... + + with allow_interrupt(stop_my_application): + # main polling loop. + + In a typical ZMQ application, you would use the "self pipe trick" to + send message to a ``PAIR`` socket in order to interrupt your blocking + socket polling operation. + + In a Tornado event loop, you can use the ``IOLoop.stop`` method to + unblock your I/O loop. + """ + + def __init__(self, action=None): + """Translate ``action`` into a CTRL-C handler. + + ``action`` is a callable that takes no arguments and returns no + value (returned value is ignored). It must *NEVER* raise an + exception. + + If unspecified, a no-op will be used. + """ + self._init_action(action) + + def _init_action(self, action): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + return + +if os.name == 'nt': + from ctypes import WINFUNCTYPE, windll + from ctypes.wintypes import BOOL, DWORD + + kernel32 = windll.LoadLibrary('kernel32') + + # <http://msdn.microsoft.com/en-us/library/ms686016.aspx> + PHANDLER_ROUTINE = WINFUNCTYPE(BOOL, DWORD) + SetConsoleCtrlHandler = kernel32.SetConsoleCtrlHandler + SetConsoleCtrlHandler.argtypes = (PHANDLER_ROUTINE, BOOL) + SetConsoleCtrlHandler.restype = BOOL + + class allow_interrupt(_allow_interrupt): + __doc__ = _allow_interrupt.__doc__ + + def _init_action(self, action): + if action is None: + action = lambda: None + self.action = action + @PHANDLER_ROUTINE + def handle(event): + if event == 0: # CTRL_C_EVENT + action() + # Typical C implementations would return 1 to indicate that + # the event was processed and other control handlers in the + # stack should not be executed. However, that would + # prevent the Python interpreter's handler from translating + # CTRL-C to a `KeyboardInterrupt` exception, so we pretend + # that we didn't handle it. + return 0 + self.handle = handle + + def __enter__(self): + """Install the custom CTRL-C handler.""" + result = SetConsoleCtrlHandler(self.handle, 1) + if result == 0: + # Have standard library automatically call `GetLastError()` and + # `FormatMessage()` into a nice exception object :-) + raise WindowsError() + + def __exit__(self, *args): + """Remove the custom CTRL-C handler.""" + result = SetConsoleCtrlHandler(self.handle, 0) + if result == 0: + # Have standard library automatically call `GetLastError()` and + # `FormatMessage()` into a nice exception object :-) + raise WindowsError() +else: + class allow_interrupt(_allow_interrupt): + __doc__ = _allow_interrupt.__doc__ + pass diff --git a/src/console/zmq/utils/z85.py b/src/console/zmq/utils/z85.py new file mode 100755 index 00000000..1bb1784e --- /dev/null +++ b/src/console/zmq/utils/z85.py @@ -0,0 +1,56 @@ +"""Python implementation of Z85 85-bit encoding + +Z85 encoding is a plaintext encoding for a bytestring interpreted as 32bit integers. +Since the chunks are 32bit, a bytestring must be a multiple of 4 bytes. +See ZMQ RFC 32 for details. + + +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys +import struct + +PY3 = sys.version_info[0] >= 3 +# Z85CHARS is the base 85 symbol table +Z85CHARS = b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.-:+=^!/*?&<>()[]{}@%$#" +# Z85MAP maps integers in [0,84] to the appropriate character in Z85CHARS +Z85MAP = dict([(c, idx) for idx, c in enumerate(Z85CHARS)]) + +_85s = [ 85**i for i in range(5) ][::-1] + +def encode(rawbytes): + """encode raw bytes into Z85""" + # Accepts only byte arrays bounded to 4 bytes + if len(rawbytes) % 4: + raise ValueError("length must be multiple of 4, not %i" % len(rawbytes)) + + nvalues = len(rawbytes) / 4 + + values = struct.unpack('>%dI' % nvalues, rawbytes) + encoded = [] + for v in values: + for offset in _85s: + encoded.append(Z85CHARS[(v // offset) % 85]) + + # In Python 3, encoded is a list of integers (obviously?!) + if PY3: + return bytes(encoded) + else: + return b''.join(encoded) + +def decode(z85bytes): + """decode Z85 bytes to raw bytes""" + if len(z85bytes) % 5: + raise ValueError("Z85 length must be multiple of 5, not %i" % len(z85bytes)) + + nvalues = len(z85bytes) / 5 + values = [] + for i in range(0, len(z85bytes), 5): + value = 0 + for j, offset in enumerate(_85s): + value += Z85MAP[z85bytes[i+j]] * offset + values.append(value) + return struct.pack('>%dI' % nvalues, *values) diff --git a/src/console/zmq/utils/zmq_compat.h b/src/console/zmq/utils/zmq_compat.h new file mode 100644 index 00000000..81c57b69 --- /dev/null +++ b/src/console/zmq/utils/zmq_compat.h @@ -0,0 +1,80 @@ +//----------------------------------------------------------------------------- +// Copyright (c) 2010 Brian Granger, Min Ragan-Kelley +// +// Distributed under the terms of the New BSD License. The full license is in +// the file COPYING.BSD, distributed as part of this software. +//----------------------------------------------------------------------------- + +#if defined(_MSC_VER) +#define pyzmq_int64_t __int64 +#else +#include <stdint.h> +#define pyzmq_int64_t int64_t +#endif + + +#include "zmq.h" +// version compatibility for constants: +#include "zmq_constants.h" + +#define _missing (-1) + + +// define fd type (from libzmq's fd.hpp) +#ifdef _WIN32 + #ifdef _MSC_VER && _MSC_VER <= 1400 + #define ZMQ_FD_T UINT_PTR + #else + #define ZMQ_FD_T SOCKET + #endif +#else + #define ZMQ_FD_T int +#endif + +// use unambiguous aliases for zmq_send/recv functions + +#if ZMQ_VERSION_MAJOR >= 4 +// nothing to remove +#else + #define zmq_curve_keypair(z85_public_key, z85_secret_key) _missing +#endif + +#if ZMQ_VERSION_MAJOR >= 4 && ZMQ_VERSION_MINOR >= 1 +// nothing to remove +#else + #define zmq_msg_gets(msg, prop) _missing + #define zmq_has(capability) _missing +#endif + +#if ZMQ_VERSION_MAJOR >= 3 + #define zmq_sendbuf zmq_send + #define zmq_recvbuf zmq_recv + + // 3.x deprecations - these symbols haven't been removed, + // but let's protect against their planned removal + #define zmq_device(device_type, isocket, osocket) _missing + #define zmq_init(io_threads) ((void*)NULL) + #define zmq_term zmq_ctx_destroy +#else + #define zmq_ctx_set(ctx, opt, val) _missing + #define zmq_ctx_get(ctx, opt) _missing + #define zmq_ctx_destroy zmq_term + #define zmq_ctx_new() ((void*)NULL) + + #define zmq_proxy(a,b,c) _missing + + #define zmq_disconnect(s, addr) _missing + #define zmq_unbind(s, addr) _missing + + #define zmq_msg_more(msg) _missing + #define zmq_msg_get(msg, opt) _missing + #define zmq_msg_set(msg, opt, val) _missing + #define zmq_msg_send(msg, s, flags) zmq_send(s, msg, flags) + #define zmq_msg_recv(msg, s, flags) zmq_recv(s, msg, flags) + + #define zmq_sendbuf(s, buf, len, flags) _missing + #define zmq_recvbuf(s, buf, len, flags) _missing + + #define zmq_socket_monitor(s, addr, flags) _missing + +#endif diff --git a/src/console/zmq/utils/zmq_constants.h b/src/console/zmq/utils/zmq_constants.h new file mode 100644 index 00000000..97683022 --- /dev/null +++ b/src/console/zmq/utils/zmq_constants.h @@ -0,0 +1,622 @@ +#ifndef _PYZMQ_CONSTANT_DEFS +#define _PYZMQ_CONSTANT_DEFS + +#define _PYZMQ_UNDEFINED (-9999) +#ifndef ZMQ_VERSION + #define ZMQ_VERSION (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_VERSION_MAJOR + #define ZMQ_VERSION_MAJOR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_VERSION_MINOR + #define ZMQ_VERSION_MINOR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_VERSION_PATCH + #define ZMQ_VERSION_PATCH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_NOBLOCK + #define ZMQ_NOBLOCK (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DONTWAIT + #define ZMQ_DONTWAIT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLIN + #define ZMQ_POLLIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLOUT + #define ZMQ_POLLOUT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLERR + #define ZMQ_POLLERR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDMORE + #define ZMQ_SNDMORE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_STREAMER + #define ZMQ_STREAMER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_FORWARDER + #define ZMQ_FORWARDER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_QUEUE + #define ZMQ_QUEUE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IO_THREADS_DFLT + #define ZMQ_IO_THREADS_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MAX_SOCKETS_DFLT + #define ZMQ_MAX_SOCKETS_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLITEMS_DFLT + #define ZMQ_POLLITEMS_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_PRIORITY_DFLT + #define ZMQ_THREAD_PRIORITY_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_SCHED_POLICY_DFLT + #define ZMQ_THREAD_SCHED_POLICY_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PAIR + #define ZMQ_PAIR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PUB + #define ZMQ_PUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SUB + #define ZMQ_SUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REQ + #define ZMQ_REQ (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REP + #define ZMQ_REP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DEALER + #define ZMQ_DEALER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER + #define ZMQ_ROUTER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XREQ + #define ZMQ_XREQ (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XREP + #define ZMQ_XREP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PULL + #define ZMQ_PULL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PUSH + #define ZMQ_PUSH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XPUB + #define ZMQ_XPUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XSUB + #define ZMQ_XSUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_UPSTREAM + #define ZMQ_UPSTREAM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DOWNSTREAM + #define ZMQ_DOWNSTREAM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_STREAM + #define ZMQ_STREAM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CONNECTED + #define ZMQ_EVENT_CONNECTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CONNECT_DELAYED + #define ZMQ_EVENT_CONNECT_DELAYED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CONNECT_RETRIED + #define ZMQ_EVENT_CONNECT_RETRIED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_LISTENING + #define ZMQ_EVENT_LISTENING (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_BIND_FAILED + #define ZMQ_EVENT_BIND_FAILED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_ACCEPTED + #define ZMQ_EVENT_ACCEPTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_ACCEPT_FAILED + #define ZMQ_EVENT_ACCEPT_FAILED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CLOSED + #define ZMQ_EVENT_CLOSED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CLOSE_FAILED + #define ZMQ_EVENT_CLOSE_FAILED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_DISCONNECTED + #define ZMQ_EVENT_DISCONNECTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_ALL + #define ZMQ_EVENT_ALL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_MONITOR_STOPPED + #define ZMQ_EVENT_MONITOR_STOPPED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_NULL + #define ZMQ_NULL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN + #define ZMQ_PLAIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE + #define ZMQ_CURVE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI + #define ZMQ_GSSAPI (_PYZMQ_UNDEFINED) +#endif + +#ifndef EAGAIN + #define EAGAIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef EINVAL + #define EINVAL (_PYZMQ_UNDEFINED) +#endif + +#ifndef EFAULT + #define EFAULT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOMEM + #define ENOMEM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENODEV + #define ENODEV (_PYZMQ_UNDEFINED) +#endif + +#ifndef EMSGSIZE + #define EMSGSIZE (_PYZMQ_UNDEFINED) +#endif + +#ifndef EAFNOSUPPORT + #define EAFNOSUPPORT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENETUNREACH + #define ENETUNREACH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ECONNABORTED + #define ECONNABORTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ECONNRESET + #define ECONNRESET (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOTCONN + #define ENOTCONN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ETIMEDOUT + #define ETIMEDOUT (_PYZMQ_UNDEFINED) +#endif + +#ifndef EHOSTUNREACH + #define EHOSTUNREACH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENETRESET + #define ENETRESET (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_HAUSNUMERO + #define ZMQ_HAUSNUMERO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOTSUP + #define ENOTSUP (_PYZMQ_UNDEFINED) +#endif + +#ifndef EPROTONOSUPPORT + #define EPROTONOSUPPORT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOBUFS + #define ENOBUFS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENETDOWN + #define ENETDOWN (_PYZMQ_UNDEFINED) +#endif + +#ifndef EADDRINUSE + #define EADDRINUSE (_PYZMQ_UNDEFINED) +#endif + +#ifndef EADDRNOTAVAIL + #define EADDRNOTAVAIL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ECONNREFUSED + #define ECONNREFUSED (_PYZMQ_UNDEFINED) +#endif + +#ifndef EINPROGRESS + #define EINPROGRESS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOTSOCK + #define ENOTSOCK (_PYZMQ_UNDEFINED) +#endif + +#ifndef EFSM + #define EFSM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOCOMPATPROTO + #define ENOCOMPATPROTO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ETERM + #define ETERM (_PYZMQ_UNDEFINED) +#endif + +#ifndef EMTHREAD + #define EMTHREAD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IO_THREADS + #define ZMQ_IO_THREADS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MAX_SOCKETS + #define ZMQ_MAX_SOCKETS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SOCKET_LIMIT + #define ZMQ_SOCKET_LIMIT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_PRIORITY + #define ZMQ_THREAD_PRIORITY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_SCHED_POLICY + #define ZMQ_THREAD_SCHED_POLICY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IDENTITY + #define ZMQ_IDENTITY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SUBSCRIBE + #define ZMQ_SUBSCRIBE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_UNSUBSCRIBE + #define ZMQ_UNSUBSCRIBE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_LAST_ENDPOINT + #define ZMQ_LAST_ENDPOINT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_ACCEPT_FILTER + #define ZMQ_TCP_ACCEPT_FILTER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN_USERNAME + #define ZMQ_PLAIN_USERNAME (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN_PASSWORD + #define ZMQ_PLAIN_PASSWORD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_PUBLICKEY + #define ZMQ_CURVE_PUBLICKEY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_SECRETKEY + #define ZMQ_CURVE_SECRETKEY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_SERVERKEY + #define ZMQ_CURVE_SERVERKEY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ZAP_DOMAIN + #define ZMQ_ZAP_DOMAIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CONNECT_RID + #define ZMQ_CONNECT_RID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_PRINCIPAL + #define ZMQ_GSSAPI_PRINCIPAL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_SERVICE_PRINCIPAL + #define ZMQ_GSSAPI_SERVICE_PRINCIPAL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SOCKS_PROXY + #define ZMQ_SOCKS_PROXY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_FD + #define ZMQ_FD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IDENTITY_FD + #define ZMQ_IDENTITY_FD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECONNECT_IVL_MAX + #define ZMQ_RECONNECT_IVL_MAX (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDTIMEO + #define ZMQ_SNDTIMEO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVTIMEO + #define ZMQ_RCVTIMEO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDHWM + #define ZMQ_SNDHWM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVHWM + #define ZMQ_RCVHWM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MULTICAST_HOPS + #define ZMQ_MULTICAST_HOPS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPV4ONLY + #define ZMQ_IPV4ONLY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_BEHAVIOR + #define ZMQ_ROUTER_BEHAVIOR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE + #define ZMQ_TCP_KEEPALIVE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE_CNT + #define ZMQ_TCP_KEEPALIVE_CNT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE_IDLE + #define ZMQ_TCP_KEEPALIVE_IDLE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE_INTVL + #define ZMQ_TCP_KEEPALIVE_INTVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DELAY_ATTACH_ON_CONNECT + #define ZMQ_DELAY_ATTACH_ON_CONNECT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XPUB_VERBOSE + #define ZMQ_XPUB_VERBOSE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENTS + #define ZMQ_EVENTS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TYPE + #define ZMQ_TYPE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_LINGER + #define ZMQ_LINGER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECONNECT_IVL + #define ZMQ_RECONNECT_IVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_BACKLOG + #define ZMQ_BACKLOG (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_MANDATORY + #define ZMQ_ROUTER_MANDATORY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_FAIL_UNROUTABLE + #define ZMQ_FAIL_UNROUTABLE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_RAW + #define ZMQ_ROUTER_RAW (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IMMEDIATE + #define ZMQ_IMMEDIATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPV6 + #define ZMQ_IPV6 (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MECHANISM + #define ZMQ_MECHANISM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN_SERVER + #define ZMQ_PLAIN_SERVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_SERVER + #define ZMQ_CURVE_SERVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PROBE_ROUTER + #define ZMQ_PROBE_ROUTER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REQ_RELAXED + #define ZMQ_REQ_RELAXED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REQ_CORRELATE + #define ZMQ_REQ_CORRELATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CONFLATE + #define ZMQ_CONFLATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_HANDOVER + #define ZMQ_ROUTER_HANDOVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TOS + #define ZMQ_TOS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPC_FILTER_PID + #define ZMQ_IPC_FILTER_PID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPC_FILTER_UID + #define ZMQ_IPC_FILTER_UID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPC_FILTER_GID + #define ZMQ_IPC_FILTER_GID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_SERVER + #define ZMQ_GSSAPI_SERVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_PLAINTEXT + #define ZMQ_GSSAPI_PLAINTEXT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_HANDSHAKE_IVL + #define ZMQ_HANDSHAKE_IVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XPUB_NODROP + #define ZMQ_XPUB_NODROP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_AFFINITY + #define ZMQ_AFFINITY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MAXMSGSIZE + #define ZMQ_MAXMSGSIZE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_HWM + #define ZMQ_HWM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SWAP + #define ZMQ_SWAP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MCAST_LOOP + #define ZMQ_MCAST_LOOP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECOVERY_IVL_MSEC + #define ZMQ_RECOVERY_IVL_MSEC (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RATE + #define ZMQ_RATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECOVERY_IVL + #define ZMQ_RECOVERY_IVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDBUF + #define ZMQ_SNDBUF (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVBUF + #define ZMQ_RCVBUF (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVMORE + #define ZMQ_RCVMORE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MORE + #define ZMQ_MORE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SRCFD + #define ZMQ_SRCFD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SHARED + #define ZMQ_SHARED (_PYZMQ_UNDEFINED) +#endif + + +#endif // ifndef _PYZMQ_CONSTANT_DEFS |