# coding: utf-8 """0MQ Socket pure Python methods.""" # Copyright (C) PyZMQ Developers # Distributed under the terms of the Modified BSD License. import codecs import random import warnings import zmq from zmq.backend import Socket as SocketBase from .poll import Poller from . import constants from .attrsettr import AttributeSetter from zmq.error import ZMQError, ZMQBindError from zmq.utils import jsonapi from zmq.utils.strtypes import bytes,unicode,basestring from zmq.utils.interop import cast_int_addr from .constants import ( SNDMORE, ENOTSUP, POLLIN, int64_sockopt_names, int_sockopt_names, bytes_sockopt_names, fd_sockopt_names, ) try: import cPickle pickle = cPickle except: cPickle = None import pickle try: DEFAULT_PROTOCOL = pickle.DEFAULT_PROTOCOL except AttributeError: DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL class Socket(SocketBase, AttributeSetter): """The ZMQ socket object To create a Socket, first create a Context:: ctx = zmq.Context.instance() then call ``ctx.socket(socket_type)``:: s = ctx.socket(zmq.ROUTER) """ _shadow = False def __del__(self): if not self._shadow: self.close() # socket as context manager: def __enter__(self): """Sockets are context managers .. versionadded:: 14.4 """ return self def __exit__(self, *args, **kwargs): self.close() #------------------------------------------------------------------------- # Socket creation #------------------------------------------------------------------------- @classmethod def shadow(cls, address): """Shadow an existing libzmq socket address is the integer address of the libzmq socket or an FFI pointer to it. .. versionadded:: 14.1 """ address = cast_int_addr(address) return cls(shadow=address) #------------------------------------------------------------------------- # Deprecated aliases #------------------------------------------------------------------------- @property def socket_type(self): warnings.warn("Socket.socket_type is deprecated, use Socket.type", DeprecationWarning ) return self.type #------------------------------------------------------------------------- # Hooks for sockopt completion #------------------------------------------------------------------------- def __dir__(self): keys = dir(self.__class__) for collection in ( bytes_sockopt_names, int_sockopt_names, int64_sockopt_names, fd_sockopt_names, ): keys.extend(collection) return keys #------------------------------------------------------------------------- # Getting/Setting options #------------------------------------------------------------------------- setsockopt = SocketBase.set getsockopt = SocketBase.get def set_string(self, option, optval, encoding='utf-8'): """set socket options with a unicode object This is simply a wrapper for setsockopt to protect from encoding ambiguity. See the 0MQ documentation for details on specific options. Parameters ---------- option : int The name of the option to set. Can be any of: SUBSCRIBE, UNSUBSCRIBE, IDENTITY optval : unicode string (unicode on py2, str on py3) The value of the option to set. encoding : str The encoding to be used, default is utf8 """ if not isinstance(optval, unicode): raise TypeError("unicode strings only") return self.set(option, optval.encode(encoding)) setsockopt_unicode = setsockopt_string = set_string def get_string(self, option, encoding='utf-8'): """get the value of a socket option See the 0MQ documentation for details on specific options. Parameters ---------- option : int The option to retrieve. Returns ------- optval : unicode string (unicode on py2, str on py3) The value of the option as a unicode string. """ if option not in constants.bytes_sockopts: raise TypeError("option %i will not return a string to be decoded"%option) return self.getsockopt(option).decode(encoding) getsockopt_unicode = getsockopt_string = get_string def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=100): """bind this socket to a random port in a range Parameters ---------- addr : str The address string without the port to pass to ``Socket.bind()``. min_port : int, optional The minimum port in the range of ports to try (inclusive). max_port : int, optional The maximum port in the range of ports to try (exclusive). max_tries : int, optional The maximum number of bind attempts to make. Returns ------- port : int The port the socket was bound to. Raises ------ ZMQBindError if `max_tries` reached before successful bind """ for i in range(max_tries): try: port = random.randrange(min_port, max_port) self.bind('%s:%s' % (addr, port)) except ZMQError as exception: if not exception.errno == zmq.EADDRINUSE: raise else: return port raise ZMQBindError("Could not bind socket to random port.") def get_hwm(self): """get the High Water Mark On libzmq ≥ 3, this gets SNDHWM if available, otherwise RCVHWM """ major = zmq.zmq_version_info()[0] if major >= 3: # return sndhwm, fallback on rcvhwm try: return self.getsockopt(zmq.SNDHWM) except zmq.ZMQError as e: pass return self.getsockopt(zmq.RCVHWM) else: return self.getsockopt(zmq.HWM) def set_hwm(self, value): """set the High Water Mark On libzmq ≥ 3, this sets both SNDHWM and RCVHWM """ major = zmq.zmq_version_info()[0] if major >= 3: raised = None try: self.sndhwm = value except Exception as e: raised = e try: self.rcvhwm = value except Exception: raised = e if raised: raise raised else: return self.setsockopt(zmq.HWM, value) hwm = property(get_hwm, set_hwm, """property for High Water Mark Setting hwm sets both SNDHWM and RCVHWM as appropriate. It gets SNDHWM if available, otherwise RCVHWM. """ ) #------------------------------------------------------------------------- # Sending and receiving messages #------------------------------------------------------------------------- def send_multipart(self, msg_parts, flags=0, copy=True, track=False): """send a sequence of buffers as a multipart message The zmq.SNDMORE flag is added to all msg parts before the last. Parameters ---------- msg_parts : iterable A sequence of objects to send as a multipart message. Each element can be any sendable object (Frame, bytes, buffer-providers) flags : int, optional SNDMORE is handled automatically for frames before the last. copy : bool, optional Should the frame(s) be sent in a copying or non-copying manner. track : bool, optional Should the frame(s) be tracked for notification that ZMQ has finished with it (ignored if copy=True). Returns ------- None : if copy or not track MessageTracker : if track and not copy a MessageTracker object, whose `pending` property will be True until the last send is completed. """ for msg in msg_parts[:-1]: self.send(msg, SNDMORE|flags, copy=copy, track=track) # Send the last part without the extra SNDMORE flag. return self.send(msg_parts[-1], flags, copy=copy, track=track) def recv_multipart(self, flags=0, copy=True, track=False): """receive a multipart message as a list of bytes or Frame objects Parameters ---------- flags : int, optional Any supported flag: NOBLOCK. If NOBLOCK is set, this method will raise a ZMQError with EAGAIN if a message is not ready. If NOBLOCK is not set, then this method will block until a message arrives. copy : bool, optional Should the message frame(s) be received in a copying or non-copying manner? If False a Frame object is returned for each part, if True a copy of the bytes is made for each frame. track : bool, optional Should the message frame(s) be tracked for notification that ZMQ has finished with it? (ignored if copy=True) Returns ------- msg_parts : list A list of frames in the multipart message; either Frames or bytes, depending on `copy`. """ parts = [self.recv(flags, copy=copy, track=track)] # have first part already, only loop while more to receive while self.getsockopt(zmq.RCVMORE): part = self.recv(flags, copy=copy, track=track) parts.append(part) return parts def send_string(self, u, flags=0, copy=True, encoding='utf-8'): """send a Python unicode string as a message with an encoding 0MQ communicates with raw bytes, so you must encode/decode text (unicode on py2, str on py3) around 0MQ. Parameters ---------- u : Python unicode string (unicode on py2, str on py3) The unicode string to send. flags : int, optional Any valid send flag. encoding : str [default: 'utf-8'] The encoding to be used """ if not isinstance(u, basestring): raise TypeError("unicode/str objects only") return self.send(u.encode(encoding), flags=flags, copy=copy) send_unicode = send_string def recv_string(self, flags=0, encoding='utf-8'): """receive a unicode string, as sent by send_string Parameters ---------- flags : int Any valid recv flag. encoding : str [default: 'utf-8'] The encoding to be used Returns ------- s : unicode string (unicode on py2, str on py3) The Python unicode string that arrives as encoded bytes. """ b = self.recv(flags=flags) return b.decode(encoding) recv_unicode = recv_string def send_pyobj(self, obj, flags=0, protocol=DEFAULT_PROTOCOL): """send a Python object as a message using pickle to serialize Parameters ---------- obj : Python object The Python object to send. flags : int Any valid send flag. protocol : int The pickle protocol number to use. The default is pickle.DEFAULT_PROTOCOl where defined, and pickle.HIGHEST_PROTOCOL elsewhere. """ msg = pickle.dumps(obj, protocol) return self.send(msg, flags) def recv_pyobj(self, flags=0): """receive a Python object as a message using pickle to serialize Parameters ---------- flags : int Any valid recv flag. Returns ------- obj : Python object The Python object that arrives as a message. """ s = self.recv(flags) return pickle.loads(s) def send_json(self, obj, flags=0, **kwargs): """send a Python object as a message using json to serialize Keyword arguments are passed on to json.dumps Parameters ---------- obj : Python object The Python object to send flags : int Any valid send flag """ msg = jsonapi.dumps(obj, **kwargs) return self.send(msg, flags) def recv_json(self, flags=0, **kwargs): """receive a Python object as a message using json to serialize Keyword arguments are passed on to json.loads Parameters ---------- flags : int Any valid recv flag. Returns ------- obj : Python object The Python object that arrives as a message. """ msg = self.recv(flags) return jsonapi.loads(msg, **kwargs) _poller_class = Poller def poll(self, timeout=None, flags=POLLIN): """poll the socket for events The default is to poll forever for incoming events. Timeout is in milliseconds, if specified. Parameters ---------- timeout : int [default: None] The timeout (in milliseconds) to wait for an event. If unspecified (or specified None), will wait forever for an event. flags : bitfield (int) [default: POLLIN] The event flags to poll for (any combination of POLLIN|POLLOUT). The default is to check for incoming events (POLLIN). Returns ------- events : bitfield (int) The events that are ready and waiting. Will be 0 if no events were ready by the time timeout was reached. """ if self.closed: raise ZMQError(ENOTSUP) p = self._poller_class() p.register(self, flags) evts = dict(p.poll(timeout)) # return 0 if no events, otherwise return event bitfield return evts.get(self, 0) def get_monitor_socket(self, events=None, addr=None): """Return a connected PAIR socket ready to receive the event notifications. .. versionadded:: libzmq-4.0 .. versionadded:: 14.0 Parameters ---------- events : bitfield (int) [default: ZMQ_EVENTS_ALL] The bitmask defining which events are wanted. addr : string [default: None] The optional endpoint for the monitoring sockets. Returns ------- socket : (PAIR) The socket is already connected and ready to receive messages. """ # safe-guard, method only available on libzmq >= 4 if zmq.zmq_version_info() < (4,): raise NotImplementedError("get_monitor_socket requires libzmq >= 4, have %s" % zmq.zmq_version()) if addr is None: # create endpoint name from internal fd addr = "inproc://monitor.s-%d" % self.FD if events is None: # use all events events = zmq.EVENT_ALL # attach monitoring socket self.monitor(addr, events) # create new PAIR socket and connect it ret = self.context.socket(zmq.PAIR) ret.connect(addr) return ret def disable_monitor(self): """Shutdown the PAIR socket (created using get_monitor_socket) that is serving socket events. .. versionadded:: 14.4 """ self.monitor(None, 0) __all__ = ['Socket']