diff options
152 files changed, 23246 insertions, 102 deletions
@@ -11,6 +11,8 @@ linux_dpdk* linux* scripts/_t-rex-* scripts/bp-sim-* +scripts/doc/* +scripts/mock-* *.pyc @@ -43,6 +45,11 @@ scripts/bp-sim-* ehthumbs.db Thumbs.db + # IDE/ Editors files # ###################### .idea/ +*.vpj +*.vpw +*.vtg +*.vpwhist @@ -1,3 +1,4 @@ -v1.74 +v1.75 + diff --git a/external_libs/json/json/json-forwards.h b/external_libs/json/json/json-forwards.h new file mode 100644 index 00000000..ccbdb2b1 --- /dev/null +++ b/external_libs/json/json/json-forwards.h @@ -0,0 +1,255 @@ +/// Json-cpp amalgated forward header (http://jsoncpp.sourceforge.net/). +/// It is intended to be used with #include "json/json-forwards.h" +/// This header provides forward declaration for all JsonCpp types. + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + +/* +The JsonCpp library's source code, including accompanying documentation, +tests and demonstration applications, are licensed under the following +conditions... + +The author (Baptiste Lepilleur) explicitly disclaims copyright in all +jurisdictions which recognize such a disclaimer. In such jurisdictions, +this software is released into the Public Domain. + +In jurisdictions which do not recognize Public Domain property (e.g. Germany as of +2010), this software is Copyright (c) 2007-2010 by Baptiste Lepilleur, and is +released under the terms of the MIT License (see below). + +In jurisdictions which recognize Public Domain property, the user of this +software may choose to accept it either as 1) Public Domain, 2) under the +conditions of the MIT License (see below), or 3) under the terms of dual +Public Domain/MIT License conditions described here, as they choose. + +The MIT License is about as close to Public Domain as a license can get, and is +described in clear, concise terms at: + + http://en.wikipedia.org/wiki/MIT_License + +The full text of the MIT License follows: + +======================================================================== +Copyright (c) 2007-2010 Baptiste Lepilleur + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, copy, +modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +======================================================================== +(END LICENSE TEXT) + +The MIT license is compatible with both the GPL and commercial +software, affording one all of the rights of Public Domain with the +minor nuisance of being required to keep the above copyright notice +and license text in the source code. Note also that by accepting the +Public Domain "license" you can re-license your copy using whatever +license you like. + +*/ + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + + + + + +#ifndef JSON_FORWARD_AMALGATED_H_INCLUDED +# define JSON_FORWARD_AMALGATED_H_INCLUDED +/// If defined, indicates that the source file is amalgated +/// to prevent private header inclusion. +#define JSON_IS_AMALGAMATION + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/config.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef JSON_CONFIG_H_INCLUDED +#define JSON_CONFIG_H_INCLUDED + +/// If defined, indicates that json library is embedded in CppTL library. +//# define JSON_IN_CPPTL 1 + +/// If defined, indicates that json may leverage CppTL library +//# define JSON_USE_CPPTL 1 +/// If defined, indicates that cpptl vector based map should be used instead of +/// std::map +/// as Value container. +//# define JSON_USE_CPPTL_SMALLMAP 1 + +// If non-zero, the library uses exceptions to report bad input instead of C +// assertion macros. The default is to use exceptions. +#ifndef JSON_USE_EXCEPTION +#define JSON_USE_EXCEPTION 1 +#endif + +/// If defined, indicates that the source file is amalgated +/// to prevent private header inclusion. +/// Remarks: it is automatically defined in the generated amalgated header. +// #define JSON_IS_AMALGAMATION + +#ifdef JSON_IN_CPPTL +#include <cpptl/config.h> +#ifndef JSON_USE_CPPTL +#define JSON_USE_CPPTL 1 +#endif +#endif + +#ifdef JSON_IN_CPPTL +#define JSON_API CPPTL_API +#elif defined(JSON_DLL_BUILD) +#if defined(_MSC_VER) +#define JSON_API __declspec(dllexport) +#define JSONCPP_DISABLE_DLL_INTERFACE_WARNING +#endif // if defined(_MSC_VER) +#elif defined(JSON_DLL) +#if defined(_MSC_VER) +#define JSON_API __declspec(dllimport) +#define JSONCPP_DISABLE_DLL_INTERFACE_WARNING +#endif // if defined(_MSC_VER) +#endif // ifdef JSON_IN_CPPTL +#if !defined(JSON_API) +#define JSON_API +#endif + +// If JSON_NO_INT64 is defined, then Json only support C++ "int" type for +// integer +// Storages, and 64 bits integer support is disabled. +// #define JSON_NO_INT64 1 + +#if defined(_MSC_VER) && _MSC_VER <= 1200 // MSVC 6 +// Microsoft Visual Studio 6 only support conversion from __int64 to double +// (no conversion from unsigned __int64). +#define JSON_USE_INT64_DOUBLE_CONVERSION 1 +// Disable warning 4786 for VS6 caused by STL (identifier was truncated to '255' +// characters in the debug information) +// All projects I've ever seen with VS6 were using this globally (not bothering +// with pragma push/pop). +#pragma warning(disable : 4786) +#endif // if defined(_MSC_VER) && _MSC_VER < 1200 // MSVC 6 + +#if defined(_MSC_VER) && _MSC_VER >= 1500 // MSVC 2008 +/// Indicates that the following function is deprecated. +#define JSONCPP_DEPRECATED(message) __declspec(deprecated(message)) +#elif defined(__clang__) && defined(__has_feature) +#if __has_feature(attribute_deprecated_with_message) +#define JSONCPP_DEPRECATED(message) __attribute__ ((deprecated(message))) +#endif +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5)) +#define JSONCPP_DEPRECATED(message) __attribute__ ((deprecated(message))) +#elif defined(__GNUC__) && (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1)) +#define JSONCPP_DEPRECATED(message) __attribute__((__deprecated__)) +#endif + +#if !defined(JSONCPP_DEPRECATED) +#define JSONCPP_DEPRECATED(message) +#endif // if !defined(JSONCPP_DEPRECATED) + +namespace Json { +typedef int Int; +typedef unsigned int UInt; +#if defined(JSON_NO_INT64) +typedef int LargestInt; +typedef unsigned int LargestUInt; +#undef JSON_HAS_INT64 +#else // if defined(JSON_NO_INT64) +// For Microsoft Visual use specific types as long long is not supported +#if defined(_MSC_VER) // Microsoft Visual Studio +typedef __int64 Int64; +typedef unsigned __int64 UInt64; +#else // if defined(_MSC_VER) // Other platforms, use long long +typedef long long int Int64; +typedef unsigned long long int UInt64; +#endif // if defined(_MSC_VER) +typedef Int64 LargestInt; +typedef UInt64 LargestUInt; +#define JSON_HAS_INT64 +#endif // if defined(JSON_NO_INT64) +} // end namespace Json + +#endif // JSON_CONFIG_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/config.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/forwards.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef JSON_FORWARDS_H_INCLUDED +#define JSON_FORWARDS_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "config.h" +#endif // if !defined(JSON_IS_AMALGAMATION) + +namespace Json { + +// writer.h +class FastWriter; +class StyledWriter; + +// reader.h +class Reader; + +// features.h +class Features; + +// value.h +typedef unsigned int ArrayIndex; +class StaticString; +class Path; +class PathArgument; +class Value; +class ValueIteratorBase; +class ValueIterator; +class ValueConstIterator; + +} // namespace Json + +#endif // JSON_FORWARDS_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/forwards.h +// ////////////////////////////////////////////////////////////////////// + + + + + +#endif //ifndef JSON_FORWARD_AMALGATED_H_INCLUDED diff --git a/external_libs/json/json/json.h b/external_libs/json/json/json.h new file mode 100644 index 00000000..e01991e0 --- /dev/null +++ b/external_libs/json/json/json.h @@ -0,0 +1,2017 @@ +/// Json-cpp amalgated header (http://jsoncpp.sourceforge.net/). +/// It is intended to be used with #include "json/json.h" + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + +/* +The JsonCpp library's source code, including accompanying documentation, +tests and demonstration applications, are licensed under the following +conditions... + +The author (Baptiste Lepilleur) explicitly disclaims copyright in all +jurisdictions which recognize such a disclaimer. In such jurisdictions, +this software is released into the Public Domain. + +In jurisdictions which do not recognize Public Domain property (e.g. Germany as of +2010), this software is Copyright (c) 2007-2010 by Baptiste Lepilleur, and is +released under the terms of the MIT License (see below). + +In jurisdictions which recognize Public Domain property, the user of this +software may choose to accept it either as 1) Public Domain, 2) under the +conditions of the MIT License (see below), or 3) under the terms of dual +Public Domain/MIT License conditions described here, as they choose. + +The MIT License is about as close to Public Domain as a license can get, and is +described in clear, concise terms at: + + http://en.wikipedia.org/wiki/MIT_License + +The full text of the MIT License follows: + +======================================================================== +Copyright (c) 2007-2010 Baptiste Lepilleur + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, copy, +modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +======================================================================== +(END LICENSE TEXT) + +The MIT license is compatible with both the GPL and commercial +software, affording one all of the rights of Public Domain with the +minor nuisance of being required to keep the above copyright notice +and license text in the source code. Note also that by accepting the +Public Domain "license" you can re-license your copy using whatever +license you like. + +*/ + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + + + + + +#ifndef JSON_AMALGATED_H_INCLUDED +# define JSON_AMALGATED_H_INCLUDED +/// If defined, indicates that the source file is amalgated +/// to prevent private header inclusion. +#define JSON_IS_AMALGAMATION + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/version.h +// ////////////////////////////////////////////////////////////////////// + +// DO NOT EDIT. This file is generated by CMake from "version" +// and "version.h.in" files. +// Run CMake configure step to update it. +#ifndef JSON_VERSION_H_INCLUDED +# define JSON_VERSION_H_INCLUDED + +# define JSONCPP_VERSION_STRING "1.6.2" +# define JSONCPP_VERSION_MAJOR 1 +# define JSONCPP_VERSION_MINOR 6 +# define JSONCPP_VERSION_PATCH 2 +# define JSONCPP_VERSION_QUALIFIER +# define JSONCPP_VERSION_HEXA ((JSONCPP_VERSION_MAJOR << 24) | (JSONCPP_VERSION_MINOR << 16) | (JSONCPP_VERSION_PATCH << 8)) + +#endif // JSON_VERSION_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/version.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/config.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef JSON_CONFIG_H_INCLUDED +#define JSON_CONFIG_H_INCLUDED + +/// If defined, indicates that json library is embedded in CppTL library. +//# define JSON_IN_CPPTL 1 + +/// If defined, indicates that json may leverage CppTL library +//# define JSON_USE_CPPTL 1 +/// If defined, indicates that cpptl vector based map should be used instead of +/// std::map +/// as Value container. +//# define JSON_USE_CPPTL_SMALLMAP 1 + +// If non-zero, the library uses exceptions to report bad input instead of C +// assertion macros. The default is to use exceptions. +#ifndef JSON_USE_EXCEPTION +#define JSON_USE_EXCEPTION 1 +#endif + +/// If defined, indicates that the source file is amalgated +/// to prevent private header inclusion. +/// Remarks: it is automatically defined in the generated amalgated header. +// #define JSON_IS_AMALGAMATION + +#ifdef JSON_IN_CPPTL +#include <cpptl/config.h> +#ifndef JSON_USE_CPPTL +#define JSON_USE_CPPTL 1 +#endif +#endif + +#ifdef JSON_IN_CPPTL +#define JSON_API CPPTL_API +#elif defined(JSON_DLL_BUILD) +#if defined(_MSC_VER) +#define JSON_API __declspec(dllexport) +#define JSONCPP_DISABLE_DLL_INTERFACE_WARNING +#endif // if defined(_MSC_VER) +#elif defined(JSON_DLL) +#if defined(_MSC_VER) +#define JSON_API __declspec(dllimport) +#define JSONCPP_DISABLE_DLL_INTERFACE_WARNING +#endif // if defined(_MSC_VER) +#endif // ifdef JSON_IN_CPPTL +#if !defined(JSON_API) +#define JSON_API +#endif + +// If JSON_NO_INT64 is defined, then Json only support C++ "int" type for +// integer +// Storages, and 64 bits integer support is disabled. +// #define JSON_NO_INT64 1 + +#if defined(_MSC_VER) && _MSC_VER <= 1200 // MSVC 6 +// Microsoft Visual Studio 6 only support conversion from __int64 to double +// (no conversion from unsigned __int64). +#define JSON_USE_INT64_DOUBLE_CONVERSION 1 +// Disable warning 4786 for VS6 caused by STL (identifier was truncated to '255' +// characters in the debug information) +// All projects I've ever seen with VS6 were using this globally (not bothering +// with pragma push/pop). +#pragma warning(disable : 4786) +#endif // if defined(_MSC_VER) && _MSC_VER < 1200 // MSVC 6 + +#if defined(_MSC_VER) && _MSC_VER >= 1500 // MSVC 2008 +/// Indicates that the following function is deprecated. +#define JSONCPP_DEPRECATED(message) __declspec(deprecated(message)) +#elif defined(__clang__) && defined(__has_feature) +#if __has_feature(attribute_deprecated_with_message) +#define JSONCPP_DEPRECATED(message) __attribute__ ((deprecated(message))) +#endif +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5)) +#define JSONCPP_DEPRECATED(message) __attribute__ ((deprecated(message))) +#elif defined(__GNUC__) && (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1)) +#define JSONCPP_DEPRECATED(message) __attribute__((__deprecated__)) +#endif + +#if !defined(JSONCPP_DEPRECATED) +#define JSONCPP_DEPRECATED(message) +#endif // if !defined(JSONCPP_DEPRECATED) + +namespace Json { +typedef int Int; +typedef unsigned int UInt; +#if defined(JSON_NO_INT64) +typedef int LargestInt; +typedef unsigned int LargestUInt; +#undef JSON_HAS_INT64 +#else // if defined(JSON_NO_INT64) +// For Microsoft Visual use specific types as long long is not supported +#if defined(_MSC_VER) // Microsoft Visual Studio +typedef __int64 Int64; +typedef unsigned __int64 UInt64; +#else // if defined(_MSC_VER) // Other platforms, use long long +typedef long long int Int64; +typedef unsigned long long int UInt64; +#endif // if defined(_MSC_VER) +typedef Int64 LargestInt; +typedef UInt64 LargestUInt; +#define JSON_HAS_INT64 +#endif // if defined(JSON_NO_INT64) +} // end namespace Json + +#endif // JSON_CONFIG_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/config.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/forwards.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef JSON_FORWARDS_H_INCLUDED +#define JSON_FORWARDS_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "config.h" +#endif // if !defined(JSON_IS_AMALGAMATION) + +namespace Json { + +// writer.h +class FastWriter; +class StyledWriter; + +// reader.h +class Reader; + +// features.h +class Features; + +// value.h +typedef unsigned int ArrayIndex; +class StaticString; +class Path; +class PathArgument; +class Value; +class ValueIteratorBase; +class ValueIterator; +class ValueConstIterator; + +} // namespace Json + +#endif // JSON_FORWARDS_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/forwards.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/features.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef CPPTL_JSON_FEATURES_H_INCLUDED +#define CPPTL_JSON_FEATURES_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "forwards.h" +#endif // if !defined(JSON_IS_AMALGAMATION) + +namespace Json { + +/** \brief Configuration passed to reader and writer. + * This configuration object can be used to force the Reader or Writer + * to behave in a standard conforming way. + */ +class JSON_API Features { +public: + /** \brief A configuration that allows all features and assumes all strings + * are UTF-8. + * - C & C++ comments are allowed + * - Root object can be any JSON value + * - Assumes Value strings are encoded in UTF-8 + */ + static Features all(); + + /** \brief A configuration that is strictly compatible with the JSON + * specification. + * - Comments are forbidden. + * - Root object must be either an array or an object value. + * - Assumes Value strings are encoded in UTF-8 + */ + static Features strictMode(); + + /** \brief Initialize the configuration like JsonConfig::allFeatures; + */ + Features(); + + /// \c true if comments are allowed. Default: \c true. + bool allowComments_; + + /// \c true if root must be either an array or an object value. Default: \c + /// false. + bool strictRoot_; + + /// \c true if dropped null placeholders are allowed. Default: \c false. + bool allowDroppedNullPlaceholders_; + + /// \c true if numeric object key are allowed. Default: \c false. + bool allowNumericKeys_; +}; + +} // namespace Json + +#endif // CPPTL_JSON_FEATURES_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/features.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/value.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef CPPTL_JSON_H_INCLUDED +#define CPPTL_JSON_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "forwards.h" +#endif // if !defined(JSON_IS_AMALGAMATION) +#include <string> +#include <vector> +#include <exception> + +#ifndef JSON_USE_CPPTL_SMALLMAP +#include <map> +#else +#include <cpptl/smallmap.h> +#endif +#ifdef JSON_USE_CPPTL +#include <cpptl/forwards.h> +#endif + +// Disable warning C4251: <data member>: <type> needs to have dll-interface to +// be used by... +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(push) +#pragma warning(disable : 4251) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +/** \brief JSON (JavaScript Object Notation). + */ +namespace Json { + +/** Base class for all exceptions we throw. + * + * We use nothing but these internally. Of course, STL can throw others. + */ +class JSON_API Exception; +/** Exceptions which the user cannot easily avoid. + * + * E.g. out-of-memory (when we use malloc), stack-overflow, malicious input + * + * \remark derived from Json::Exception + */ +class JSON_API RuntimeError; +/** Exceptions thrown by JSON_ASSERT/JSON_FAIL macros. + * + * These are precondition-violations (user bugs) and internal errors (our bugs). + * + * \remark derived from Json::Exception + */ +class JSON_API LogicError; + +/// used internally +void throwRuntimeError(std::string const& msg); +/// used internally +void throwLogicError(std::string const& msg); + +/** \brief Type of the value held by a Value object. + */ +enum ValueType { + nullValue = 0, ///< 'null' value + intValue, ///< signed integer value + uintValue, ///< unsigned integer value + realValue, ///< double value + stringValue, ///< UTF-8 string value + booleanValue, ///< bool value + arrayValue, ///< array value (ordered list) + objectValue ///< object value (collection of name/value pairs). +}; + +enum CommentPlacement { + commentBefore = 0, ///< a comment placed on the line before a value + commentAfterOnSameLine, ///< a comment just after a value on the same line + commentAfter, ///< a comment on the line after a value (only make sense for + /// root value) + numberOfCommentPlacement +}; + +//# ifdef JSON_USE_CPPTL +// typedef CppTL::AnyEnumerator<const char *> EnumMemberNames; +// typedef CppTL::AnyEnumerator<const Value &> EnumValues; +//# endif + +/** \brief Lightweight wrapper to tag static string. + * + * Value constructor and objectValue member assignement takes advantage of the + * StaticString and avoid the cost of string duplication when storing the + * string or the member name. + * + * Example of usage: + * \code + * Json::Value aValue( StaticString("some text") ); + * Json::Value object; + * static const StaticString code("code"); + * object[code] = 1234; + * \endcode + */ +class JSON_API StaticString { +public: + explicit StaticString(const char* czstring) : c_str_(czstring) {} + + operator const char*() const { return c_str_; } + + const char* c_str() const { return c_str_; } + +private: + const char* c_str_; +}; + +/** \brief Represents a <a HREF="http://www.json.org">JSON</a> value. + * + * This class is a discriminated union wrapper that can represents a: + * - signed integer [range: Value::minInt - Value::maxInt] + * - unsigned integer (range: 0 - Value::maxUInt) + * - double + * - UTF-8 string + * - boolean + * - 'null' + * - an ordered list of Value + * - collection of name/value pairs (javascript object) + * + * The type of the held value is represented by a #ValueType and + * can be obtained using type(). + * + * Values of an #objectValue or #arrayValue can be accessed using operator[]() + * methods. + * Non-const methods will automatically create the a #nullValue element + * if it does not exist. + * The sequence of an #arrayValue will be automatically resized and initialized + * with #nullValue. resize() can be used to enlarge or truncate an #arrayValue. + * + * The get() methods can be used to obtain default value in the case the + * required element does not exist. + * + * It is possible to iterate over the list of a #objectValue values using + * the getMemberNames() method. + * + * \note #Value string-length fit in size_t, but keys must be < 2^30. + * (The reason is an implementation detail.) A #CharReader will raise an + * exception if a bound is exceeded to avoid security holes in your app, + * but the Value API does *not* check bounds. That is the responsibility + * of the caller. + */ +class JSON_API Value { + friend class ValueIteratorBase; +public: + typedef std::vector<std::string> Members; + typedef ValueIterator iterator; + typedef ValueConstIterator const_iterator; + typedef Json::UInt UInt; + typedef Json::Int Int; +#if defined(JSON_HAS_INT64) + typedef Json::UInt64 UInt64; + typedef Json::Int64 Int64; +#endif // defined(JSON_HAS_INT64) + typedef Json::LargestInt LargestInt; + typedef Json::LargestUInt LargestUInt; + typedef Json::ArrayIndex ArrayIndex; + + static const Value& null; ///< We regret this reference to a global instance; prefer the simpler Value(). + static const Value& nullRef; ///< just a kludge for binary-compatibility; same as null + /// Minimum signed integer value that can be stored in a Json::Value. + static const LargestInt minLargestInt; + /// Maximum signed integer value that can be stored in a Json::Value. + static const LargestInt maxLargestInt; + /// Maximum unsigned integer value that can be stored in a Json::Value. + static const LargestUInt maxLargestUInt; + + /// Minimum signed int value that can be stored in a Json::Value. + static const Int minInt; + /// Maximum signed int value that can be stored in a Json::Value. + static const Int maxInt; + /// Maximum unsigned int value that can be stored in a Json::Value. + static const UInt maxUInt; + +#if defined(JSON_HAS_INT64) + /// Minimum signed 64 bits int value that can be stored in a Json::Value. + static const Int64 minInt64; + /// Maximum signed 64 bits int value that can be stored in a Json::Value. + static const Int64 maxInt64; + /// Maximum unsigned 64 bits int value that can be stored in a Json::Value. + static const UInt64 maxUInt64; +#endif // defined(JSON_HAS_INT64) + +private: +#ifndef JSONCPP_DOC_EXCLUDE_IMPLEMENTATION + class CZString { + public: + enum DuplicationPolicy { + noDuplication = 0, + duplicate, + duplicateOnCopy + }; + CZString(ArrayIndex index); + CZString(char const* str, unsigned length, DuplicationPolicy allocate); + CZString(CZString const& other); + ~CZString(); + CZString& operator=(CZString other); + bool operator<(CZString const& other) const; + bool operator==(CZString const& other) const; + ArrayIndex index() const; + //const char* c_str() const; ///< \deprecated + char const* data() const; + unsigned length() const; + bool isStaticString() const; + + private: + void swap(CZString& other); + + struct StringStorage { + unsigned policy_: 2; + unsigned length_: 30; // 1GB max + }; + + char const* cstr_; // actually, a prefixed string, unless policy is noDup + union { + ArrayIndex index_; + StringStorage storage_; + }; + }; + +public: +#ifndef JSON_USE_CPPTL_SMALLMAP + typedef std::map<CZString, Value> ObjectValues; +#else + typedef CppTL::SmallMap<CZString, Value> ObjectValues; +#endif // ifndef JSON_USE_CPPTL_SMALLMAP +#endif // ifndef JSONCPP_DOC_EXCLUDE_IMPLEMENTATION + +public: + /** \brief Create a default Value of the given type. + + This is a very useful constructor. + To create an empty array, pass arrayValue. + To create an empty object, pass objectValue. + Another Value can then be set to this one by assignment. +This is useful since clear() and resize() will not alter types. + + Examples: +\code +Json::Value null_value; // null +Json::Value arr_value(Json::arrayValue); // [] +Json::Value obj_value(Json::objectValue); // {} +\endcode + */ + Value(ValueType type = nullValue); + Value(Int value); + Value(UInt value); +#if defined(JSON_HAS_INT64) + Value(Int64 value); + Value(UInt64 value); +#endif // if defined(JSON_HAS_INT64) + Value(double value); + Value(const char* value); ///< Copy til first 0. (NULL causes to seg-fault.) + Value(const char* beginValue, const char* endValue); ///< Copy all, incl zeroes. + /** \brief Constructs a value from a static string. + + * Like other value string constructor but do not duplicate the string for + * internal storage. The given string must remain alive after the call to this + * constructor. + * \note This works only for null-terminated strings. (We cannot change the + * size of this class, so we have nowhere to store the length, + * which might be computed later for various operations.) + * + * Example of usage: + * \code + * static StaticString foo("some text"); + * Json::Value aValue(foo); + * \endcode + */ + Value(const StaticString& value); + Value(const std::string& value); ///< Copy data() til size(). Embedded zeroes too. +#ifdef JSON_USE_CPPTL + Value(const CppTL::ConstString& value); +#endif + Value(bool value); + /// Deep copy. + Value(const Value& other); + ~Value(); + + /// Deep copy, then swap(other). + /// \note Over-write existing comments. To preserve comments, use #swapPayload(). + Value& operator=(Value other); + /// Swap everything. + void swap(Value& other); + /// Swap values but leave comments and source offsets in place. + void swapPayload(Value& other); + + ValueType type() const; + + /// Compare payload only, not comments etc. + bool operator<(const Value& other) const; + bool operator<=(const Value& other) const; + bool operator>=(const Value& other) const; + bool operator>(const Value& other) const; + bool operator==(const Value& other) const; + bool operator!=(const Value& other) const; + int compare(const Value& other) const; + + const char* asCString() const; ///< Embedded zeroes could cause you trouble! + std::string asString() const; ///< Embedded zeroes are possible. + /** Get raw char* of string-value. + * \return false if !string. (Seg-fault if str or end are NULL.) + */ + bool getString( + char const** str, char const** end) const; +#ifdef JSON_USE_CPPTL + CppTL::ConstString asConstString() const; +#endif + Int asInt() const; + UInt asUInt() const; +#if defined(JSON_HAS_INT64) + Int64 asInt64() const; + UInt64 asUInt64() const; +#endif // if defined(JSON_HAS_INT64) + LargestInt asLargestInt() const; + LargestUInt asLargestUInt() const; + float asFloat() const; + double asDouble() const; + bool asBool() const; + + bool isNull() const; + bool isBool() const; + bool isInt() const; + bool isInt64() const; + bool isUInt() const; + bool isUInt64() const; + bool isIntegral() const; + bool isDouble() const; + bool isNumeric() const; + bool isString() const; + bool isArray() const; + bool isObject() const; + + bool isConvertibleTo(ValueType other) const; + + /// Number of values in array or object + ArrayIndex size() const; + + /// \brief Return true if empty array, empty object, or null; + /// otherwise, false. + bool empty() const; + + /// Return isNull() + bool operator!() const; + + /// Remove all object members and array elements. + /// \pre type() is arrayValue, objectValue, or nullValue + /// \post type() is unchanged + void clear(); + + /// Resize the array to size elements. + /// New elements are initialized to null. + /// May only be called on nullValue or arrayValue. + /// \pre type() is arrayValue or nullValue + /// \post type() is arrayValue + void resize(ArrayIndex size); + + /// Access an array element (zero based index ). + /// If the array contains less than index element, then null value are + /// inserted + /// in the array so that its size is index+1. + /// (You may need to say 'value[0u]' to get your compiler to distinguish + /// this from the operator[] which takes a string.) + Value& operator[](ArrayIndex index); + + /// Access an array element (zero based index ). + /// If the array contains less than index element, then null value are + /// inserted + /// in the array so that its size is index+1. + /// (You may need to say 'value[0u]' to get your compiler to distinguish + /// this from the operator[] which takes a string.) + Value& operator[](int index); + + /// Access an array element (zero based index ) + /// (You may need to say 'value[0u]' to get your compiler to distinguish + /// this from the operator[] which takes a string.) + const Value& operator[](ArrayIndex index) const; + + /// Access an array element (zero based index ) + /// (You may need to say 'value[0u]' to get your compiler to distinguish + /// this from the operator[] which takes a string.) + const Value& operator[](int index) const; + + /// If the array contains at least index+1 elements, returns the element + /// value, + /// otherwise returns defaultValue. + Value get(ArrayIndex index, const Value& defaultValue) const; + /// Return true if index < size(). + bool isValidIndex(ArrayIndex index) const; + /// \brief Append value to array at the end. + /// + /// Equivalent to jsonvalue[jsonvalue.size()] = value; + Value& append(const Value& value); + + /// Access an object value by name, create a null member if it does not exist. + /// \note Because of our implementation, keys are limited to 2^30 -1 chars. + /// Exceeding that will cause an exception. + Value& operator[](const char* key); + /// Access an object value by name, returns null if there is no member with + /// that name. + const Value& operator[](const char* key) const; + /// Access an object value by name, create a null member if it does not exist. + /// \param key may contain embedded nulls. + Value& operator[](const std::string& key); + /// Access an object value by name, returns null if there is no member with + /// that name. + /// \param key may contain embedded nulls. + const Value& operator[](const std::string& key) const; + /** \brief Access an object value by name, create a null member if it does not + exist. + + * If the object has no entry for that name, then the member name used to store + * the new entry is not duplicated. + * Example of use: + * \code + * Json::Value object; + * static const StaticString code("code"); + * object[code] = 1234; + * \endcode + */ + Value& operator[](const StaticString& key); +#ifdef JSON_USE_CPPTL + /// Access an object value by name, create a null member if it does not exist. + Value& operator[](const CppTL::ConstString& key); + /// Access an object value by name, returns null if there is no member with + /// that name. + const Value& operator[](const CppTL::ConstString& key) const; +#endif + /// Return the member named key if it exist, defaultValue otherwise. + /// \note deep copy + Value get(const char* key, const Value& defaultValue) const; + /// Return the member named key if it exist, defaultValue otherwise. + /// \note deep copy + /// \param key may contain embedded nulls. + Value get(const char* key, const char* end, const Value& defaultValue) const; + /// Return the member named key if it exist, defaultValue otherwise. + /// \note deep copy + /// \param key may contain embedded nulls. + Value get(const std::string& key, const Value& defaultValue) const; +#ifdef JSON_USE_CPPTL + /// Return the member named key if it exist, defaultValue otherwise. + /// \note deep copy + Value get(const CppTL::ConstString& key, const Value& defaultValue) const; +#endif + /// Most general and efficient version of isMember()const, get()const, + /// and operator[]const + /// \note As stated elsewhere, behavior is undefined if (end-key) >= 2^30 + Value const* find(char const* key, char const* end) const; + /// Most general and efficient version of object-mutators. + /// \note As stated elsewhere, behavior is undefined if (end-key) >= 2^30 + /// \return non-zero, but JSON_ASSERT if this is neither object nor nullValue. + Value const* demand(char const* key, char const* end); + /// \brief Remove and return the named member. + /// + /// Do nothing if it did not exist. + /// \return the removed Value, or null. + /// \pre type() is objectValue or nullValue + /// \post type() is unchanged + /// \deprecated + Value removeMember(const char* key); + /// Same as removeMember(const char*) + /// \param key may contain embedded nulls. + /// \deprecated + Value removeMember(const std::string& key); + /// Same as removeMember(const char* key, const char* end, Value* removed), + /// but 'key' is null-terminated. + bool removeMember(const char* key, Value* removed); + /** \brief Remove the named map member. + + Update 'removed' iff removed. + \param key may contain embedded nulls. + \return true iff removed (no exceptions) + */ + bool removeMember(std::string const& key, Value* removed); + /// Same as removeMember(std::string const& key, Value* removed) + bool removeMember(const char* key, const char* end, Value* removed); + /** \brief Remove the indexed array element. + + O(n) expensive operations. + Update 'removed' iff removed. + \return true iff removed (no exceptions) + */ + bool removeIndex(ArrayIndex i, Value* removed); + + /// Return true if the object has a member named key. + /// \note 'key' must be null-terminated. + bool isMember(const char* key) const; + /// Return true if the object has a member named key. + /// \param key may contain embedded nulls. + bool isMember(const std::string& key) const; + /// Same as isMember(std::string const& key)const + bool isMember(const char* key, const char* end) const; +#ifdef JSON_USE_CPPTL + /// Return true if the object has a member named key. + bool isMember(const CppTL::ConstString& key) const; +#endif + + /// \brief Return a list of the member names. + /// + /// If null, return an empty list. + /// \pre type() is objectValue or nullValue + /// \post if type() was nullValue, it remains nullValue + Members getMemberNames() const; + + //# ifdef JSON_USE_CPPTL + // EnumMemberNames enumMemberNames() const; + // EnumValues enumValues() const; + //# endif + + /// \deprecated Always pass len. + JSONCPP_DEPRECATED("Use setComment(std::string const&) instead.") + void setComment(const char* comment, CommentPlacement placement); + /// Comments must be //... or /* ... */ + void setComment(const char* comment, size_t len, CommentPlacement placement); + /// Comments must be //... or /* ... */ + void setComment(const std::string& comment, CommentPlacement placement); + bool hasComment(CommentPlacement placement) const; + /// Include delimiters and embedded newlines. + std::string getComment(CommentPlacement placement) const; + + std::string toStyledString() const; + + const_iterator begin() const; + const_iterator end() const; + + iterator begin(); + iterator end(); + + // Accessors for the [start, limit) range of bytes within the JSON text from + // which this value was parsed, if any. + void setOffsetStart(size_t start); + void setOffsetLimit(size_t limit); + size_t getOffsetStart() const; + size_t getOffsetLimit() const; + +private: + void initBasic(ValueType type, bool allocated = false); + + Value& resolveReference(const char* key); + Value& resolveReference(const char* key, const char* end); + + struct CommentInfo { + CommentInfo(); + ~CommentInfo(); + + void setComment(const char* text, size_t len); + + char* comment_; + }; + + // struct MemberNamesTransform + //{ + // typedef const char *result_type; + // const char *operator()( const CZString &name ) const + // { + // return name.c_str(); + // } + //}; + + union ValueHolder { + LargestInt int_; + LargestUInt uint_; + double real_; + bool bool_; + char* string_; // actually ptr to unsigned, followed by str, unless !allocated_ + ObjectValues* map_; + } value_; + ValueType type_ : 8; + unsigned int allocated_ : 1; // Notes: if declared as bool, bitfield is useless. + // If not allocated_, string_ must be null-terminated. + CommentInfo* comments_; + + // [start, limit) byte offsets in the source JSON text from which this Value + // was extracted. + size_t start_; + size_t limit_; +}; + +/** \brief Experimental and untested: represents an element of the "path" to + * access a node. + */ +class JSON_API PathArgument { +public: + friend class Path; + + PathArgument(); + PathArgument(ArrayIndex index); + PathArgument(const char* key); + PathArgument(const std::string& key); + +private: + enum Kind { + kindNone = 0, + kindIndex, + kindKey + }; + std::string key_; + ArrayIndex index_; + Kind kind_; +}; + +/** \brief Experimental and untested: represents a "path" to access a node. + * + * Syntax: + * - "." => root node + * - ".[n]" => elements at index 'n' of root node (an array value) + * - ".name" => member named 'name' of root node (an object value) + * - ".name1.name2.name3" + * - ".[0][1][2].name1[3]" + * - ".%" => member name is provided as parameter + * - ".[%]" => index is provied as parameter + */ +class JSON_API Path { +public: + Path(const std::string& path, + const PathArgument& a1 = PathArgument(), + const PathArgument& a2 = PathArgument(), + const PathArgument& a3 = PathArgument(), + const PathArgument& a4 = PathArgument(), + const PathArgument& a5 = PathArgument()); + + const Value& resolve(const Value& root) const; + Value resolve(const Value& root, const Value& defaultValue) const; + /// Creates the "path" to access the specified node and returns a reference on + /// the node. + Value& make(Value& root) const; + +private: + typedef std::vector<const PathArgument*> InArgs; + typedef std::vector<PathArgument> Args; + + void makePath(const std::string& path, const InArgs& in); + void addPathInArg(const std::string& path, + const InArgs& in, + InArgs::const_iterator& itInArg, + PathArgument::Kind kind); + void invalidPath(const std::string& path, int location); + + Args args_; +}; + +/** \brief base class for Value iterators. + * + */ +class JSON_API ValueIteratorBase { +public: + typedef std::bidirectional_iterator_tag iterator_category; + typedef unsigned int size_t; + typedef int difference_type; + typedef ValueIteratorBase SelfType; + + bool operator==(const SelfType& other) const { return isEqual(other); } + + bool operator!=(const SelfType& other) const { return !isEqual(other); } + + difference_type operator-(const SelfType& other) const { + return other.computeDistance(*this); + } + + /// Return either the index or the member name of the referenced value as a + /// Value. + Value key() const; + + /// Return the index of the referenced Value, or -1 if it is not an arrayValue. + UInt index() const; + + /// Return the member name of the referenced Value, or "" if it is not an + /// objectValue. + /// \note Avoid `c_str()` on result, as embedded zeroes are possible. + std::string name() const; + + /// Return the member name of the referenced Value. "" if it is not an + /// objectValue. + /// \deprecated This cannot be used for UTF-8 strings, since there can be embedded nulls. + JSONCPP_DEPRECATED("Use `key = name();` instead.") + char const* memberName() const; + /// Return the member name of the referenced Value, or NULL if it is not an + /// objectValue. + /// \note Better version than memberName(). Allows embedded nulls. + char const* memberName(char const** end) const; + +protected: + Value& deref() const; + + void increment(); + + void decrement(); + + difference_type computeDistance(const SelfType& other) const; + + bool isEqual(const SelfType& other) const; + + void copy(const SelfType& other); + +private: + Value::ObjectValues::iterator current_; + // Indicates that iterator is for a null value. + bool isNull_; + +public: + // For some reason, BORLAND needs these at the end, rather + // than earlier. No idea why. + ValueIteratorBase(); + explicit ValueIteratorBase(const Value::ObjectValues::iterator& current); +}; + +/** \brief const iterator for object and array value. + * + */ +class JSON_API ValueConstIterator : public ValueIteratorBase { + friend class Value; + +public: + typedef const Value value_type; + //typedef unsigned int size_t; + //typedef int difference_type; + typedef const Value& reference; + typedef const Value* pointer; + typedef ValueConstIterator SelfType; + + ValueConstIterator(); + +private: +/*! \internal Use by Value to create an iterator. + */ + explicit ValueConstIterator(const Value::ObjectValues::iterator& current); +public: + SelfType& operator=(const ValueIteratorBase& other); + + SelfType operator++(int) { + SelfType temp(*this); + ++*this; + return temp; + } + + SelfType operator--(int) { + SelfType temp(*this); + --*this; + return temp; + } + + SelfType& operator--() { + decrement(); + return *this; + } + + SelfType& operator++() { + increment(); + return *this; + } + + reference operator*() const { return deref(); } + + pointer operator->() const { return &deref(); } +}; + +/** \brief Iterator for object and array value. + */ +class JSON_API ValueIterator : public ValueIteratorBase { + friend class Value; + +public: + typedef Value value_type; + typedef unsigned int size_t; + typedef int difference_type; + typedef Value& reference; + typedef Value* pointer; + typedef ValueIterator SelfType; + + ValueIterator(); + ValueIterator(const ValueConstIterator& other); + ValueIterator(const ValueIterator& other); + +private: +/*! \internal Use by Value to create an iterator. + */ + explicit ValueIterator(const Value::ObjectValues::iterator& current); +public: + SelfType& operator=(const SelfType& other); + + SelfType operator++(int) { + SelfType temp(*this); + ++*this; + return temp; + } + + SelfType operator--(int) { + SelfType temp(*this); + --*this; + return temp; + } + + SelfType& operator--() { + decrement(); + return *this; + } + + SelfType& operator++() { + increment(); + return *this; + } + + reference operator*() const { return deref(); } + + pointer operator->() const { return &deref(); } +}; + +} // namespace Json + + +namespace std { +/// Specialize std::swap() for Json::Value. +template<> +inline void swap(Json::Value& a, Json::Value& b) { a.swap(b); } +} + + +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(pop) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +#endif // CPPTL_JSON_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/value.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/reader.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef CPPTL_JSON_READER_H_INCLUDED +#define CPPTL_JSON_READER_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "features.h" +#include "value.h" +#endif // if !defined(JSON_IS_AMALGAMATION) +#include <deque> +#include <iosfwd> +#include <stack> +#include <string> +#include <istream> + +// Disable warning C4251: <data member>: <type> needs to have dll-interface to +// be used by... +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(push) +#pragma warning(disable : 4251) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +namespace Json { + +/** \brief Unserialize a <a HREF="http://www.json.org">JSON</a> document into a + *Value. + * + * \deprecated Use CharReader and CharReaderBuilder. + */ +class JSON_API Reader { +public: + typedef char Char; + typedef const Char* Location; + + /** \brief An error tagged with where in the JSON text it was encountered. + * + * The offsets give the [start, limit) range of bytes within the text. Note + * that this is bytes, not codepoints. + * + */ + struct StructuredError { + size_t offset_start; + size_t offset_limit; + std::string message; + }; + + /** \brief Constructs a Reader allowing all features + * for parsing. + */ + Reader(); + + /** \brief Constructs a Reader allowing the specified feature set + * for parsing. + */ + Reader(const Features& features); + + /** \brief Read a Value from a <a HREF="http://www.json.org">JSON</a> + * document. + * \param document UTF-8 encoded string containing the document to read. + * \param root [out] Contains the root value of the document if it was + * successfully parsed. + * \param collectComments \c true to collect comment and allow writing them + * back during + * serialization, \c false to discard comments. + * This parameter is ignored if + * Features::allowComments_ + * is \c false. + * \return \c true if the document was successfully parsed, \c false if an + * error occurred. + */ + bool + parse(const std::string& document, Value& root, bool collectComments = true); + + /** \brief Read a Value from a <a HREF="http://www.json.org">JSON</a> + document. + * \param beginDoc Pointer on the beginning of the UTF-8 encoded string of the + document to read. + * \param endDoc Pointer on the end of the UTF-8 encoded string of the + document to read. + * Must be >= beginDoc. + * \param root [out] Contains the root value of the document if it was + * successfully parsed. + * \param collectComments \c true to collect comment and allow writing them + back during + * serialization, \c false to discard comments. + * This parameter is ignored if + Features::allowComments_ + * is \c false. + * \return \c true if the document was successfully parsed, \c false if an + error occurred. + */ + bool parse(const char* beginDoc, + const char* endDoc, + Value& root, + bool collectComments = true); + + /// \brief Parse from input stream. + /// \see Json::operator>>(std::istream&, Json::Value&). + bool parse(std::istream& is, Value& root, bool collectComments = true); + + /** \brief Returns a user friendly string that list errors in the parsed + * document. + * \return Formatted error message with the list of errors with their location + * in + * the parsed document. An empty string is returned if no error + * occurred + * during parsing. + * \deprecated Use getFormattedErrorMessages() instead (typo fix). + */ + JSONCPP_DEPRECATED("Use getFormattedErrorMessages() instead.") + std::string getFormatedErrorMessages() const; + + /** \brief Returns a user friendly string that list errors in the parsed + * document. + * \return Formatted error message with the list of errors with their location + * in + * the parsed document. An empty string is returned if no error + * occurred + * during parsing. + */ + std::string getFormattedErrorMessages() const; + + /** \brief Returns a vector of structured erros encounted while parsing. + * \return A (possibly empty) vector of StructuredError objects. Currently + * only one error can be returned, but the caller should tolerate + * multiple + * errors. This can occur if the parser recovers from a non-fatal + * parse error and then encounters additional errors. + */ + std::vector<StructuredError> getStructuredErrors() const; + + /** \brief Add a semantic error message. + * \param value JSON Value location associated with the error + * \param message The error message. + * \return \c true if the error was successfully added, \c false if the + * Value offset exceeds the document size. + */ + bool pushError(const Value& value, const std::string& message); + + /** \brief Add a semantic error message with extra context. + * \param value JSON Value location associated with the error + * \param message The error message. + * \param extra Additional JSON Value location to contextualize the error + * \return \c true if the error was successfully added, \c false if either + * Value offset exceeds the document size. + */ + bool pushError(const Value& value, const std::string& message, const Value& extra); + + /** \brief Return whether there are any errors. + * \return \c true if there are no errors to report \c false if + * errors have occurred. + */ + bool good() const; + +private: + enum TokenType { + tokenEndOfStream = 0, + tokenObjectBegin, + tokenObjectEnd, + tokenArrayBegin, + tokenArrayEnd, + tokenString, + tokenNumber, + tokenTrue, + tokenFalse, + tokenNull, + tokenArraySeparator, + tokenMemberSeparator, + tokenComment, + tokenError + }; + + class Token { + public: + TokenType type_; + Location start_; + Location end_; + }; + + class ErrorInfo { + public: + Token token_; + std::string message_; + Location extra_; + }; + + typedef std::deque<ErrorInfo> Errors; + + bool readToken(Token& token); + void skipSpaces(); + bool match(Location pattern, int patternLength); + bool readComment(); + bool readCStyleComment(); + bool readCppStyleComment(); + bool readString(); + void readNumber(); + bool readValue(); + bool readObject(Token& token); + bool readArray(Token& token); + bool decodeNumber(Token& token); + bool decodeNumber(Token& token, Value& decoded); + bool decodeString(Token& token); + bool decodeString(Token& token, std::string& decoded); + bool decodeDouble(Token& token); + bool decodeDouble(Token& token, Value& decoded); + bool decodeUnicodeCodePoint(Token& token, + Location& current, + Location end, + unsigned int& unicode); + bool decodeUnicodeEscapeSequence(Token& token, + Location& current, + Location end, + unsigned int& unicode); + bool addError(const std::string& message, Token& token, Location extra = 0); + bool recoverFromError(TokenType skipUntilToken); + bool addErrorAndRecover(const std::string& message, + Token& token, + TokenType skipUntilToken); + void skipUntilSpace(); + Value& currentValue(); + Char getNextChar(); + void + getLocationLineAndColumn(Location location, int& line, int& column) const; + std::string getLocationLineAndColumn(Location location) const; + void addComment(Location begin, Location end, CommentPlacement placement); + void skipCommentTokens(Token& token); + + typedef std::stack<Value*> Nodes; + Nodes nodes_; + Errors errors_; + std::string document_; + Location begin_; + Location end_; + Location current_; + Location lastValueEnd_; + Value* lastValue_; + std::string commentsBefore_; + Features features_; + bool collectComments_; +}; // Reader + +/** Interface for reading JSON from a char array. + */ +class JSON_API CharReader { +public: + virtual ~CharReader() {} + /** \brief Read a Value from a <a HREF="http://www.json.org">JSON</a> + document. + * The document must be a UTF-8 encoded string containing the document to read. + * + * \param beginDoc Pointer on the beginning of the UTF-8 encoded string of the + document to read. + * \param endDoc Pointer on the end of the UTF-8 encoded string of the + document to read. + * Must be >= beginDoc. + * \param root [out] Contains the root value of the document if it was + * successfully parsed. + * \param errs [out] Formatted error messages (if not NULL) + * a user friendly string that lists errors in the parsed + * document. + * \return \c true if the document was successfully parsed, \c false if an + error occurred. + */ + virtual bool parse( + char const* beginDoc, char const* endDoc, + Value* root, std::string* errs) = 0; + + class Factory { + public: + virtual ~Factory() {} + /** \brief Allocate a CharReader via operator new(). + * \throw std::exception if something goes wrong (e.g. invalid settings) + */ + virtual CharReader* newCharReader() const = 0; + }; // Factory +}; // CharReader + +/** \brief Build a CharReader implementation. + +Usage: +\code + using namespace Json; + CharReaderBuilder builder; + builder["collectComments"] = false; + Value value; + std::string errs; + bool ok = parseFromStream(builder, std::cin, &value, &errs); +\endcode +*/ +class JSON_API CharReaderBuilder : public CharReader::Factory { +public: + // Note: We use a Json::Value so that we can add data-members to this class + // without a major version bump. + /** Configuration of this builder. + These are case-sensitive. + Available settings (case-sensitive): + - `"collectComments": false or true` + - true to collect comment and allow writing them + back during serialization, false to discard comments. + This parameter is ignored if allowComments is false. + - `"allowComments": false or true` + - true if comments are allowed. + - `"strictRoot": false or true` + - true if root must be either an array or an object value + - `"allowDroppedNullPlaceholders": false or true` + - true if dropped null placeholders are allowed. (See StreamWriterBuilder.) + - `"allowNumericKeys": false or true` + - true if numeric object keys are allowed. + - `"allowSingleQuotes": false or true` + - true if '' are allowed for strings (both keys and values) + - `"stackLimit": integer` + - Exceeding stackLimit (recursive depth of `readValue()`) will + cause an exception. + - This is a security issue (seg-faults caused by deeply nested JSON), + so the default is low. + - `"failIfExtra": false or true` + - If true, `parse()` returns false when extra non-whitespace trails + the JSON value in the input string. + - `"rejectDupKeys": false or true` + - If true, `parse()` returns false when a key is duplicated within an object. + + You can examine 'settings_` yourself + to see the defaults. You can also write and read them just like any + JSON Value. + \sa setDefaults() + */ + Json::Value settings_; + + CharReaderBuilder(); + virtual ~CharReaderBuilder(); + + virtual CharReader* newCharReader() const; + + /** \return true if 'settings' are legal and consistent; + * otherwise, indicate bad settings via 'invalid'. + */ + bool validate(Json::Value* invalid) const; + + /** A simple way to update a specific setting. + */ + Value& operator[](std::string key); + + /** Called by ctor, but you can use this to reset settings_. + * \pre 'settings' != NULL (but Json::null is fine) + * \remark Defaults: + * \snippet src/lib_json/json_reader.cpp CharReaderBuilderDefaults + */ + static void setDefaults(Json::Value* settings); + /** Same as old Features::strictMode(). + * \pre 'settings' != NULL (but Json::null is fine) + * \remark Defaults: + * \snippet src/lib_json/json_reader.cpp CharReaderBuilderStrictMode + */ + static void strictMode(Json::Value* settings); +}; + +/** Consume entire stream and use its begin/end. + * Someday we might have a real StreamReader, but for now this + * is convenient. + */ +bool JSON_API parseFromStream( + CharReader::Factory const&, + std::istream&, + Value* root, std::string* errs); + +/** \brief Read from 'sin' into 'root'. + + Always keep comments from the input JSON. + + This can be used to read a file into a particular sub-object. + For example: + \code + Json::Value root; + cin >> root["dir"]["file"]; + cout << root; + \endcode + Result: + \verbatim + { + "dir": { + "file": { + // The input stream JSON would be nested here. + } + } + } + \endverbatim + \throw std::exception on parse error. + \see Json::operator<<() +*/ +JSON_API std::istream& operator>>(std::istream&, Value&); + +} // namespace Json + +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(pop) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +#endif // CPPTL_JSON_READER_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/reader.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/writer.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef JSON_WRITER_H_INCLUDED +#define JSON_WRITER_H_INCLUDED + +#if !defined(JSON_IS_AMALGAMATION) +#include "value.h" +#endif // if !defined(JSON_IS_AMALGAMATION) +#include <vector> +#include <string> +#include <ostream> + +// Disable warning C4251: <data member>: <type> needs to have dll-interface to +// be used by... +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(push) +#pragma warning(disable : 4251) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +namespace Json { + +class Value; + +/** + +Usage: +\code + using namespace Json; + void writeToStdout(StreamWriter::Factory const& factory, Value const& value) { + std::unique_ptr<StreamWriter> const writer( + factory.newStreamWriter()); + writer->write(value, &std::cout); + std::cout << std::endl; // add lf and flush + } +\endcode +*/ +class JSON_API StreamWriter { +protected: + std::ostream* sout_; // not owned; will not delete +public: + StreamWriter(); + virtual ~StreamWriter(); + /** Write Value into document as configured in sub-class. + Do not take ownership of sout, but maintain a reference during function. + \pre sout != NULL + \return zero on success (For now, we always return zero, so check the stream instead.) + \throw std::exception possibly, depending on configuration + */ + virtual int write(Value const& root, std::ostream* sout) = 0; + + /** \brief A simple abstract factory. + */ + class JSON_API Factory { + public: + virtual ~Factory(); + /** \brief Allocate a CharReader via operator new(). + * \throw std::exception if something goes wrong (e.g. invalid settings) + */ + virtual StreamWriter* newStreamWriter() const = 0; + }; // Factory +}; // StreamWriter + +/** \brief Write into stringstream, then return string, for convenience. + * A StreamWriter will be created from the factory, used, and then deleted. + */ +std::string JSON_API writeString(StreamWriter::Factory const& factory, Value const& root); + + +/** \brief Build a StreamWriter implementation. + +Usage: +\code + using namespace Json; + Value value = ...; + StreamWriterBuilder builder; + builder["commentStyle"] = "None"; + builder["indentation"] = " "; // or whatever you like + std::unique_ptr<Json::StreamWriter> writer( + builder.newStreamWriter()); + writer->write(value, &std::cout); + std::cout << std::endl; // add lf and flush +\endcode +*/ +class JSON_API StreamWriterBuilder : public StreamWriter::Factory { +public: + // Note: We use a Json::Value so that we can add data-members to this class + // without a major version bump. + /** Configuration of this builder. + Available settings (case-sensitive): + - "commentStyle": "None" or "All" + - "indentation": "<anything>" + - "enableYAMLCompatibility": false or true + - slightly change the whitespace around colons + - "dropNullPlaceholders": false or true + - Drop the "null" string from the writer's output for nullValues. + Strictly speaking, this is not valid JSON. But when the output is being + fed to a browser's Javascript, it makes for smaller output and the + browser can handle the output just fine. + + You can examine 'settings_` yourself + to see the defaults. You can also write and read them just like any + JSON Value. + \sa setDefaults() + */ + Json::Value settings_; + + StreamWriterBuilder(); + virtual ~StreamWriterBuilder(); + + /** + * \throw std::exception if something goes wrong (e.g. invalid settings) + */ + virtual StreamWriter* newStreamWriter() const; + + /** \return true if 'settings' are legal and consistent; + * otherwise, indicate bad settings via 'invalid'. + */ + bool validate(Json::Value* invalid) const; + /** A simple way to update a specific setting. + */ + Value& operator[](std::string key); + + /** Called by ctor, but you can use this to reset settings_. + * \pre 'settings' != NULL (but Json::null is fine) + * \remark Defaults: + * \snippet src/lib_json/json_writer.cpp StreamWriterBuilderDefaults + */ + static void setDefaults(Json::Value* settings); +}; + +/** \brief Abstract class for writers. + * \deprecated Use StreamWriter. (And really, this is an implementation detail.) + */ +class JSON_API Writer { +public: + virtual ~Writer(); + + virtual std::string write(const Value& root) = 0; +}; + +/** \brief Outputs a Value in <a HREF="http://www.json.org">JSON</a> format + *without formatting (not human friendly). + * + * The JSON document is written in a single line. It is not intended for 'human' + *consumption, + * but may be usefull to support feature such as RPC where bandwith is limited. + * \sa Reader, Value + * \deprecated Use StreamWriterBuilder. + */ +class JSON_API FastWriter : public Writer { + +public: + FastWriter(); + virtual ~FastWriter() {} + + void enableYAMLCompatibility(); + + /** \brief Drop the "null" string from the writer's output for nullValues. + * Strictly speaking, this is not valid JSON. But when the output is being + * fed to a browser's Javascript, it makes for smaller output and the + * browser can handle the output just fine. + */ + void dropNullPlaceholders(); + + void omitEndingLineFeed(); + +public: // overridden from Writer + virtual std::string write(const Value& root); + +private: + void writeValue(const Value& value); + + std::string document_; + bool yamlCompatiblityEnabled_; + bool dropNullPlaceholders_; + bool omitEndingLineFeed_; +}; + +/** \brief Writes a Value in <a HREF="http://www.json.org">JSON</a> format in a + *human friendly way. + * + * The rules for line break and indent are as follow: + * - Object value: + * - if empty then print {} without indent and line break + * - if not empty the print '{', line break & indent, print one value per + *line + * and then unindent and line break and print '}'. + * - Array value: + * - if empty then print [] without indent and line break + * - if the array contains no object value, empty array or some other value + *types, + * and all the values fit on one lines, then print the array on a single + *line. + * - otherwise, it the values do not fit on one line, or the array contains + * object or non empty array, then print one value per line. + * + * If the Value have comments then they are outputed according to their + *#CommentPlacement. + * + * \sa Reader, Value, Value::setComment() + * \deprecated Use StreamWriterBuilder. + */ +class JSON_API StyledWriter : public Writer { +public: + StyledWriter(); + virtual ~StyledWriter() {} + +public: // overridden from Writer + /** \brief Serialize a Value in <a HREF="http://www.json.org">JSON</a> format. + * \param root Value to serialize. + * \return String containing the JSON document that represents the root value. + */ + virtual std::string write(const Value& root); + +private: + void writeValue(const Value& value); + void writeArrayValue(const Value& value); + bool isMultineArray(const Value& value); + void pushValue(const std::string& value); + void writeIndent(); + void writeWithIndent(const std::string& value); + void indent(); + void unindent(); + void writeCommentBeforeValue(const Value& root); + void writeCommentAfterValueOnSameLine(const Value& root); + bool hasCommentForValue(const Value& value); + static std::string normalizeEOL(const std::string& text); + + typedef std::vector<std::string> ChildValues; + + ChildValues childValues_; + std::string document_; + std::string indentString_; + int rightMargin_; + int indentSize_; + bool addChildValues_; +}; + +/** \brief Writes a Value in <a HREF="http://www.json.org">JSON</a> format in a + human friendly way, + to a stream rather than to a string. + * + * The rules for line break and indent are as follow: + * - Object value: + * - if empty then print {} without indent and line break + * - if not empty the print '{', line break & indent, print one value per + line + * and then unindent and line break and print '}'. + * - Array value: + * - if empty then print [] without indent and line break + * - if the array contains no object value, empty array or some other value + types, + * and all the values fit on one lines, then print the array on a single + line. + * - otherwise, it the values do not fit on one line, or the array contains + * object or non empty array, then print one value per line. + * + * If the Value have comments then they are outputed according to their + #CommentPlacement. + * + * \param indentation Each level will be indented by this amount extra. + * \sa Reader, Value, Value::setComment() + * \deprecated Use StreamWriterBuilder. + */ +class JSON_API StyledStreamWriter { +public: + StyledStreamWriter(std::string indentation = "\t"); + ~StyledStreamWriter() {} + +public: + /** \brief Serialize a Value in <a HREF="http://www.json.org">JSON</a> format. + * \param out Stream to write to. (Can be ostringstream, e.g.) + * \param root Value to serialize. + * \note There is no point in deriving from Writer, since write() should not + * return a value. + */ + void write(std::ostream& out, const Value& root); + +private: + void writeValue(const Value& value); + void writeArrayValue(const Value& value); + bool isMultineArray(const Value& value); + void pushValue(const std::string& value); + void writeIndent(); + void writeWithIndent(const std::string& value); + void indent(); + void unindent(); + void writeCommentBeforeValue(const Value& root); + void writeCommentAfterValueOnSameLine(const Value& root); + bool hasCommentForValue(const Value& value); + static std::string normalizeEOL(const std::string& text); + + typedef std::vector<std::string> ChildValues; + + ChildValues childValues_; + std::ostream* document_; + std::string indentString_; + int rightMargin_; + std::string indentation_; + bool addChildValues_ : 1; + bool indented_ : 1; +}; + +#if defined(JSON_HAS_INT64) +std::string JSON_API valueToString(Int value); +std::string JSON_API valueToString(UInt value); +#endif // if defined(JSON_HAS_INT64) +std::string JSON_API valueToString(LargestInt value); +std::string JSON_API valueToString(LargestUInt value); +std::string JSON_API valueToString(double value); +std::string JSON_API valueToString(bool value); +std::string JSON_API valueToQuotedString(const char* value); + +/// \brief Output using the StyledStreamWriter. +/// \see Json::operator>>() +JSON_API std::ostream& operator<<(std::ostream&, const Value& root); + +} // namespace Json + +#if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) +#pragma warning(pop) +#endif // if defined(JSONCPP_DISABLE_DLL_INTERFACE_WARNING) + +#endif // JSON_WRITER_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/writer.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: include/json/assertions.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef CPPTL_JSON_ASSERTIONS_H_INCLUDED +#define CPPTL_JSON_ASSERTIONS_H_INCLUDED + +#include <stdlib.h> +#include <sstream> + +#if !defined(JSON_IS_AMALGAMATION) +#include "config.h" +#endif // if !defined(JSON_IS_AMALGAMATION) + +/** It should not be possible for a maliciously designed file to + * cause an abort() or seg-fault, so these macros are used only + * for pre-condition violations and internal logic errors. + */ +#if JSON_USE_EXCEPTION + +// @todo <= add detail about condition in exception +# define JSON_ASSERT(condition) \ + {if (!(condition)) {Json::throwLogicError( "assert json failed" );}} + +# define JSON_FAIL_MESSAGE(message) \ + { \ + std::ostringstream oss; oss << message; \ + Json::throwLogicError(oss.str()); \ + abort(); \ + } + +#else // JSON_USE_EXCEPTION + +# define JSON_ASSERT(condition) assert(condition) + +// The call to assert() will show the failure message in debug builds. In +// release builds we abort, for a core-dump or debugger. +# define JSON_FAIL_MESSAGE(message) \ + { \ + std::ostringstream oss; oss << message; \ + assert(false && oss.str().c_str()); \ + abort(); \ + } + + +#endif + +#define JSON_ASSERT_MESSAGE(condition, message) \ + if (!(condition)) { \ + JSON_FAIL_MESSAGE(message); \ + } + +#endif // CPPTL_JSON_ASSERTIONS_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: include/json/assertions.h +// ////////////////////////////////////////////////////////////////////// + + + + + +#endif //ifndef JSON_AMALGATED_H_INCLUDED diff --git a/external_libs/json/jsoncpp.cpp b/external_libs/json/jsoncpp.cpp new file mode 100644 index 00000000..e9f7e986 --- /dev/null +++ b/external_libs/json/jsoncpp.cpp @@ -0,0 +1,5100 @@ +/// Json-cpp amalgated source (http://jsoncpp.sourceforge.net/). +/// It is intended to be used with #include "json/json.h" + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + +/* +The JsonCpp library's source code, including accompanying documentation, +tests and demonstration applications, are licensed under the following +conditions... + +The author (Baptiste Lepilleur) explicitly disclaims copyright in all +jurisdictions which recognize such a disclaimer. In such jurisdictions, +this software is released into the Public Domain. + +In jurisdictions which do not recognize Public Domain property (e.g. Germany as of +2010), this software is Copyright (c) 2007-2010 by Baptiste Lepilleur, and is +released under the terms of the MIT License (see below). + +In jurisdictions which recognize Public Domain property, the user of this +software may choose to accept it either as 1) Public Domain, 2) under the +conditions of the MIT License (see below), or 3) under the terms of dual +Public Domain/MIT License conditions described here, as they choose. + +The MIT License is about as close to Public Domain as a license can get, and is +described in clear, concise terms at: + + http://en.wikipedia.org/wiki/MIT_License + +The full text of the MIT License follows: + +======================================================================== +Copyright (c) 2007-2010 Baptiste Lepilleur + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, copy, +modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +======================================================================== +(END LICENSE TEXT) + +The MIT license is compatible with both the GPL and commercial +software, affording one all of the rights of Public Domain with the +minor nuisance of being required to keep the above copyright notice +and license text in the source code. Note also that by accepting the +Public Domain "license" you can re-license your copy using whatever +license you like. + +*/ + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: LICENSE +// ////////////////////////////////////////////////////////////////////// + + + + + + +#include "json/json.h" + +#ifndef JSON_IS_AMALGAMATION +#error "Compile with -I PATH_TO_JSON_DIRECTORY" +#endif + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: src/lib_json/json_tool.h +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#ifndef LIB_JSONCPP_JSON_TOOL_H_INCLUDED +#define LIB_JSONCPP_JSON_TOOL_H_INCLUDED + +/* This header provides common string manipulation support, such as UTF-8, + * portable conversion from/to string... + * + * It is an internal header that must not be exposed. + */ + +namespace Json { + +/// Converts a unicode code-point to UTF-8. +static inline std::string codePointToUTF8(unsigned int cp) { + std::string result; + + // based on description from http://en.wikipedia.org/wiki/UTF-8 + + if (cp <= 0x7f) { + result.resize(1); + result[0] = static_cast<char>(cp); + } else if (cp <= 0x7FF) { + result.resize(2); + result[1] = static_cast<char>(0x80 | (0x3f & cp)); + result[0] = static_cast<char>(0xC0 | (0x1f & (cp >> 6))); + } else if (cp <= 0xFFFF) { + result.resize(3); + result[2] = static_cast<char>(0x80 | (0x3f & cp)); + result[1] = 0x80 | static_cast<char>((0x3f & (cp >> 6))); + result[0] = 0xE0 | static_cast<char>((0xf & (cp >> 12))); + } else if (cp <= 0x10FFFF) { + result.resize(4); + result[3] = static_cast<char>(0x80 | (0x3f & cp)); + result[2] = static_cast<char>(0x80 | (0x3f & (cp >> 6))); + result[1] = static_cast<char>(0x80 | (0x3f & (cp >> 12))); + result[0] = static_cast<char>(0xF0 | (0x7 & (cp >> 18))); + } + + return result; +} + +/// Returns true if ch is a control character (in range [0,32[). +static inline bool isControlCharacter(char ch) { return ch > 0 && ch <= 0x1F; } + +enum { + /// Constant that specify the size of the buffer that must be passed to + /// uintToString. + uintToStringBufferSize = 3 * sizeof(LargestUInt) + 1 +}; + +// Defines a char buffer for use with uintToString(). +typedef char UIntToStringBuffer[uintToStringBufferSize]; + +/** Converts an unsigned integer to string. + * @param value Unsigned interger to convert to string + * @param current Input/Output string buffer. + * Must have at least uintToStringBufferSize chars free. + */ +static inline void uintToString(LargestUInt value, char*& current) { + *--current = 0; + do { + *--current = char(value % 10) + '0'; + value /= 10; + } while (value != 0); +} + +/** Change ',' to '.' everywhere in buffer. + * + * We had a sophisticated way, but it did not work in WinCE. + * @see https://github.com/open-source-parsers/jsoncpp/pull/9 + */ +static inline void fixNumericLocale(char* begin, char* end) { + while (begin < end) { + if (*begin == ',') { + *begin = '.'; + } + ++begin; + } +} + +} // namespace Json { + +#endif // LIB_JSONCPP_JSON_TOOL_H_INCLUDED + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: src/lib_json/json_tool.h +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: src/lib_json/json_reader.cpp +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2011 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#if !defined(JSON_IS_AMALGAMATION) +#include <json/assertions.h> +#include <json/reader.h> +#include <json/value.h> +#include "json_tool.h" +#endif // if !defined(JSON_IS_AMALGAMATION) +#include <utility> +#include <cstdio> +#include <cassert> +#include <cstring> +#include <istream> +#include <sstream> +#include <memory> +#include <set> + +#if defined(_MSC_VER) && _MSC_VER < 1500 // VC++ 8.0 and below +#define snprintf _snprintf +#endif + +#if defined(_MSC_VER) && _MSC_VER >= 1400 // VC++ 8.0 +// Disable warning about strdup being deprecated. +#pragma warning(disable : 4996) +#endif + +static int const stackLimit_g = 1000; +static int stackDepth_g = 0; // see readValue() + +namespace Json { + +#if __cplusplus >= 201103L +typedef std::unique_ptr<CharReader> CharReaderPtr; +#else +typedef std::auto_ptr<CharReader> CharReaderPtr; +#endif + +// Implementation of class Features +// //////////////////////////////// + +Features::Features() + : allowComments_(true), strictRoot_(false), + allowDroppedNullPlaceholders_(false), allowNumericKeys_(false) {} + +Features Features::all() { return Features(); } + +Features Features::strictMode() { + Features features; + features.allowComments_ = false; + features.strictRoot_ = true; + features.allowDroppedNullPlaceholders_ = false; + features.allowNumericKeys_ = false; + return features; +} + +// Implementation of class Reader +// //////////////////////////////// + +static bool containsNewLine(Reader::Location begin, Reader::Location end) { + for (; begin < end; ++begin) + if (*begin == '\n' || *begin == '\r') + return true; + return false; +} + +// Class Reader +// ////////////////////////////////////////////////////////////////// + +Reader::Reader() + : errors_(), document_(), begin_(), end_(), current_(), lastValueEnd_(), + lastValue_(), commentsBefore_(), features_(Features::all()), + collectComments_() {} + +Reader::Reader(const Features& features) + : errors_(), document_(), begin_(), end_(), current_(), lastValueEnd_(), + lastValue_(), commentsBefore_(), features_(features), collectComments_() { +} + +bool +Reader::parse(const std::string& document, Value& root, bool collectComments) { + document_ = document; + const char* begin = document_.c_str(); + const char* end = begin + document_.length(); + return parse(begin, end, root, collectComments); +} + +bool Reader::parse(std::istream& sin, Value& root, bool collectComments) { + // std::istream_iterator<char> begin(sin); + // std::istream_iterator<char> end; + // Those would allow streamed input from a file, if parse() were a + // template function. + + // Since std::string is reference-counted, this at least does not + // create an extra copy. + std::string doc; + std::getline(sin, doc, (char)EOF); + return parse(doc, root, collectComments); +} + +bool Reader::parse(const char* beginDoc, + const char* endDoc, + Value& root, + bool collectComments) { + if (!features_.allowComments_) { + collectComments = false; + } + + begin_ = beginDoc; + end_ = endDoc; + collectComments_ = collectComments; + current_ = begin_; + lastValueEnd_ = 0; + lastValue_ = 0; + commentsBefore_ = ""; + errors_.clear(); + while (!nodes_.empty()) + nodes_.pop(); + nodes_.push(&root); + + stackDepth_g = 0; // Yes, this is bad coding, but options are limited. + bool successful = readValue(); + Token token; + skipCommentTokens(token); + if (collectComments_ && !commentsBefore_.empty()) + root.setComment(commentsBefore_, commentAfter); + if (features_.strictRoot_) { + if (!root.isArray() && !root.isObject()) { + // Set error location to start of doc, ideally should be first token found + // in doc + token.type_ = tokenError; + token.start_ = beginDoc; + token.end_ = endDoc; + addError( + "A valid JSON document must be either an array or an object value.", + token); + return false; + } + } + return successful; +} + +bool Reader::readValue() { + // This is a non-reentrant way to support a stackLimit. Terrible! + // But this deprecated class has a security problem: Bad input can + // cause a seg-fault. This seems like a fair, binary-compatible way + // to prevent the problem. + if (stackDepth_g >= stackLimit_g) throwRuntimeError("Exceeded stackLimit in readValue()."); + ++stackDepth_g; + + Token token; + skipCommentTokens(token); + bool successful = true; + + if (collectComments_ && !commentsBefore_.empty()) { + currentValue().setComment(commentsBefore_, commentBefore); + commentsBefore_ = ""; + } + + switch (token.type_) { + case tokenObjectBegin: + successful = readObject(token); + currentValue().setOffsetLimit(current_ - begin_); + break; + case tokenArrayBegin: + successful = readArray(token); + currentValue().setOffsetLimit(current_ - begin_); + break; + case tokenNumber: + successful = decodeNumber(token); + break; + case tokenString: + successful = decodeString(token); + break; + case tokenTrue: + { + Value v(true); + currentValue().swapPayload(v); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + } + break; + case tokenFalse: + { + Value v(false); + currentValue().swapPayload(v); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + } + break; + case tokenNull: + { + Value v; + currentValue().swapPayload(v); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + } + break; + case tokenArraySeparator: + case tokenObjectEnd: + case tokenArrayEnd: + if (features_.allowDroppedNullPlaceholders_) { + // "Un-read" the current token and mark the current value as a null + // token. + current_--; + Value v; + currentValue().swapPayload(v); + currentValue().setOffsetStart(current_ - begin_ - 1); + currentValue().setOffsetLimit(current_ - begin_); + break; + } // Else, fall through... + default: + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + return addError("Syntax error: value, object or array expected.", token); + } + + if (collectComments_) { + lastValueEnd_ = current_; + lastValue_ = ¤tValue(); + } + + --stackDepth_g; + return successful; +} + +void Reader::skipCommentTokens(Token& token) { + if (features_.allowComments_) { + do { + readToken(token); + } while (token.type_ == tokenComment); + } else { + readToken(token); + } +} + +bool Reader::readToken(Token& token) { + skipSpaces(); + token.start_ = current_; + Char c = getNextChar(); + bool ok = true; + switch (c) { + case '{': + token.type_ = tokenObjectBegin; + break; + case '}': + token.type_ = tokenObjectEnd; + break; + case '[': + token.type_ = tokenArrayBegin; + break; + case ']': + token.type_ = tokenArrayEnd; + break; + case '"': + token.type_ = tokenString; + ok = readString(); + break; + case '/': + token.type_ = tokenComment; + ok = readComment(); + break; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '-': + token.type_ = tokenNumber; + readNumber(); + break; + case 't': + token.type_ = tokenTrue; + ok = match("rue", 3); + break; + case 'f': + token.type_ = tokenFalse; + ok = match("alse", 4); + break; + case 'n': + token.type_ = tokenNull; + ok = match("ull", 3); + break; + case ',': + token.type_ = tokenArraySeparator; + break; + case ':': + token.type_ = tokenMemberSeparator; + break; + case 0: + token.type_ = tokenEndOfStream; + break; + default: + ok = false; + break; + } + if (!ok) + token.type_ = tokenError; + token.end_ = current_; + return true; +} + +void Reader::skipSpaces() { + while (current_ != end_) { + Char c = *current_; + if (c == ' ' || c == '\t' || c == '\r' || c == '\n') + ++current_; + else + break; + } +} + +bool Reader::match(Location pattern, int patternLength) { + if (end_ - current_ < patternLength) + return false; + int index = patternLength; + while (index--) + if (current_[index] != pattern[index]) + return false; + current_ += patternLength; + return true; +} + +bool Reader::readComment() { + Location commentBegin = current_ - 1; + Char c = getNextChar(); + bool successful = false; + if (c == '*') + successful = readCStyleComment(); + else if (c == '/') + successful = readCppStyleComment(); + if (!successful) + return false; + + if (collectComments_) { + CommentPlacement placement = commentBefore; + if (lastValueEnd_ && !containsNewLine(lastValueEnd_, commentBegin)) { + if (c != '*' || !containsNewLine(commentBegin, current_)) + placement = commentAfterOnSameLine; + } + + addComment(commentBegin, current_, placement); + } + return true; +} + +static std::string normalizeEOL(Reader::Location begin, Reader::Location end) { + std::string normalized; + normalized.reserve(end - begin); + Reader::Location current = begin; + while (current != end) { + char c = *current++; + if (c == '\r') { + if (current != end && *current == '\n') + // convert dos EOL + ++current; + // convert Mac EOL + normalized += '\n'; + } else { + normalized += c; + } + } + return normalized; +} + +void +Reader::addComment(Location begin, Location end, CommentPlacement placement) { + assert(collectComments_); + const std::string& normalized = normalizeEOL(begin, end); + if (placement == commentAfterOnSameLine) { + assert(lastValue_ != 0); + lastValue_->setComment(normalized, placement); + } else { + commentsBefore_ += normalized; + } +} + +bool Reader::readCStyleComment() { + while (current_ != end_) { + Char c = getNextChar(); + if (c == '*' && *current_ == '/') + break; + } + return getNextChar() == '/'; +} + +bool Reader::readCppStyleComment() { + while (current_ != end_) { + Char c = getNextChar(); + if (c == '\n') + break; + if (c == '\r') { + // Consume DOS EOL. It will be normalized in addComment. + if (current_ != end_ && *current_ == '\n') + getNextChar(); + // Break on Moc OS 9 EOL. + break; + } + } + return true; +} + +void Reader::readNumber() { + const char *p = current_; + char c = '0'; // stopgap for already consumed character + // integral part + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + // fractional part + if (c == '.') { + c = (current_ = p) < end_ ? *p++ : 0; + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + } + // exponential part + if (c == 'e' || c == 'E') { + c = (current_ = p) < end_ ? *p++ : 0; + if (c == '+' || c == '-') + c = (current_ = p) < end_ ? *p++ : 0; + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + } +} + +bool Reader::readString() { + Char c = 0; + while (current_ != end_) { + c = getNextChar(); + if (c == '\\') + getNextChar(); + else if (c == '"') + break; + } + return c == '"'; +} + +bool Reader::readObject(Token& tokenStart) { + Token tokenName; + std::string name; + Value init(objectValue); + currentValue().swapPayload(init); + currentValue().setOffsetStart(tokenStart.start_ - begin_); + while (readToken(tokenName)) { + bool initialTokenOk = true; + while (tokenName.type_ == tokenComment && initialTokenOk) + initialTokenOk = readToken(tokenName); + if (!initialTokenOk) + break; + if (tokenName.type_ == tokenObjectEnd && name.empty()) // empty object + return true; + name = ""; + if (tokenName.type_ == tokenString) { + if (!decodeString(tokenName, name)) + return recoverFromError(tokenObjectEnd); + } else if (tokenName.type_ == tokenNumber && features_.allowNumericKeys_) { + Value numberName; + if (!decodeNumber(tokenName, numberName)) + return recoverFromError(tokenObjectEnd); + name = numberName.asString(); + } else { + break; + } + + Token colon; + if (!readToken(colon) || colon.type_ != tokenMemberSeparator) { + return addErrorAndRecover( + "Missing ':' after object member name", colon, tokenObjectEnd); + } + Value& value = currentValue()[name]; + nodes_.push(&value); + bool ok = readValue(); + nodes_.pop(); + if (!ok) // error already set + return recoverFromError(tokenObjectEnd); + + Token comma; + if (!readToken(comma) || + (comma.type_ != tokenObjectEnd && comma.type_ != tokenArraySeparator && + comma.type_ != tokenComment)) { + return addErrorAndRecover( + "Missing ',' or '}' in object declaration", comma, tokenObjectEnd); + } + bool finalizeTokenOk = true; + while (comma.type_ == tokenComment && finalizeTokenOk) + finalizeTokenOk = readToken(comma); + if (comma.type_ == tokenObjectEnd) + return true; + } + return addErrorAndRecover( + "Missing '}' or object member name", tokenName, tokenObjectEnd); +} + +bool Reader::readArray(Token& tokenStart) { + Value init(arrayValue); + currentValue().swapPayload(init); + currentValue().setOffsetStart(tokenStart.start_ - begin_); + skipSpaces(); + if (*current_ == ']') // empty array + { + Token endArray; + readToken(endArray); + return true; + } + int index = 0; + for (;;) { + Value& value = currentValue()[index++]; + nodes_.push(&value); + bool ok = readValue(); + nodes_.pop(); + if (!ok) // error already set + return recoverFromError(tokenArrayEnd); + + Token token; + // Accept Comment after last item in the array. + ok = readToken(token); + while (token.type_ == tokenComment && ok) { + ok = readToken(token); + } + bool badTokenType = + (token.type_ != tokenArraySeparator && token.type_ != tokenArrayEnd); + if (!ok || badTokenType) { + return addErrorAndRecover( + "Missing ',' or ']' in array declaration", token, tokenArrayEnd); + } + if (token.type_ == tokenArrayEnd) + break; + } + return true; +} + +bool Reader::decodeNumber(Token& token) { + Value decoded; + if (!decodeNumber(token, decoded)) + return false; + currentValue().swapPayload(decoded); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + return true; +} + +bool Reader::decodeNumber(Token& token, Value& decoded) { + // Attempts to parse the number as an integer. If the number is + // larger than the maximum supported value of an integer then + // we decode the number as a double. + Location current = token.start_; + bool isNegative = *current == '-'; + if (isNegative) + ++current; + // TODO: Help the compiler do the div and mod at compile time or get rid of them. + Value::LargestUInt maxIntegerValue = + isNegative ? Value::LargestUInt(-Value::minLargestInt) + : Value::maxLargestUInt; + Value::LargestUInt threshold = maxIntegerValue / 10; + Value::LargestUInt value = 0; + while (current < token.end_) { + Char c = *current++; + if (c < '0' || c > '9') + return decodeDouble(token, decoded); + Value::UInt digit(c - '0'); + if (value >= threshold) { + // We've hit or exceeded the max value divided by 10 (rounded down). If + // a) we've only just touched the limit, b) this is the last digit, and + // c) it's small enough to fit in that rounding delta, we're okay. + // Otherwise treat this number as a double to avoid overflow. + if (value > threshold || current != token.end_ || + digit > maxIntegerValue % 10) { + return decodeDouble(token, decoded); + } + } + value = value * 10 + digit; + } + if (isNegative) + decoded = -Value::LargestInt(value); + else if (value <= Value::LargestUInt(Value::maxInt)) + decoded = Value::LargestInt(value); + else + decoded = value; + return true; +} + +bool Reader::decodeDouble(Token& token) { + Value decoded; + if (!decodeDouble(token, decoded)) + return false; + currentValue().swapPayload(decoded); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + return true; +} + +bool Reader::decodeDouble(Token& token, Value& decoded) { + double value = 0; + std::string buffer(token.start_, token.end_); + std::istringstream is(buffer); + if (!(is >> value)) + return addError("'" + std::string(token.start_, token.end_) + + "' is not a number.", + token); + decoded = value; + return true; +} + +bool Reader::decodeString(Token& token) { + std::string decoded_string; + if (!decodeString(token, decoded_string)) + return false; + Value decoded(decoded_string); + currentValue().swapPayload(decoded); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + return true; +} + +bool Reader::decodeString(Token& token, std::string& decoded) { + decoded.reserve(token.end_ - token.start_ - 2); + Location current = token.start_ + 1; // skip '"' + Location end = token.end_ - 1; // do not include '"' + while (current != end) { + Char c = *current++; + if (c == '"') + break; + else if (c == '\\') { + if (current == end) + return addError("Empty escape sequence in string", token, current); + Char escape = *current++; + switch (escape) { + case '"': + decoded += '"'; + break; + case '/': + decoded += '/'; + break; + case '\\': + decoded += '\\'; + break; + case 'b': + decoded += '\b'; + break; + case 'f': + decoded += '\f'; + break; + case 'n': + decoded += '\n'; + break; + case 'r': + decoded += '\r'; + break; + case 't': + decoded += '\t'; + break; + case 'u': { + unsigned int unicode; + if (!decodeUnicodeCodePoint(token, current, end, unicode)) + return false; + decoded += codePointToUTF8(unicode); + } break; + default: + return addError("Bad escape sequence in string", token, current); + } + } else { + decoded += c; + } + } + return true; +} + +bool Reader::decodeUnicodeCodePoint(Token& token, + Location& current, + Location end, + unsigned int& unicode) { + + if (!decodeUnicodeEscapeSequence(token, current, end, unicode)) + return false; + if (unicode >= 0xD800 && unicode <= 0xDBFF) { + // surrogate pairs + if (end - current < 6) + return addError( + "additional six characters expected to parse unicode surrogate pair.", + token, + current); + unsigned int surrogatePair; + if (*(current++) == '\\' && *(current++) == 'u') { + if (decodeUnicodeEscapeSequence(token, current, end, surrogatePair)) { + unicode = 0x10000 + ((unicode & 0x3FF) << 10) + (surrogatePair & 0x3FF); + } else + return false; + } else + return addError("expecting another \\u token to begin the second half of " + "a unicode surrogate pair", + token, + current); + } + return true; +} + +bool Reader::decodeUnicodeEscapeSequence(Token& token, + Location& current, + Location end, + unsigned int& unicode) { + if (end - current < 4) + return addError( + "Bad unicode escape sequence in string: four digits expected.", + token, + current); + unicode = 0; + for (int index = 0; index < 4; ++index) { + Char c = *current++; + unicode *= 16; + if (c >= '0' && c <= '9') + unicode += c - '0'; + else if (c >= 'a' && c <= 'f') + unicode += c - 'a' + 10; + else if (c >= 'A' && c <= 'F') + unicode += c - 'A' + 10; + else + return addError( + "Bad unicode escape sequence in string: hexadecimal digit expected.", + token, + current); + } + return true; +} + +bool +Reader::addError(const std::string& message, Token& token, Location extra) { + ErrorInfo info; + info.token_ = token; + info.message_ = message; + info.extra_ = extra; + errors_.push_back(info); + return false; +} + +bool Reader::recoverFromError(TokenType skipUntilToken) { + int errorCount = int(errors_.size()); + Token skip; + for (;;) { + if (!readToken(skip)) + errors_.resize(errorCount); // discard errors caused by recovery + if (skip.type_ == skipUntilToken || skip.type_ == tokenEndOfStream) + break; + } + errors_.resize(errorCount); + return false; +} + +bool Reader::addErrorAndRecover(const std::string& message, + Token& token, + TokenType skipUntilToken) { + addError(message, token); + return recoverFromError(skipUntilToken); +} + +Value& Reader::currentValue() { return *(nodes_.top()); } + +Reader::Char Reader::getNextChar() { + if (current_ == end_) + return 0; + return *current_++; +} + +void Reader::getLocationLineAndColumn(Location location, + int& line, + int& column) const { + Location current = begin_; + Location lastLineStart = current; + line = 0; + while (current < location && current != end_) { + Char c = *current++; + if (c == '\r') { + if (*current == '\n') + ++current; + lastLineStart = current; + ++line; + } else if (c == '\n') { + lastLineStart = current; + ++line; + } + } + // column & line start at 1 + column = int(location - lastLineStart) + 1; + ++line; +} + +std::string Reader::getLocationLineAndColumn(Location location) const { + int line, column; + getLocationLineAndColumn(location, line, column); + char buffer[18 + 16 + 16 + 1]; +#if defined(_MSC_VER) && defined(__STDC_SECURE_LIB__) +#if defined(WINCE) + _snprintf(buffer, sizeof(buffer), "Line %d, Column %d", line, column); +#else + sprintf_s(buffer, sizeof(buffer), "Line %d, Column %d", line, column); +#endif +#else + snprintf(buffer, sizeof(buffer), "Line %d, Column %d", line, column); +#endif + return buffer; +} + +// Deprecated. Preserved for backward compatibility +std::string Reader::getFormatedErrorMessages() const { + return getFormattedErrorMessages(); +} + +std::string Reader::getFormattedErrorMessages() const { + std::string formattedMessage; + for (Errors::const_iterator itError = errors_.begin(); + itError != errors_.end(); + ++itError) { + const ErrorInfo& error = *itError; + formattedMessage += + "* " + getLocationLineAndColumn(error.token_.start_) + "\n"; + formattedMessage += " " + error.message_ + "\n"; + if (error.extra_) + formattedMessage += + "See " + getLocationLineAndColumn(error.extra_) + " for detail.\n"; + } + return formattedMessage; +} + +std::vector<Reader::StructuredError> Reader::getStructuredErrors() const { + std::vector<Reader::StructuredError> allErrors; + for (Errors::const_iterator itError = errors_.begin(); + itError != errors_.end(); + ++itError) { + const ErrorInfo& error = *itError; + Reader::StructuredError structured; + structured.offset_start = error.token_.start_ - begin_; + structured.offset_limit = error.token_.end_ - begin_; + structured.message = error.message_; + allErrors.push_back(structured); + } + return allErrors; +} + +bool Reader::pushError(const Value& value, const std::string& message) { + size_t length = end_ - begin_; + if(value.getOffsetStart() > length + || value.getOffsetLimit() > length) + return false; + Token token; + token.type_ = tokenError; + token.start_ = begin_ + value.getOffsetStart(); + token.end_ = end_ + value.getOffsetLimit(); + ErrorInfo info; + info.token_ = token; + info.message_ = message; + info.extra_ = 0; + errors_.push_back(info); + return true; +} + +bool Reader::pushError(const Value& value, const std::string& message, const Value& extra) { + size_t length = end_ - begin_; + if(value.getOffsetStart() > length + || value.getOffsetLimit() > length + || extra.getOffsetLimit() > length) + return false; + Token token; + token.type_ = tokenError; + token.start_ = begin_ + value.getOffsetStart(); + token.end_ = begin_ + value.getOffsetLimit(); + ErrorInfo info; + info.token_ = token; + info.message_ = message; + info.extra_ = begin_ + extra.getOffsetStart(); + errors_.push_back(info); + return true; +} + +bool Reader::good() const { + return !errors_.size(); +} + +// exact copy of Features +class OurFeatures { +public: + static OurFeatures all(); + OurFeatures(); + bool allowComments_; + bool strictRoot_; + bool allowDroppedNullPlaceholders_; + bool allowNumericKeys_; + bool allowSingleQuotes_; + bool failIfExtra_; + bool rejectDupKeys_; + int stackLimit_; +}; // OurFeatures + +// exact copy of Implementation of class Features +// //////////////////////////////// + +OurFeatures::OurFeatures() + : allowComments_(true), strictRoot_(false) + , allowDroppedNullPlaceholders_(false), allowNumericKeys_(false) + , allowSingleQuotes_(false) + , failIfExtra_(false) +{ +} + +OurFeatures OurFeatures::all() { return OurFeatures(); } + +// Implementation of class Reader +// //////////////////////////////// + +// exact copy of Reader, renamed to OurReader +class OurReader { +public: + typedef char Char; + typedef const Char* Location; + struct StructuredError { + size_t offset_start; + size_t offset_limit; + std::string message; + }; + + OurReader(OurFeatures const& features); + bool parse(const char* beginDoc, + const char* endDoc, + Value& root, + bool collectComments = true); + std::string getFormattedErrorMessages() const; + std::vector<StructuredError> getStructuredErrors() const; + bool pushError(const Value& value, const std::string& message); + bool pushError(const Value& value, const std::string& message, const Value& extra); + bool good() const; + +private: + OurReader(OurReader const&); // no impl + void operator=(OurReader const&); // no impl + + enum TokenType { + tokenEndOfStream = 0, + tokenObjectBegin, + tokenObjectEnd, + tokenArrayBegin, + tokenArrayEnd, + tokenString, + tokenNumber, + tokenTrue, + tokenFalse, + tokenNull, + tokenArraySeparator, + tokenMemberSeparator, + tokenComment, + tokenError + }; + + class Token { + public: + TokenType type_; + Location start_; + Location end_; + }; + + class ErrorInfo { + public: + Token token_; + std::string message_; + Location extra_; + }; + + typedef std::deque<ErrorInfo> Errors; + + bool readToken(Token& token); + void skipSpaces(); + bool match(Location pattern, int patternLength); + bool readComment(); + bool readCStyleComment(); + bool readCppStyleComment(); + bool readString(); + bool readStringSingleQuote(); + void readNumber(); + bool readValue(); + bool readObject(Token& token); + bool readArray(Token& token); + bool decodeNumber(Token& token); + bool decodeNumber(Token& token, Value& decoded); + bool decodeString(Token& token); + bool decodeString(Token& token, std::string& decoded); + bool decodeDouble(Token& token); + bool decodeDouble(Token& token, Value& decoded); + bool decodeUnicodeCodePoint(Token& token, + Location& current, + Location end, + unsigned int& unicode); + bool decodeUnicodeEscapeSequence(Token& token, + Location& current, + Location end, + unsigned int& unicode); + bool addError(const std::string& message, Token& token, Location extra = 0); + bool recoverFromError(TokenType skipUntilToken); + bool addErrorAndRecover(const std::string& message, + Token& token, + TokenType skipUntilToken); + void skipUntilSpace(); + Value& currentValue(); + Char getNextChar(); + void + getLocationLineAndColumn(Location location, int& line, int& column) const; + std::string getLocationLineAndColumn(Location location) const; + void addComment(Location begin, Location end, CommentPlacement placement); + void skipCommentTokens(Token& token); + + typedef std::stack<Value*> Nodes; + Nodes nodes_; + Errors errors_; + std::string document_; + Location begin_; + Location end_; + Location current_; + Location lastValueEnd_; + Value* lastValue_; + std::string commentsBefore_; + int stackDepth_; + + OurFeatures const features_; + bool collectComments_; +}; // OurReader + +// complete copy of Read impl, for OurReader + +OurReader::OurReader(OurFeatures const& features) + : errors_(), document_(), begin_(), end_(), current_(), lastValueEnd_(), + lastValue_(), commentsBefore_(), features_(features), collectComments_() { +} + +bool OurReader::parse(const char* beginDoc, + const char* endDoc, + Value& root, + bool collectComments) { + if (!features_.allowComments_) { + collectComments = false; + } + + begin_ = beginDoc; + end_ = endDoc; + collectComments_ = collectComments; + current_ = begin_; + lastValueEnd_ = 0; + lastValue_ = 0; + commentsBefore_ = ""; + errors_.clear(); + while (!nodes_.empty()) + nodes_.pop(); + nodes_.push(&root); + + stackDepth_ = 0; + bool successful = readValue(); + Token token; + skipCommentTokens(token); + if (features_.failIfExtra_) { + if (token.type_ != tokenError && token.type_ != tokenEndOfStream) { + addError("Extra non-whitespace after JSON value.", token); + return false; + } + } + if (collectComments_ && !commentsBefore_.empty()) + root.setComment(commentsBefore_, commentAfter); + if (features_.strictRoot_) { + if (!root.isArray() && !root.isObject()) { + // Set error location to start of doc, ideally should be first token found + // in doc + token.type_ = tokenError; + token.start_ = beginDoc; + token.end_ = endDoc; + addError( + "A valid JSON document must be either an array or an object value.", + token); + return false; + } + } + return successful; +} + +bool OurReader::readValue() { + if (stackDepth_ >= features_.stackLimit_) throwRuntimeError("Exceeded stackLimit in readValue()."); + ++stackDepth_; + Token token; + skipCommentTokens(token); + bool successful = true; + + if (collectComments_ && !commentsBefore_.empty()) { + currentValue().setComment(commentsBefore_, commentBefore); + commentsBefore_ = ""; + } + + switch (token.type_) { + case tokenObjectBegin: + successful = readObject(token); + currentValue().setOffsetLimit(current_ - begin_); + break; + case tokenArrayBegin: + successful = readArray(token); + currentValue().setOffsetLimit(current_ - begin_); + break; + case tokenNumber: + successful = decodeNumber(token); + break; + case tokenString: + successful = decodeString(token); + break; + case tokenTrue: + { + Value v(true); + currentValue().swapPayload(v); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + } + break; + case tokenFalse: + { + Value v(false); + currentValue().swapPayload(v); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + } + break; + case tokenNull: + { + Value v; + currentValue().swapPayload(v); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + } + break; + case tokenArraySeparator: + case tokenObjectEnd: + case tokenArrayEnd: + if (features_.allowDroppedNullPlaceholders_) { + // "Un-read" the current token and mark the current value as a null + // token. + current_--; + Value v; + currentValue().swapPayload(v); + currentValue().setOffsetStart(current_ - begin_ - 1); + currentValue().setOffsetLimit(current_ - begin_); + break; + } // else, fall through ... + default: + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + return addError("Syntax error: value, object or array expected.", token); + } + + if (collectComments_) { + lastValueEnd_ = current_; + lastValue_ = ¤tValue(); + } + + --stackDepth_; + return successful; +} + +void OurReader::skipCommentTokens(Token& token) { + if (features_.allowComments_) { + do { + readToken(token); + } while (token.type_ == tokenComment); + } else { + readToken(token); + } +} + +bool OurReader::readToken(Token& token) { + skipSpaces(); + token.start_ = current_; + Char c = getNextChar(); + bool ok = true; + switch (c) { + case '{': + token.type_ = tokenObjectBegin; + break; + case '}': + token.type_ = tokenObjectEnd; + break; + case '[': + token.type_ = tokenArrayBegin; + break; + case ']': + token.type_ = tokenArrayEnd; + break; + case '"': + token.type_ = tokenString; + ok = readString(); + break; + case '\'': + if (features_.allowSingleQuotes_) { + token.type_ = tokenString; + ok = readStringSingleQuote(); + break; + } // else continue + case '/': + token.type_ = tokenComment; + ok = readComment(); + break; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '-': + token.type_ = tokenNumber; + readNumber(); + break; + case 't': + token.type_ = tokenTrue; + ok = match("rue", 3); + break; + case 'f': + token.type_ = tokenFalse; + ok = match("alse", 4); + break; + case 'n': + token.type_ = tokenNull; + ok = match("ull", 3); + break; + case ',': + token.type_ = tokenArraySeparator; + break; + case ':': + token.type_ = tokenMemberSeparator; + break; + case 0: + token.type_ = tokenEndOfStream; + break; + default: + ok = false; + break; + } + if (!ok) + token.type_ = tokenError; + token.end_ = current_; + return true; +} + +void OurReader::skipSpaces() { + while (current_ != end_) { + Char c = *current_; + if (c == ' ' || c == '\t' || c == '\r' || c == '\n') + ++current_; + else + break; + } +} + +bool OurReader::match(Location pattern, int patternLength) { + if (end_ - current_ < patternLength) + return false; + int index = patternLength; + while (index--) + if (current_[index] != pattern[index]) + return false; + current_ += patternLength; + return true; +} + +bool OurReader::readComment() { + Location commentBegin = current_ - 1; + Char c = getNextChar(); + bool successful = false; + if (c == '*') + successful = readCStyleComment(); + else if (c == '/') + successful = readCppStyleComment(); + if (!successful) + return false; + + if (collectComments_) { + CommentPlacement placement = commentBefore; + if (lastValueEnd_ && !containsNewLine(lastValueEnd_, commentBegin)) { + if (c != '*' || !containsNewLine(commentBegin, current_)) + placement = commentAfterOnSameLine; + } + + addComment(commentBegin, current_, placement); + } + return true; +} + +void +OurReader::addComment(Location begin, Location end, CommentPlacement placement) { + assert(collectComments_); + const std::string& normalized = normalizeEOL(begin, end); + if (placement == commentAfterOnSameLine) { + assert(lastValue_ != 0); + lastValue_->setComment(normalized, placement); + } else { + commentsBefore_ += normalized; + } +} + +bool OurReader::readCStyleComment() { + while (current_ != end_) { + Char c = getNextChar(); + if (c == '*' && *current_ == '/') + break; + } + return getNextChar() == '/'; +} + +bool OurReader::readCppStyleComment() { + while (current_ != end_) { + Char c = getNextChar(); + if (c == '\n') + break; + if (c == '\r') { + // Consume DOS EOL. It will be normalized in addComment. + if (current_ != end_ && *current_ == '\n') + getNextChar(); + // Break on Moc OS 9 EOL. + break; + } + } + return true; +} + +void OurReader::readNumber() { + const char *p = current_; + char c = '0'; // stopgap for already consumed character + // integral part + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + // fractional part + if (c == '.') { + c = (current_ = p) < end_ ? *p++ : 0; + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + } + // exponential part + if (c == 'e' || c == 'E') { + c = (current_ = p) < end_ ? *p++ : 0; + if (c == '+' || c == '-') + c = (current_ = p) < end_ ? *p++ : 0; + while (c >= '0' && c <= '9') + c = (current_ = p) < end_ ? *p++ : 0; + } +} +bool OurReader::readString() { + Char c = 0; + while (current_ != end_) { + c = getNextChar(); + if (c == '\\') + getNextChar(); + else if (c == '"') + break; + } + return c == '"'; +} + + +bool OurReader::readStringSingleQuote() { + Char c = 0; + while (current_ != end_) { + c = getNextChar(); + if (c == '\\') + getNextChar(); + else if (c == '\'') + break; + } + return c == '\''; +} + +bool OurReader::readObject(Token& tokenStart) { + Token tokenName; + std::string name; + Value init(objectValue); + currentValue().swapPayload(init); + currentValue().setOffsetStart(tokenStart.start_ - begin_); + while (readToken(tokenName)) { + bool initialTokenOk = true; + while (tokenName.type_ == tokenComment && initialTokenOk) + initialTokenOk = readToken(tokenName); + if (!initialTokenOk) + break; + if (tokenName.type_ == tokenObjectEnd && name.empty()) // empty object + return true; + name = ""; + if (tokenName.type_ == tokenString) { + if (!decodeString(tokenName, name)) + return recoverFromError(tokenObjectEnd); + } else if (tokenName.type_ == tokenNumber && features_.allowNumericKeys_) { + Value numberName; + if (!decodeNumber(tokenName, numberName)) + return recoverFromError(tokenObjectEnd); + name = numberName.asString(); + } else { + break; + } + + Token colon; + if (!readToken(colon) || colon.type_ != tokenMemberSeparator) { + return addErrorAndRecover( + "Missing ':' after object member name", colon, tokenObjectEnd); + } + if (name.length() >= (1U<<30)) throwRuntimeError("keylength >= 2^30"); + if (features_.rejectDupKeys_ && currentValue().isMember(name)) { + std::string msg = "Duplicate key: '" + name + "'"; + return addErrorAndRecover( + msg, tokenName, tokenObjectEnd); + } + Value& value = currentValue()[name]; + nodes_.push(&value); + bool ok = readValue(); + nodes_.pop(); + if (!ok) // error already set + return recoverFromError(tokenObjectEnd); + + Token comma; + if (!readToken(comma) || + (comma.type_ != tokenObjectEnd && comma.type_ != tokenArraySeparator && + comma.type_ != tokenComment)) { + return addErrorAndRecover( + "Missing ',' or '}' in object declaration", comma, tokenObjectEnd); + } + bool finalizeTokenOk = true; + while (comma.type_ == tokenComment && finalizeTokenOk) + finalizeTokenOk = readToken(comma); + if (comma.type_ == tokenObjectEnd) + return true; + } + return addErrorAndRecover( + "Missing '}' or object member name", tokenName, tokenObjectEnd); +} + +bool OurReader::readArray(Token& tokenStart) { + Value init(arrayValue); + currentValue().swapPayload(init); + currentValue().setOffsetStart(tokenStart.start_ - begin_); + skipSpaces(); + if (*current_ == ']') // empty array + { + Token endArray; + readToken(endArray); + return true; + } + int index = 0; + for (;;) { + Value& value = currentValue()[index++]; + nodes_.push(&value); + bool ok = readValue(); + nodes_.pop(); + if (!ok) // error already set + return recoverFromError(tokenArrayEnd); + + Token token; + // Accept Comment after last item in the array. + ok = readToken(token); + while (token.type_ == tokenComment && ok) { + ok = readToken(token); + } + bool badTokenType = + (token.type_ != tokenArraySeparator && token.type_ != tokenArrayEnd); + if (!ok || badTokenType) { + return addErrorAndRecover( + "Missing ',' or ']' in array declaration", token, tokenArrayEnd); + } + if (token.type_ == tokenArrayEnd) + break; + } + return true; +} + +bool OurReader::decodeNumber(Token& token) { + Value decoded; + if (!decodeNumber(token, decoded)) + return false; + currentValue().swapPayload(decoded); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + return true; +} + +bool OurReader::decodeNumber(Token& token, Value& decoded) { + // Attempts to parse the number as an integer. If the number is + // larger than the maximum supported value of an integer then + // we decode the number as a double. + Location current = token.start_; + bool isNegative = *current == '-'; + if (isNegative) + ++current; + // TODO: Help the compiler do the div and mod at compile time or get rid of them. + Value::LargestUInt maxIntegerValue = + isNegative ? Value::LargestUInt(-Value::minLargestInt) + : Value::maxLargestUInt; + Value::LargestUInt threshold = maxIntegerValue / 10; + Value::LargestUInt value = 0; + while (current < token.end_) { + Char c = *current++; + if (c < '0' || c > '9') + return decodeDouble(token, decoded); + Value::UInt digit(c - '0'); + if (value >= threshold) { + // We've hit or exceeded the max value divided by 10 (rounded down). If + // a) we've only just touched the limit, b) this is the last digit, and + // c) it's small enough to fit in that rounding delta, we're okay. + // Otherwise treat this number as a double to avoid overflow. + if (value > threshold || current != token.end_ || + digit > maxIntegerValue % 10) { + return decodeDouble(token, decoded); + } + } + value = value * 10 + digit; + } + if (isNegative) + decoded = -Value::LargestInt(value); + else if (value <= Value::LargestUInt(Value::maxInt)) + decoded = Value::LargestInt(value); + else + decoded = value; + return true; +} + +bool OurReader::decodeDouble(Token& token) { + Value decoded; + if (!decodeDouble(token, decoded)) + return false; + currentValue().swapPayload(decoded); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + return true; +} + +bool OurReader::decodeDouble(Token& token, Value& decoded) { + double value = 0; + const int bufferSize = 32; + int count; + int length = int(token.end_ - token.start_); + + // Sanity check to avoid buffer overflow exploits. + if (length < 0) { + return addError("Unable to parse token length", token); + } + + // Avoid using a string constant for the format control string given to + // sscanf, as this can cause hard to debug crashes on OS X. See here for more + // info: + // + // http://developer.apple.com/library/mac/#DOCUMENTATION/DeveloperTools/gcc-4.0.1/gcc/Incompatibilities.html + char format[] = "%lf"; + + if (length <= bufferSize) { + Char buffer[bufferSize + 1]; + memcpy(buffer, token.start_, length); + buffer[length] = 0; + count = sscanf(buffer, format, &value); + } else { + std::string buffer(token.start_, token.end_); + count = sscanf(buffer.c_str(), format, &value); + } + + if (count != 1) + return addError("'" + std::string(token.start_, token.end_) + + "' is not a number.", + token); + decoded = value; + return true; +} + +bool OurReader::decodeString(Token& token) { + std::string decoded_string; + if (!decodeString(token, decoded_string)) + return false; + Value decoded(decoded_string); + currentValue().swapPayload(decoded); + currentValue().setOffsetStart(token.start_ - begin_); + currentValue().setOffsetLimit(token.end_ - begin_); + return true; +} + +bool OurReader::decodeString(Token& token, std::string& decoded) { + decoded.reserve(token.end_ - token.start_ - 2); + Location current = token.start_ + 1; // skip '"' + Location end = token.end_ - 1; // do not include '"' + while (current != end) { + Char c = *current++; + if (c == '"') + break; + else if (c == '\\') { + if (current == end) + return addError("Empty escape sequence in string", token, current); + Char escape = *current++; + switch (escape) { + case '"': + decoded += '"'; + break; + case '/': + decoded += '/'; + break; + case '\\': + decoded += '\\'; + break; + case 'b': + decoded += '\b'; + break; + case 'f': + decoded += '\f'; + break; + case 'n': + decoded += '\n'; + break; + case 'r': + decoded += '\r'; + break; + case 't': + decoded += '\t'; + break; + case 'u': { + unsigned int unicode; + if (!decodeUnicodeCodePoint(token, current, end, unicode)) + return false; + decoded += codePointToUTF8(unicode); + } break; + default: + return addError("Bad escape sequence in string", token, current); + } + } else { + decoded += c; + } + } + return true; +} + +bool OurReader::decodeUnicodeCodePoint(Token& token, + Location& current, + Location end, + unsigned int& unicode) { + + if (!decodeUnicodeEscapeSequence(token, current, end, unicode)) + return false; + if (unicode >= 0xD800 && unicode <= 0xDBFF) { + // surrogate pairs + if (end - current < 6) + return addError( + "additional six characters expected to parse unicode surrogate pair.", + token, + current); + unsigned int surrogatePair; + if (*(current++) == '\\' && *(current++) == 'u') { + if (decodeUnicodeEscapeSequence(token, current, end, surrogatePair)) { + unicode = 0x10000 + ((unicode & 0x3FF) << 10) + (surrogatePair & 0x3FF); + } else + return false; + } else + return addError("expecting another \\u token to begin the second half of " + "a unicode surrogate pair", + token, + current); + } + return true; +} + +bool OurReader::decodeUnicodeEscapeSequence(Token& token, + Location& current, + Location end, + unsigned int& unicode) { + if (end - current < 4) + return addError( + "Bad unicode escape sequence in string: four digits expected.", + token, + current); + unicode = 0; + for (int index = 0; index < 4; ++index) { + Char c = *current++; + unicode *= 16; + if (c >= '0' && c <= '9') + unicode += c - '0'; + else if (c >= 'a' && c <= 'f') + unicode += c - 'a' + 10; + else if (c >= 'A' && c <= 'F') + unicode += c - 'A' + 10; + else + return addError( + "Bad unicode escape sequence in string: hexadecimal digit expected.", + token, + current); + } + return true; +} + +bool +OurReader::addError(const std::string& message, Token& token, Location extra) { + ErrorInfo info; + info.token_ = token; + info.message_ = message; + info.extra_ = extra; + errors_.push_back(info); + return false; +} + +bool OurReader::recoverFromError(TokenType skipUntilToken) { + int errorCount = int(errors_.size()); + Token skip; + for (;;) { + if (!readToken(skip)) + errors_.resize(errorCount); // discard errors caused by recovery + if (skip.type_ == skipUntilToken || skip.type_ == tokenEndOfStream) + break; + } + errors_.resize(errorCount); + return false; +} + +bool OurReader::addErrorAndRecover(const std::string& message, + Token& token, + TokenType skipUntilToken) { + addError(message, token); + return recoverFromError(skipUntilToken); +} + +Value& OurReader::currentValue() { return *(nodes_.top()); } + +OurReader::Char OurReader::getNextChar() { + if (current_ == end_) + return 0; + return *current_++; +} + +void OurReader::getLocationLineAndColumn(Location location, + int& line, + int& column) const { + Location current = begin_; + Location lastLineStart = current; + line = 0; + while (current < location && current != end_) { + Char c = *current++; + if (c == '\r') { + if (*current == '\n') + ++current; + lastLineStart = current; + ++line; + } else if (c == '\n') { + lastLineStart = current; + ++line; + } + } + // column & line start at 1 + column = int(location - lastLineStart) + 1; + ++line; +} + +std::string OurReader::getLocationLineAndColumn(Location location) const { + int line, column; + getLocationLineAndColumn(location, line, column); + char buffer[18 + 16 + 16 + 1]; +#if defined(_MSC_VER) && defined(__STDC_SECURE_LIB__) +#if defined(WINCE) + _snprintf(buffer, sizeof(buffer), "Line %d, Column %d", line, column); +#else + sprintf_s(buffer, sizeof(buffer), "Line %d, Column %d", line, column); +#endif +#else + snprintf(buffer, sizeof(buffer), "Line %d, Column %d", line, column); +#endif + return buffer; +} + +std::string OurReader::getFormattedErrorMessages() const { + std::string formattedMessage; + for (Errors::const_iterator itError = errors_.begin(); + itError != errors_.end(); + ++itError) { + const ErrorInfo& error = *itError; + formattedMessage += + "* " + getLocationLineAndColumn(error.token_.start_) + "\n"; + formattedMessage += " " + error.message_ + "\n"; + if (error.extra_) + formattedMessage += + "See " + getLocationLineAndColumn(error.extra_) + " for detail.\n"; + } + return formattedMessage; +} + +std::vector<OurReader::StructuredError> OurReader::getStructuredErrors() const { + std::vector<OurReader::StructuredError> allErrors; + for (Errors::const_iterator itError = errors_.begin(); + itError != errors_.end(); + ++itError) { + const ErrorInfo& error = *itError; + OurReader::StructuredError structured; + structured.offset_start = error.token_.start_ - begin_; + structured.offset_limit = error.token_.end_ - begin_; + structured.message = error.message_; + allErrors.push_back(structured); + } + return allErrors; +} + +bool OurReader::pushError(const Value& value, const std::string& message) { + size_t length = end_ - begin_; + if(value.getOffsetStart() > length + || value.getOffsetLimit() > length) + return false; + Token token; + token.type_ = tokenError; + token.start_ = begin_ + value.getOffsetStart(); + token.end_ = end_ + value.getOffsetLimit(); + ErrorInfo info; + info.token_ = token; + info.message_ = message; + info.extra_ = 0; + errors_.push_back(info); + return true; +} + +bool OurReader::pushError(const Value& value, const std::string& message, const Value& extra) { + size_t length = end_ - begin_; + if(value.getOffsetStart() > length + || value.getOffsetLimit() > length + || extra.getOffsetLimit() > length) + return false; + Token token; + token.type_ = tokenError; + token.start_ = begin_ + value.getOffsetStart(); + token.end_ = begin_ + value.getOffsetLimit(); + ErrorInfo info; + info.token_ = token; + info.message_ = message; + info.extra_ = begin_ + extra.getOffsetStart(); + errors_.push_back(info); + return true; +} + +bool OurReader::good() const { + return !errors_.size(); +} + + +class OurCharReader : public CharReader { + bool const collectComments_; + OurReader reader_; +public: + OurCharReader( + bool collectComments, + OurFeatures const& features) + : collectComments_(collectComments) + , reader_(features) + {} + virtual bool parse( + char const* beginDoc, char const* endDoc, + Value* root, std::string* errs) { + bool ok = reader_.parse(beginDoc, endDoc, *root, collectComments_); + if (errs) { + *errs = reader_.getFormattedErrorMessages(); + } + return ok; + } +}; + +CharReaderBuilder::CharReaderBuilder() +{ + setDefaults(&settings_); +} +CharReaderBuilder::~CharReaderBuilder() +{} +CharReader* CharReaderBuilder::newCharReader() const +{ + bool collectComments = settings_["collectComments"].asBool(); + OurFeatures features = OurFeatures::all(); + features.allowComments_ = settings_["allowComments"].asBool(); + features.strictRoot_ = settings_["strictRoot"].asBool(); + features.allowDroppedNullPlaceholders_ = settings_["allowDroppedNullPlaceholders"].asBool(); + features.allowNumericKeys_ = settings_["allowNumericKeys"].asBool(); + features.allowSingleQuotes_ = settings_["allowSingleQuotes"].asBool(); + features.stackLimit_ = settings_["stackLimit"].asInt(); + features.failIfExtra_ = settings_["failIfExtra"].asBool(); + features.rejectDupKeys_ = settings_["rejectDupKeys"].asBool(); + return new OurCharReader(collectComments, features); +} +static void getValidReaderKeys(std::set<std::string>* valid_keys) +{ + valid_keys->clear(); + valid_keys->insert("collectComments"); + valid_keys->insert("allowComments"); + valid_keys->insert("strictRoot"); + valid_keys->insert("allowDroppedNullPlaceholders"); + valid_keys->insert("allowNumericKeys"); + valid_keys->insert("allowSingleQuotes"); + valid_keys->insert("stackLimit"); + valid_keys->insert("failIfExtra"); + valid_keys->insert("rejectDupKeys"); +} +bool CharReaderBuilder::validate(Json::Value* invalid) const +{ + Json::Value my_invalid; + if (!invalid) invalid = &my_invalid; // so we do not need to test for NULL + Json::Value& inv = *invalid; + std::set<std::string> valid_keys; + getValidReaderKeys(&valid_keys); + Value::Members keys = settings_.getMemberNames(); + size_t n = keys.size(); + for (size_t i = 0; i < n; ++i) { + std::string const& key = keys[i]; + if (valid_keys.find(key) == valid_keys.end()) { + inv[key] = settings_[key]; + } + } + return 0u == inv.size(); +} +Value& CharReaderBuilder::operator[](std::string key) +{ + return settings_[key]; +} +// static +void CharReaderBuilder::strictMode(Json::Value* settings) +{ +//! [CharReaderBuilderStrictMode] + (*settings)["allowComments"] = false; + (*settings)["strictRoot"] = true; + (*settings)["allowDroppedNullPlaceholders"] = false; + (*settings)["allowNumericKeys"] = false; + (*settings)["allowSingleQuotes"] = false; + (*settings)["failIfExtra"] = true; + (*settings)["rejectDupKeys"] = true; +//! [CharReaderBuilderStrictMode] +} +// static +void CharReaderBuilder::setDefaults(Json::Value* settings) +{ +//! [CharReaderBuilderDefaults] + (*settings)["collectComments"] = true; + (*settings)["allowComments"] = true; + (*settings)["strictRoot"] = false; + (*settings)["allowDroppedNullPlaceholders"] = false; + (*settings)["allowNumericKeys"] = false; + (*settings)["allowSingleQuotes"] = false; + (*settings)["stackLimit"] = 1000; + (*settings)["failIfExtra"] = false; + (*settings)["rejectDupKeys"] = false; +//! [CharReaderBuilderDefaults] +} + +////////////////////////////////// +// global functions + +bool parseFromStream( + CharReader::Factory const& fact, std::istream& sin, + Value* root, std::string* errs) +{ + std::ostringstream ssin; + ssin << sin.rdbuf(); + std::string doc = ssin.str(); + char const* begin = doc.data(); + char const* end = begin + doc.size(); + // Note that we do not actually need a null-terminator. + CharReaderPtr const reader(fact.newCharReader()); + return reader->parse(begin, end, root, errs); +} + +std::istream& operator>>(std::istream& sin, Value& root) { + CharReaderBuilder b; + std::string errs; + bool ok = parseFromStream(b, sin, &root, &errs); + if (!ok) { + fprintf(stderr, + "Error from reader: %s", + errs.c_str()); + + throwRuntimeError("reader error"); + } + return sin; +} + +} // namespace Json + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: src/lib_json/json_reader.cpp +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: src/lib_json/json_valueiterator.inl +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2007-2010 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +// included by json_value.cpp + +namespace Json { + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class ValueIteratorBase +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +ValueIteratorBase::ValueIteratorBase() + : current_(), isNull_(true) { +} + +ValueIteratorBase::ValueIteratorBase( + const Value::ObjectValues::iterator& current) + : current_(current), isNull_(false) {} + +Value& ValueIteratorBase::deref() const { + return current_->second; +} + +void ValueIteratorBase::increment() { + ++current_; +} + +void ValueIteratorBase::decrement() { + --current_; +} + +ValueIteratorBase::difference_type +ValueIteratorBase::computeDistance(const SelfType& other) const { +#ifdef JSON_USE_CPPTL_SMALLMAP + return other.current_ - current_; +#else + // Iterator for null value are initialized using the default + // constructor, which initialize current_ to the default + // std::map::iterator. As begin() and end() are two instance + // of the default std::map::iterator, they can not be compared. + // To allow this, we handle this comparison specifically. + if (isNull_ && other.isNull_) { + return 0; + } + + // Usage of std::distance is not portable (does not compile with Sun Studio 12 + // RogueWave STL, + // which is the one used by default). + // Using a portable hand-made version for non random iterator instead: + // return difference_type( std::distance( current_, other.current_ ) ); + difference_type myDistance = 0; + for (Value::ObjectValues::iterator it = current_; it != other.current_; + ++it) { + ++myDistance; + } + return myDistance; +#endif +} + +bool ValueIteratorBase::isEqual(const SelfType& other) const { + if (isNull_) { + return other.isNull_; + } + return current_ == other.current_; +} + +void ValueIteratorBase::copy(const SelfType& other) { + current_ = other.current_; + isNull_ = other.isNull_; +} + +Value ValueIteratorBase::key() const { + const Value::CZString czstring = (*current_).first; + if (czstring.data()) { + if (czstring.isStaticString()) + return Value(StaticString(czstring.data())); + return Value(czstring.data(), czstring.data() + czstring.length()); + } + return Value(czstring.index()); +} + +UInt ValueIteratorBase::index() const { + const Value::CZString czstring = (*current_).first; + if (!czstring.data()) + return czstring.index(); + return Value::UInt(-1); +} + +std::string ValueIteratorBase::name() const { + char const* key; + char const* end; + key = memberName(&end); + if (!key) return std::string(); + return std::string(key, end); +} + +char const* ValueIteratorBase::memberName() const { + const char* name = (*current_).first.data(); + return name ? name : ""; +} + +char const* ValueIteratorBase::memberName(char const** end) const { + const char* name = (*current_).first.data(); + if (!name) { + *end = NULL; + return NULL; + } + *end = name + (*current_).first.length(); + return name; +} + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class ValueConstIterator +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +ValueConstIterator::ValueConstIterator() {} + +ValueConstIterator::ValueConstIterator( + const Value::ObjectValues::iterator& current) + : ValueIteratorBase(current) {} + +ValueConstIterator& ValueConstIterator:: +operator=(const ValueIteratorBase& other) { + copy(other); + return *this; +} + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class ValueIterator +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +ValueIterator::ValueIterator() {} + +ValueIterator::ValueIterator(const Value::ObjectValues::iterator& current) + : ValueIteratorBase(current) {} + +ValueIterator::ValueIterator(const ValueConstIterator& other) + : ValueIteratorBase(other) {} + +ValueIterator::ValueIterator(const ValueIterator& other) + : ValueIteratorBase(other) {} + +ValueIterator& ValueIterator::operator=(const SelfType& other) { + copy(other); + return *this; +} + +} // namespace Json + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: src/lib_json/json_valueiterator.inl +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: src/lib_json/json_value.cpp +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2011 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#if !defined(JSON_IS_AMALGAMATION) +#include <json/assertions.h> +#include <json/value.h> +#include <json/writer.h> +#endif // if !defined(JSON_IS_AMALGAMATION) +#include <math.h> +#include <sstream> +#include <utility> +#include <cstring> +#include <cassert> +#ifdef JSON_USE_CPPTL +#include <cpptl/conststring.h> +#endif +#include <cstddef> // size_t +#include <algorithm> // min() + +#define JSON_ASSERT_UNREACHABLE assert(false) + +namespace Json { + +// This is a walkaround to avoid the static initialization of Value::null. +// kNull must be word-aligned to avoid crashing on ARM. We use an alignment of +// 8 (instead of 4) as a bit of future-proofing. +#if defined(__ARMEL__) +#define ALIGNAS(byte_alignment) __attribute__((aligned(byte_alignment))) +#else +#define ALIGNAS(byte_alignment) +#endif +static const unsigned char ALIGNAS(8) kNull[sizeof(Value)] = { 0 }; +const unsigned char& kNullRef = kNull[0]; +const Value& Value::null = reinterpret_cast<const Value&>(kNullRef); +const Value& Value::nullRef = null; + +const Int Value::minInt = Int(~(UInt(-1) / 2)); +const Int Value::maxInt = Int(UInt(-1) / 2); +const UInt Value::maxUInt = UInt(-1); +#if defined(JSON_HAS_INT64) +const Int64 Value::minInt64 = Int64(~(UInt64(-1) / 2)); +const Int64 Value::maxInt64 = Int64(UInt64(-1) / 2); +const UInt64 Value::maxUInt64 = UInt64(-1); +// The constant is hard-coded because some compiler have trouble +// converting Value::maxUInt64 to a double correctly (AIX/xlC). +// Assumes that UInt64 is a 64 bits integer. +static const double maxUInt64AsDouble = 18446744073709551615.0; +#endif // defined(JSON_HAS_INT64) +const LargestInt Value::minLargestInt = LargestInt(~(LargestUInt(-1) / 2)); +const LargestInt Value::maxLargestInt = LargestInt(LargestUInt(-1) / 2); +const LargestUInt Value::maxLargestUInt = LargestUInt(-1); + +#if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) +template <typename T, typename U> +static inline bool InRange(double d, T min, U max) { + return d >= min && d <= max; +} +#else // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) +static inline double integerToDouble(Json::UInt64 value) { + return static_cast<double>(Int64(value / 2)) * 2.0 + Int64(value & 1); +} + +template <typename T> static inline double integerToDouble(T value) { + return static_cast<double>(value); +} + +template <typename T, typename U> +static inline bool InRange(double d, T min, U max) { + return d >= integerToDouble(min) && d <= integerToDouble(max); +} +#endif // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + +/** Duplicates the specified string value. + * @param value Pointer to the string to duplicate. Must be zero-terminated if + * length is "unknown". + * @param length Length of the value. if equals to unknown, then it will be + * computed using strlen(value). + * @return Pointer on the duplicate instance of string. + */ +static inline char* duplicateStringValue(const char* value, + size_t length) { + // Avoid an integer overflow in the call to malloc below by limiting length + // to a sane value. + if (length >= (size_t)Value::maxInt) + length = Value::maxInt - 1; + + char* newString = static_cast<char*>(malloc(length + 1)); + if (newString == NULL) { + throwRuntimeError( + "in Json::Value::duplicateStringValue(): " + "Failed to allocate string value buffer"); + } + memcpy(newString, value, length); + newString[length] = 0; + return newString; +} + +/* Record the length as a prefix. + */ +static inline char* duplicateAndPrefixStringValue( + const char* value, + unsigned int length) +{ + // Avoid an integer overflow in the call to malloc below by limiting length + // to a sane value. + JSON_ASSERT_MESSAGE(length <= (unsigned)Value::maxInt - sizeof(unsigned) - 1U, + "in Json::Value::duplicateAndPrefixStringValue(): " + "length too big for prefixing"); + unsigned actualLength = length + sizeof(unsigned) + 1U; + char* newString = static_cast<char*>(malloc(actualLength)); + if (newString == 0) { + throwRuntimeError( + "in Json::Value::duplicateAndPrefixStringValue(): " + "Failed to allocate string value buffer"); + } + *reinterpret_cast<unsigned*>(newString) = length; + memcpy(newString + sizeof(unsigned), value, length); + newString[actualLength - 1U] = 0; // to avoid buffer over-run accidents by users later + return newString; +} +inline static void decodePrefixedString( + bool isPrefixed, char const* prefixed, + unsigned* length, char const** value) +{ + if (!isPrefixed) { + *length = static_cast<unsigned>(strlen(prefixed)); + *value = prefixed; + } else { + *length = *reinterpret_cast<unsigned const*>(prefixed); + *value = prefixed + sizeof(unsigned); + } +} +/** Free the string duplicated by duplicateStringValue()/duplicateAndPrefixStringValue(). + */ +static inline void releaseStringValue(char* value) { free(value); } + +} // namespace Json + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ValueInternals... +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +#if !defined(JSON_IS_AMALGAMATION) + +#include "json_valueiterator.inl" +#endif // if !defined(JSON_IS_AMALGAMATION) + +namespace Json { + +class JSON_API Exception : public std::exception { +public: + Exception(std::string const& msg); + virtual ~Exception() throw(); + virtual char const* what() const throw(); +protected: + std::string const msg_; +}; +class JSON_API RuntimeError : public Exception { +public: + RuntimeError(std::string const& msg); +}; +class JSON_API LogicError : public Exception { +public: + LogicError(std::string const& msg); +}; + +Exception::Exception(std::string const& msg) + : msg_(msg) +{} +Exception::~Exception() throw() +{} +char const* Exception::what() const throw() +{ + return msg_.c_str(); +} +RuntimeError::RuntimeError(std::string const& msg) + : Exception(msg) +{} +LogicError::LogicError(std::string const& msg) + : Exception(msg) +{} +void throwRuntimeError(std::string const& msg) +{ + throw RuntimeError(msg); +} +void throwLogicError(std::string const& msg) +{ + throw LogicError(msg); +} + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class Value::CommentInfo +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +Value::CommentInfo::CommentInfo() : comment_(0) {} + +Value::CommentInfo::~CommentInfo() { + if (comment_) + releaseStringValue(comment_); +} + +void Value::CommentInfo::setComment(const char* text, size_t len) { + if (comment_) { + releaseStringValue(comment_); + comment_ = 0; + } + JSON_ASSERT(text != 0); + JSON_ASSERT_MESSAGE( + text[0] == '\0' || text[0] == '/', + "in Json::Value::setComment(): Comments must start with /"); + // It seems that /**/ style comments are acceptable as well. + comment_ = duplicateStringValue(text, len); +} + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class Value::CZString +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +// Notes: policy_ indicates if the string was allocated when +// a string is stored. + +Value::CZString::CZString(ArrayIndex index) : cstr_(0), index_(index) {} + +Value::CZString::CZString(char const* str, unsigned length, DuplicationPolicy allocate) + : cstr_(str) +{ + // allocate != duplicate + storage_.policy_ = allocate; + storage_.length_ = length; +} + +Value::CZString::CZString(const CZString& other) + : cstr_(other.storage_.policy_ != noDuplication && other.cstr_ != 0 + ? duplicateStringValue(other.cstr_, other.storage_.length_) + : other.cstr_) +{ + storage_.policy_ = (other.cstr_ + ? (other.storage_.policy_ == noDuplication + ? noDuplication : duplicate) + : other.storage_.policy_); + storage_.length_ = other.storage_.length_; +} + +Value::CZString::~CZString() { + if (cstr_ && storage_.policy_ == duplicate) + releaseStringValue(const_cast<char*>(cstr_)); +} + +void Value::CZString::swap(CZString& other) { + std::swap(cstr_, other.cstr_); + std::swap(index_, other.index_); +} + +Value::CZString& Value::CZString::operator=(CZString other) { + swap(other); + return *this; +} + +bool Value::CZString::operator<(const CZString& other) const { + if (!cstr_) return index_ < other.index_; + //return strcmp(cstr_, other.cstr_) < 0; + // Assume both are strings. + unsigned this_len = this->storage_.length_; + unsigned other_len = other.storage_.length_; + unsigned min_len = std::min(this_len, other_len); + int comp = memcmp(this->cstr_, other.cstr_, min_len); + if (comp < 0) return true; + if (comp > 0) return false; + return (this_len < other_len); +} + +bool Value::CZString::operator==(const CZString& other) const { + if (!cstr_) return index_ == other.index_; + //return strcmp(cstr_, other.cstr_) == 0; + // Assume both are strings. + unsigned this_len = this->storage_.length_; + unsigned other_len = other.storage_.length_; + if (this_len != other_len) return false; + int comp = memcmp(this->cstr_, other.cstr_, this_len); + return comp == 0; +} + +ArrayIndex Value::CZString::index() const { return index_; } + +//const char* Value::CZString::c_str() const { return cstr_; } +const char* Value::CZString::data() const { return cstr_; } +unsigned Value::CZString::length() const { return storage_.length_; } +bool Value::CZString::isStaticString() const { return storage_.policy_ == noDuplication; } + +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// class Value::Value +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// +// ////////////////////////////////////////////////////////////////// + +/*! \internal Default constructor initialization must be equivalent to: + * memset( this, 0, sizeof(Value) ) + * This optimization is used in ValueInternalMap fast allocator. + */ +Value::Value(ValueType type) { + initBasic(type); + switch (type) { + case nullValue: + break; + case intValue: + case uintValue: + value_.int_ = 0; + break; + case realValue: + value_.real_ = 0.0; + break; + case stringValue: + value_.string_ = 0; + break; + case arrayValue: + case objectValue: + value_.map_ = new ObjectValues(); + break; + case booleanValue: + value_.bool_ = false; + break; + default: + JSON_ASSERT_UNREACHABLE; + } +} + +Value::Value(Int value) { + initBasic(intValue); + value_.int_ = value; +} + +Value::Value(UInt value) { + initBasic(uintValue); + value_.uint_ = value; +} +#if defined(JSON_HAS_INT64) +Value::Value(Int64 value) { + initBasic(intValue); + value_.int_ = value; +} +Value::Value(UInt64 value) { + initBasic(uintValue); + value_.uint_ = value; +} +#endif // defined(JSON_HAS_INT64) + +Value::Value(double value) { + initBasic(realValue); + value_.real_ = value; +} + +Value::Value(const char* value) { + initBasic(stringValue, true); + value_.string_ = duplicateAndPrefixStringValue(value, static_cast<unsigned>(strlen(value))); +} + +Value::Value(const char* beginValue, const char* endValue) { + initBasic(stringValue, true); + value_.string_ = + duplicateAndPrefixStringValue(beginValue, static_cast<unsigned>(endValue - beginValue)); +} + +Value::Value(const std::string& value) { + initBasic(stringValue, true); + value_.string_ = + duplicateAndPrefixStringValue(value.data(), static_cast<unsigned>(value.length())); +} + +Value::Value(const StaticString& value) { + initBasic(stringValue); + value_.string_ = const_cast<char*>(value.c_str()); +} + +#ifdef JSON_USE_CPPTL +Value::Value(const CppTL::ConstString& value) { + initBasic(stringValue, true); + value_.string_ = duplicateAndPrefixStringValue(value, static_cast<unsigned>(value.length())); +} +#endif + +Value::Value(bool value) { + initBasic(booleanValue); + value_.bool_ = value; +} + +Value::Value(Value const& other) + : type_(other.type_), allocated_(false) + , + comments_(0), start_(other.start_), limit_(other.limit_) +{ + switch (type_) { + case nullValue: + case intValue: + case uintValue: + case realValue: + case booleanValue: + value_ = other.value_; + break; + case stringValue: + if (other.value_.string_ && other.allocated_) { + unsigned len; + char const* str; + decodePrefixedString(other.allocated_, other.value_.string_, + &len, &str); + value_.string_ = duplicateAndPrefixStringValue(str, len); + allocated_ = true; + } else { + value_.string_ = other.value_.string_; + allocated_ = false; + } + break; + case arrayValue: + case objectValue: + value_.map_ = new ObjectValues(*other.value_.map_); + break; + default: + JSON_ASSERT_UNREACHABLE; + } + if (other.comments_) { + comments_ = new CommentInfo[numberOfCommentPlacement]; + for (int comment = 0; comment < numberOfCommentPlacement; ++comment) { + const CommentInfo& otherComment = other.comments_[comment]; + if (otherComment.comment_) + comments_[comment].setComment( + otherComment.comment_, strlen(otherComment.comment_)); + } + } +} + +Value::~Value() { + switch (type_) { + case nullValue: + case intValue: + case uintValue: + case realValue: + case booleanValue: + break; + case stringValue: + if (allocated_) + releaseStringValue(value_.string_); + break; + case arrayValue: + case objectValue: + delete value_.map_; + break; + default: + JSON_ASSERT_UNREACHABLE; + } + + if (comments_) + delete[] comments_; +} + +Value& Value::operator=(Value other) { + swap(other); + return *this; +} + +void Value::swapPayload(Value& other) { + ValueType temp = type_; + type_ = other.type_; + other.type_ = temp; + std::swap(value_, other.value_); + int temp2 = allocated_; + allocated_ = other.allocated_; + other.allocated_ = temp2; +} + +void Value::swap(Value& other) { + swapPayload(other); + std::swap(comments_, other.comments_); + std::swap(start_, other.start_); + std::swap(limit_, other.limit_); +} + +ValueType Value::type() const { return type_; } + +int Value::compare(const Value& other) const { + if (*this < other) + return -1; + if (*this > other) + return 1; + return 0; +} + +bool Value::operator<(const Value& other) const { + int typeDelta = type_ - other.type_; + if (typeDelta) + return typeDelta < 0 ? true : false; + switch (type_) { + case nullValue: + return false; + case intValue: + return value_.int_ < other.value_.int_; + case uintValue: + return value_.uint_ < other.value_.uint_; + case realValue: + return value_.real_ < other.value_.real_; + case booleanValue: + return value_.bool_ < other.value_.bool_; + case stringValue: + { + if ((value_.string_ == 0) || (other.value_.string_ == 0)) { + if (other.value_.string_) return true; + else return false; + } + unsigned this_len; + unsigned other_len; + char const* this_str; + char const* other_str; + decodePrefixedString(this->allocated_, this->value_.string_, &this_len, &this_str); + decodePrefixedString(other.allocated_, other.value_.string_, &other_len, &other_str); + unsigned min_len = std::min(this_len, other_len); + int comp = memcmp(this_str, other_str, min_len); + if (comp < 0) return true; + if (comp > 0) return false; + return (this_len < other_len); + } + case arrayValue: + case objectValue: { + int delta = int(value_.map_->size() - other.value_.map_->size()); + if (delta) + return delta < 0; + return (*value_.map_) < (*other.value_.map_); + } + default: + JSON_ASSERT_UNREACHABLE; + } + return false; // unreachable +} + +bool Value::operator<=(const Value& other) const { return !(other < *this); } + +bool Value::operator>=(const Value& other) const { return !(*this < other); } + +bool Value::operator>(const Value& other) const { return other < *this; } + +bool Value::operator==(const Value& other) const { + // if ( type_ != other.type_ ) + // GCC 2.95.3 says: + // attempt to take address of bit-field structure member `Json::Value::type_' + // Beats me, but a temp solves the problem. + int temp = other.type_; + if (type_ != temp) + return false; + switch (type_) { + case nullValue: + return true; + case intValue: + return value_.int_ == other.value_.int_; + case uintValue: + return value_.uint_ == other.value_.uint_; + case realValue: + return value_.real_ == other.value_.real_; + case booleanValue: + return value_.bool_ == other.value_.bool_; + case stringValue: + { + if ((value_.string_ == 0) || (other.value_.string_ == 0)) { + return (value_.string_ == other.value_.string_); + } + unsigned this_len; + unsigned other_len; + char const* this_str; + char const* other_str; + decodePrefixedString(this->allocated_, this->value_.string_, &this_len, &this_str); + decodePrefixedString(other.allocated_, other.value_.string_, &other_len, &other_str); + if (this_len != other_len) return false; + int comp = memcmp(this_str, other_str, this_len); + return comp == 0; + } + case arrayValue: + case objectValue: + return value_.map_->size() == other.value_.map_->size() && + (*value_.map_) == (*other.value_.map_); + default: + JSON_ASSERT_UNREACHABLE; + } + return false; // unreachable +} + +bool Value::operator!=(const Value& other) const { return !(*this == other); } + +const char* Value::asCString() const { + JSON_ASSERT_MESSAGE(type_ == stringValue, + "in Json::Value::asCString(): requires stringValue"); + if (value_.string_ == 0) return 0; + unsigned this_len; + char const* this_str; + decodePrefixedString(this->allocated_, this->value_.string_, &this_len, &this_str); + return this_str; +} + +bool Value::getString(char const** str, char const** end) const { + if (type_ != stringValue) return false; + if (value_.string_ == 0) return false; + unsigned length; + decodePrefixedString(this->allocated_, this->value_.string_, &length, str); + *end = *str + length; + return true; +} + +std::string Value::asString() const { + switch (type_) { + case nullValue: + return ""; + case stringValue: + { + if (value_.string_ == 0) return ""; + unsigned this_len; + char const* this_str; + decodePrefixedString(this->allocated_, this->value_.string_, &this_len, &this_str); + return std::string(this_str, this_len); + } + case booleanValue: + return value_.bool_ ? "true" : "false"; + case intValue: + return valueToString(value_.int_); + case uintValue: + return valueToString(value_.uint_); + case realValue: + return valueToString(value_.real_); + default: + JSON_FAIL_MESSAGE("Type is not convertible to string"); + } +} + +#ifdef JSON_USE_CPPTL +CppTL::ConstString Value::asConstString() const { + unsigned len; + char const* str; + decodePrefixedString(allocated_, value_.string_, + &len, &str); + return CppTL::ConstString(str, len); +} +#endif + +Value::Int Value::asInt() const { + switch (type_) { + case intValue: + JSON_ASSERT_MESSAGE(isInt(), "LargestInt out of Int range"); + return Int(value_.int_); + case uintValue: + JSON_ASSERT_MESSAGE(isInt(), "LargestUInt out of Int range"); + return Int(value_.uint_); + case realValue: + JSON_ASSERT_MESSAGE(InRange(value_.real_, minInt, maxInt), + "double out of Int range"); + return Int(value_.real_); + case nullValue: + return 0; + case booleanValue: + return value_.bool_ ? 1 : 0; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to Int."); +} + +Value::UInt Value::asUInt() const { + switch (type_) { + case intValue: + JSON_ASSERT_MESSAGE(isUInt(), "LargestInt out of UInt range"); + return UInt(value_.int_); + case uintValue: + JSON_ASSERT_MESSAGE(isUInt(), "LargestUInt out of UInt range"); + return UInt(value_.uint_); + case realValue: + JSON_ASSERT_MESSAGE(InRange(value_.real_, 0, maxUInt), + "double out of UInt range"); + return UInt(value_.real_); + case nullValue: + return 0; + case booleanValue: + return value_.bool_ ? 1 : 0; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to UInt."); +} + +#if defined(JSON_HAS_INT64) + +Value::Int64 Value::asInt64() const { + switch (type_) { + case intValue: + return Int64(value_.int_); + case uintValue: + JSON_ASSERT_MESSAGE(isInt64(), "LargestUInt out of Int64 range"); + return Int64(value_.uint_); + case realValue: + JSON_ASSERT_MESSAGE(InRange(value_.real_, minInt64, maxInt64), + "double out of Int64 range"); + return Int64(value_.real_); + case nullValue: + return 0; + case booleanValue: + return value_.bool_ ? 1 : 0; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to Int64."); +} + +Value::UInt64 Value::asUInt64() const { + switch (type_) { + case intValue: + JSON_ASSERT_MESSAGE(isUInt64(), "LargestInt out of UInt64 range"); + return UInt64(value_.int_); + case uintValue: + return UInt64(value_.uint_); + case realValue: + JSON_ASSERT_MESSAGE(InRange(value_.real_, 0, maxUInt64), + "double out of UInt64 range"); + return UInt64(value_.real_); + case nullValue: + return 0; + case booleanValue: + return value_.bool_ ? 1 : 0; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to UInt64."); +} +#endif // if defined(JSON_HAS_INT64) + +LargestInt Value::asLargestInt() const { +#if defined(JSON_NO_INT64) + return asInt(); +#else + return asInt64(); +#endif +} + +LargestUInt Value::asLargestUInt() const { +#if defined(JSON_NO_INT64) + return asUInt(); +#else + return asUInt64(); +#endif +} + +double Value::asDouble() const { + switch (type_) { + case intValue: + return static_cast<double>(value_.int_); + case uintValue: +#if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + return static_cast<double>(value_.uint_); +#else // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + return integerToDouble(value_.uint_); +#endif // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + case realValue: + return value_.real_; + case nullValue: + return 0.0; + case booleanValue: + return value_.bool_ ? 1.0 : 0.0; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to double."); +} + +float Value::asFloat() const { + switch (type_) { + case intValue: + return static_cast<float>(value_.int_); + case uintValue: +#if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + return static_cast<float>(value_.uint_); +#else // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + return integerToDouble(value_.uint_); +#endif // if !defined(JSON_USE_INT64_DOUBLE_CONVERSION) + case realValue: + return static_cast<float>(value_.real_); + case nullValue: + return 0.0; + case booleanValue: + return value_.bool_ ? 1.0f : 0.0f; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to float."); +} + +bool Value::asBool() const { + switch (type_) { + case booleanValue: + return value_.bool_; + case nullValue: + return false; + case intValue: + return value_.int_ ? true : false; + case uintValue: + return value_.uint_ ? true : false; + case realValue: + return value_.real_ ? true : false; + default: + break; + } + JSON_FAIL_MESSAGE("Value is not convertible to bool."); +} + +bool Value::isConvertibleTo(ValueType other) const { + switch (other) { + case nullValue: + return (isNumeric() && asDouble() == 0.0) || + (type_ == booleanValue && value_.bool_ == false) || + (type_ == stringValue && asString() == "") || + (type_ == arrayValue && value_.map_->size() == 0) || + (type_ == objectValue && value_.map_->size() == 0) || + type_ == nullValue; + case intValue: + return isInt() || + (type_ == realValue && InRange(value_.real_, minInt, maxInt)) || + type_ == booleanValue || type_ == nullValue; + case uintValue: + return isUInt() || + (type_ == realValue && InRange(value_.real_, 0, maxUInt)) || + type_ == booleanValue || type_ == nullValue; + case realValue: + return isNumeric() || type_ == booleanValue || type_ == nullValue; + case booleanValue: + return isNumeric() || type_ == booleanValue || type_ == nullValue; + case stringValue: + return isNumeric() || type_ == booleanValue || type_ == stringValue || + type_ == nullValue; + case arrayValue: + return type_ == arrayValue || type_ == nullValue; + case objectValue: + return type_ == objectValue || type_ == nullValue; + } + JSON_ASSERT_UNREACHABLE; + return false; +} + +/// Number of values in array or object +ArrayIndex Value::size() const { + switch (type_) { + case nullValue: + case intValue: + case uintValue: + case realValue: + case booleanValue: + case stringValue: + return 0; + case arrayValue: // size of the array is highest index + 1 + if (!value_.map_->empty()) { + ObjectValues::const_iterator itLast = value_.map_->end(); + --itLast; + return (*itLast).first.index() + 1; + } + return 0; + case objectValue: + return ArrayIndex(value_.map_->size()); + } + JSON_ASSERT_UNREACHABLE; + return 0; // unreachable; +} + +bool Value::empty() const { + if (isNull() || isArray() || isObject()) + return size() == 0u; + else + return false; +} + +bool Value::operator!() const { return isNull(); } + +void Value::clear() { + JSON_ASSERT_MESSAGE(type_ == nullValue || type_ == arrayValue || + type_ == objectValue, + "in Json::Value::clear(): requires complex value"); + start_ = 0; + limit_ = 0; + switch (type_) { + case arrayValue: + case objectValue: + value_.map_->clear(); + break; + default: + break; + } +} + +void Value::resize(ArrayIndex newSize) { + JSON_ASSERT_MESSAGE(type_ == nullValue || type_ == arrayValue, + "in Json::Value::resize(): requires arrayValue"); + if (type_ == nullValue) + *this = Value(arrayValue); + ArrayIndex oldSize = size(); + if (newSize == 0) + clear(); + else if (newSize > oldSize) + (*this)[newSize - 1]; + else { + for (ArrayIndex index = newSize; index < oldSize; ++index) { + value_.map_->erase(index); + } + assert(size() == newSize); + } +} + +Value& Value::operator[](ArrayIndex index) { + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == arrayValue, + "in Json::Value::operator[](ArrayIndex): requires arrayValue"); + if (type_ == nullValue) + *this = Value(arrayValue); + CZString key(index); + ObjectValues::iterator it = value_.map_->lower_bound(key); + if (it != value_.map_->end() && (*it).first == key) + return (*it).second; + + ObjectValues::value_type defaultValue(key, nullRef); + it = value_.map_->insert(it, defaultValue); + return (*it).second; +} + +Value& Value::operator[](int index) { + JSON_ASSERT_MESSAGE( + index >= 0, + "in Json::Value::operator[](int index): index cannot be negative"); + return (*this)[ArrayIndex(index)]; +} + +const Value& Value::operator[](ArrayIndex index) const { + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == arrayValue, + "in Json::Value::operator[](ArrayIndex)const: requires arrayValue"); + if (type_ == nullValue) + return nullRef; + CZString key(index); + ObjectValues::const_iterator it = value_.map_->find(key); + if (it == value_.map_->end()) + return nullRef; + return (*it).second; +} + +const Value& Value::operator[](int index) const { + JSON_ASSERT_MESSAGE( + index >= 0, + "in Json::Value::operator[](int index) const: index cannot be negative"); + return (*this)[ArrayIndex(index)]; +} + +void Value::initBasic(ValueType type, bool allocated) { + type_ = type; + allocated_ = allocated; + comments_ = 0; + start_ = 0; + limit_ = 0; +} + +// Access an object value by name, create a null member if it does not exist. +// @pre Type of '*this' is object or null. +// @param key is null-terminated. +Value& Value::resolveReference(const char* key) { + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == objectValue, + "in Json::Value::resolveReference(): requires objectValue"); + if (type_ == nullValue) + *this = Value(objectValue); + CZString actualKey( + key, static_cast<unsigned>(strlen(key)), CZString::noDuplication); // NOTE! + ObjectValues::iterator it = value_.map_->lower_bound(actualKey); + if (it != value_.map_->end() && (*it).first == actualKey) + return (*it).second; + + ObjectValues::value_type defaultValue(actualKey, nullRef); + it = value_.map_->insert(it, defaultValue); + Value& value = (*it).second; + return value; +} + +// @param key is not null-terminated. +Value& Value::resolveReference(char const* key, char const* end) +{ + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == objectValue, + "in Json::Value::resolveReference(key, end): requires objectValue"); + if (type_ == nullValue) + *this = Value(objectValue); + CZString actualKey( + key, static_cast<unsigned>(end-key), CZString::duplicateOnCopy); + ObjectValues::iterator it = value_.map_->lower_bound(actualKey); + if (it != value_.map_->end() && (*it).first == actualKey) + return (*it).second; + + ObjectValues::value_type defaultValue(actualKey, nullRef); + it = value_.map_->insert(it, defaultValue); + Value& value = (*it).second; + return value; +} + +Value Value::get(ArrayIndex index, const Value& defaultValue) const { + const Value* value = &((*this)[index]); + return value == &nullRef ? defaultValue : *value; +} + +bool Value::isValidIndex(ArrayIndex index) const { return index < size(); } + +Value const* Value::find(char const* key, char const* end) const +{ + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == objectValue, + "in Json::Value::find(key, end, found): requires objectValue or nullValue"); + if (type_ == nullValue) return NULL; + CZString actualKey(key, static_cast<unsigned>(end-key), CZString::noDuplication); + ObjectValues::const_iterator it = value_.map_->find(actualKey); + if (it == value_.map_->end()) return NULL; + return &(*it).second; +} +const Value& Value::operator[](const char* key) const +{ + Value const* found = find(key, key + strlen(key)); + if (!found) return nullRef; + return *found; +} +Value const& Value::operator[](std::string const& key) const +{ + Value const* found = find(key.data(), key.data() + key.length()); + if (!found) return nullRef; + return *found; +} + +Value& Value::operator[](const char* key) { + return resolveReference(key, key + strlen(key)); +} + +Value& Value::operator[](const std::string& key) { + return resolveReference(key.data(), key.data() + key.length()); +} + +Value& Value::operator[](const StaticString& key) { + return resolveReference(key.c_str()); +} + +#ifdef JSON_USE_CPPTL +Value& Value::operator[](const CppTL::ConstString& key) { + return resolveReference(key.c_str(), key.end_c_str()); +} +Value const& Value::operator[](CppTL::ConstString const& key) const +{ + Value const* found = find(key.c_str(), key.end_c_str()); + if (!found) return nullRef; + return *found; +} +#endif + +Value& Value::append(const Value& value) { return (*this)[size()] = value; } + +Value Value::get(char const* key, char const* end, Value const& defaultValue) const +{ + Value const* found = find(key, end); + return !found ? defaultValue : *found; +} +Value Value::get(char const* key, Value const& defaultValue) const +{ + return get(key, key + strlen(key), defaultValue); +} +Value Value::get(std::string const& key, Value const& defaultValue) const +{ + return get(key.data(), key.data() + key.length(), defaultValue); +} + + +bool Value::removeMember(const char* key, const char* end, Value* removed) +{ + if (type_ != objectValue) { + return false; + } + CZString actualKey(key, static_cast<unsigned>(end-key), CZString::noDuplication); + ObjectValues::iterator it = value_.map_->find(actualKey); + if (it == value_.map_->end()) + return false; + *removed = it->second; + value_.map_->erase(it); + return true; +} +bool Value::removeMember(const char* key, Value* removed) +{ + return removeMember(key, key + strlen(key), removed); +} +bool Value::removeMember(std::string const& key, Value* removed) +{ + return removeMember(key.data(), key.data() + key.length(), removed); +} +Value Value::removeMember(const char* key) +{ + JSON_ASSERT_MESSAGE(type_ == nullValue || type_ == objectValue, + "in Json::Value::removeMember(): requires objectValue"); + if (type_ == nullValue) + return nullRef; + + Value removed; // null + removeMember(key, key + strlen(key), &removed); + return removed; // still null if removeMember() did nothing +} +Value Value::removeMember(const std::string& key) +{ + return removeMember(key.c_str()); +} + +bool Value::removeIndex(ArrayIndex index, Value* removed) { + if (type_ != arrayValue) { + return false; + } + CZString key(index); + ObjectValues::iterator it = value_.map_->find(key); + if (it == value_.map_->end()) { + return false; + } + *removed = it->second; + ArrayIndex oldSize = size(); + // shift left all items left, into the place of the "removed" + for (ArrayIndex i = index; i < (oldSize - 1); ++i){ + CZString key(i); + (*value_.map_)[key] = (*this)[i + 1]; + } + // erase the last one ("leftover") + CZString keyLast(oldSize - 1); + ObjectValues::iterator itLast = value_.map_->find(keyLast); + value_.map_->erase(itLast); + return true; +} + +#ifdef JSON_USE_CPPTL +Value Value::get(const CppTL::ConstString& key, + const Value& defaultValue) const { + return get(key.c_str(), key.end_c_str(), defaultValue); +} +#endif + +bool Value::isMember(char const* key, char const* end) const +{ + Value const* value = find(key, end); + return NULL != value; +} +bool Value::isMember(char const* key) const +{ + return isMember(key, key + strlen(key)); +} +bool Value::isMember(std::string const& key) const +{ + return isMember(key.data(), key.data() + key.length()); +} + +#ifdef JSON_USE_CPPTL +bool Value::isMember(const CppTL::ConstString& key) const { + return isMember(key.c_str(), key.end_c_str()); +} +#endif + +Value::Members Value::getMemberNames() const { + JSON_ASSERT_MESSAGE( + type_ == nullValue || type_ == objectValue, + "in Json::Value::getMemberNames(), value must be objectValue"); + if (type_ == nullValue) + return Value::Members(); + Members members; + members.reserve(value_.map_->size()); + ObjectValues::const_iterator it = value_.map_->begin(); + ObjectValues::const_iterator itEnd = value_.map_->end(); + for (; it != itEnd; ++it) { + members.push_back(std::string((*it).first.data(), + (*it).first.length())); + } + return members; +} +// +//# ifdef JSON_USE_CPPTL +// EnumMemberNames +// Value::enumMemberNames() const +//{ +// if ( type_ == objectValue ) +// { +// return CppTL::Enum::any( CppTL::Enum::transform( +// CppTL::Enum::keys( *(value_.map_), CppTL::Type<const CZString &>() ), +// MemberNamesTransform() ) ); +// } +// return EnumMemberNames(); +//} +// +// +// EnumValues +// Value::enumValues() const +//{ +// if ( type_ == objectValue || type_ == arrayValue ) +// return CppTL::Enum::anyValues( *(value_.map_), +// CppTL::Type<const Value &>() ); +// return EnumValues(); +//} +// +//# endif + +static bool IsIntegral(double d) { + double integral_part; + return modf(d, &integral_part) == 0.0; +} + +bool Value::isNull() const { return type_ == nullValue; } + +bool Value::isBool() const { return type_ == booleanValue; } + +bool Value::isInt() const { + switch (type_) { + case intValue: + return value_.int_ >= minInt && value_.int_ <= maxInt; + case uintValue: + return value_.uint_ <= UInt(maxInt); + case realValue: + return value_.real_ >= minInt && value_.real_ <= maxInt && + IsIntegral(value_.real_); + default: + break; + } + return false; +} + +bool Value::isUInt() const { + switch (type_) { + case intValue: + return value_.int_ >= 0 && LargestUInt(value_.int_) <= LargestUInt(maxUInt); + case uintValue: + return value_.uint_ <= maxUInt; + case realValue: + return value_.real_ >= 0 && value_.real_ <= maxUInt && + IsIntegral(value_.real_); + default: + break; + } + return false; +} + +bool Value::isInt64() const { +#if defined(JSON_HAS_INT64) + switch (type_) { + case intValue: + return true; + case uintValue: + return value_.uint_ <= UInt64(maxInt64); + case realValue: + // Note that maxInt64 (= 2^63 - 1) is not exactly representable as a + // double, so double(maxInt64) will be rounded up to 2^63. Therefore we + // require the value to be strictly less than the limit. + return value_.real_ >= double(minInt64) && + value_.real_ < double(maxInt64) && IsIntegral(value_.real_); + default: + break; + } +#endif // JSON_HAS_INT64 + return false; +} + +bool Value::isUInt64() const { +#if defined(JSON_HAS_INT64) + switch (type_) { + case intValue: + return value_.int_ >= 0; + case uintValue: + return true; + case realValue: + // Note that maxUInt64 (= 2^64 - 1) is not exactly representable as a + // double, so double(maxUInt64) will be rounded up to 2^64. Therefore we + // require the value to be strictly less than the limit. + return value_.real_ >= 0 && value_.real_ < maxUInt64AsDouble && + IsIntegral(value_.real_); + default: + break; + } +#endif // JSON_HAS_INT64 + return false; +} + +bool Value::isIntegral() const { +#if defined(JSON_HAS_INT64) + return isInt64() || isUInt64(); +#else + return isInt() || isUInt(); +#endif +} + +bool Value::isDouble() const { return type_ == realValue || isIntegral(); } + +bool Value::isNumeric() const { return isIntegral() || isDouble(); } + +bool Value::isString() const { return type_ == stringValue; } + +bool Value::isArray() const { return type_ == arrayValue; } + +bool Value::isObject() const { return type_ == objectValue; } + +void Value::setComment(const char* comment, size_t len, CommentPlacement placement) { + if (!comments_) + comments_ = new CommentInfo[numberOfCommentPlacement]; + if ((len > 0) && (comment[len-1] == '\n')) { + // Always discard trailing newline, to aid indentation. + len -= 1; + } + comments_[placement].setComment(comment, len); +} + +void Value::setComment(const char* comment, CommentPlacement placement) { + setComment(comment, strlen(comment), placement); +} + +void Value::setComment(const std::string& comment, CommentPlacement placement) { + setComment(comment.c_str(), comment.length(), placement); +} + +bool Value::hasComment(CommentPlacement placement) const { + return comments_ != 0 && comments_[placement].comment_ != 0; +} + +std::string Value::getComment(CommentPlacement placement) const { + if (hasComment(placement)) + return comments_[placement].comment_; + return ""; +} + +void Value::setOffsetStart(size_t start) { start_ = start; } + +void Value::setOffsetLimit(size_t limit) { limit_ = limit; } + +size_t Value::getOffsetStart() const { return start_; } + +size_t Value::getOffsetLimit() const { return limit_; } + +std::string Value::toStyledString() const { + StyledWriter writer; + return writer.write(*this); +} + +Value::const_iterator Value::begin() const { + switch (type_) { + case arrayValue: + case objectValue: + if (value_.map_) + return const_iterator(value_.map_->begin()); + break; + default: + break; + } + return const_iterator(); +} + +Value::const_iterator Value::end() const { + switch (type_) { + case arrayValue: + case objectValue: + if (value_.map_) + return const_iterator(value_.map_->end()); + break; + default: + break; + } + return const_iterator(); +} + +Value::iterator Value::begin() { + switch (type_) { + case arrayValue: + case objectValue: + if (value_.map_) + return iterator(value_.map_->begin()); + break; + default: + break; + } + return iterator(); +} + +Value::iterator Value::end() { + switch (type_) { + case arrayValue: + case objectValue: + if (value_.map_) + return iterator(value_.map_->end()); + break; + default: + break; + } + return iterator(); +} + +// class PathArgument +// ////////////////////////////////////////////////////////////////// + +PathArgument::PathArgument() : key_(), index_(), kind_(kindNone) {} + +PathArgument::PathArgument(ArrayIndex index) + : key_(), index_(index), kind_(kindIndex) {} + +PathArgument::PathArgument(const char* key) + : key_(key), index_(), kind_(kindKey) {} + +PathArgument::PathArgument(const std::string& key) + : key_(key.c_str()), index_(), kind_(kindKey) {} + +// class Path +// ////////////////////////////////////////////////////////////////// + +Path::Path(const std::string& path, + const PathArgument& a1, + const PathArgument& a2, + const PathArgument& a3, + const PathArgument& a4, + const PathArgument& a5) { + InArgs in; + in.push_back(&a1); + in.push_back(&a2); + in.push_back(&a3); + in.push_back(&a4); + in.push_back(&a5); + makePath(path, in); +} + +void Path::makePath(const std::string& path, const InArgs& in) { + const char* current = path.c_str(); + const char* end = current + path.length(); + InArgs::const_iterator itInArg = in.begin(); + while (current != end) { + if (*current == '[') { + ++current; + if (*current == '%') + addPathInArg(path, in, itInArg, PathArgument::kindIndex); + else { + ArrayIndex index = 0; + for (; current != end && *current >= '0' && *current <= '9'; ++current) + index = index * 10 + ArrayIndex(*current - '0'); + args_.push_back(index); + } + if (current == end || *current++ != ']') + invalidPath(path, int(current - path.c_str())); + } else if (*current == '%') { + addPathInArg(path, in, itInArg, PathArgument::kindKey); + ++current; + } else if (*current == '.') { + ++current; + } else { + const char* beginName = current; + while (current != end && !strchr("[.", *current)) + ++current; + args_.push_back(std::string(beginName, current)); + } + } +} + +void Path::addPathInArg(const std::string& /*path*/, + const InArgs& in, + InArgs::const_iterator& itInArg, + PathArgument::Kind kind) { + if (itInArg == in.end()) { + // Error: missing argument %d + } else if ((*itInArg)->kind_ != kind) { + // Error: bad argument type + } else { + args_.push_back(**itInArg); + } +} + +void Path::invalidPath(const std::string& /*path*/, int /*location*/) { + // Error: invalid path. +} + +const Value& Path::resolve(const Value& root) const { + const Value* node = &root; + for (Args::const_iterator it = args_.begin(); it != args_.end(); ++it) { + const PathArgument& arg = *it; + if (arg.kind_ == PathArgument::kindIndex) { + if (!node->isArray() || !node->isValidIndex(arg.index_)) { + // Error: unable to resolve path (array value expected at position... + } + node = &((*node)[arg.index_]); + } else if (arg.kind_ == PathArgument::kindKey) { + if (!node->isObject()) { + // Error: unable to resolve path (object value expected at position...) + } + node = &((*node)[arg.key_]); + if (node == &Value::nullRef) { + // Error: unable to resolve path (object has no member named '' at + // position...) + } + } + } + return *node; +} + +Value Path::resolve(const Value& root, const Value& defaultValue) const { + const Value* node = &root; + for (Args::const_iterator it = args_.begin(); it != args_.end(); ++it) { + const PathArgument& arg = *it; + if (arg.kind_ == PathArgument::kindIndex) { + if (!node->isArray() || !node->isValidIndex(arg.index_)) + return defaultValue; + node = &((*node)[arg.index_]); + } else if (arg.kind_ == PathArgument::kindKey) { + if (!node->isObject()) + return defaultValue; + node = &((*node)[arg.key_]); + if (node == &Value::nullRef) + return defaultValue; + } + } + return *node; +} + +Value& Path::make(Value& root) const { + Value* node = &root; + for (Args::const_iterator it = args_.begin(); it != args_.end(); ++it) { + const PathArgument& arg = *it; + if (arg.kind_ == PathArgument::kindIndex) { + if (!node->isArray()) { + // Error: node is not an array at position ... + } + node = &((*node)[arg.index_]); + } else if (arg.kind_ == PathArgument::kindKey) { + if (!node->isObject()) { + // Error: node is not an object at position... + } + node = &((*node)[arg.key_]); + } + } + return *node; +} + +} // namespace Json + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: src/lib_json/json_value.cpp +// ////////////////////////////////////////////////////////////////////// + + + + + + +// ////////////////////////////////////////////////////////////////////// +// Beginning of content of file: src/lib_json/json_writer.cpp +// ////////////////////////////////////////////////////////////////////// + +// Copyright 2011 Baptiste Lepilleur +// Distributed under MIT license, or public domain if desired and +// recognized in your jurisdiction. +// See file LICENSE for detail or copy at http://jsoncpp.sourceforge.net/LICENSE + +#if !defined(JSON_IS_AMALGAMATION) +#include <json/writer.h> +#include "json_tool.h" +#endif // if !defined(JSON_IS_AMALGAMATION) +#include <iomanip> +#include <memory> +#include <sstream> +#include <utility> +#include <set> +#include <cassert> +#include <cstring> +#include <cstdio> + +#if defined(_MSC_VER) && _MSC_VER >= 1200 && _MSC_VER < 1800 // Between VC++ 6.0 and VC++ 11.0 +#include <float.h> +#define isfinite _finite +#elif defined(__sun) && defined(__SVR4) //Solaris +#include <ieeefp.h> +#define isfinite finite +#else +#include <cmath> +#define isfinite std::isfinite +#endif + +#if defined(_MSC_VER) && _MSC_VER < 1500 // VC++ 8.0 and below +#define snprintf _snprintf +#elif defined(__ANDROID__) +#define snprintf snprintf +#elif __cplusplus >= 201103L +#define snprintf std::snprintf +#endif + +#if defined(__BORLANDC__) +#include <float.h> +#define isfinite _finite +#define snprintf _snprintf +#endif + +#if defined(_MSC_VER) && _MSC_VER >= 1400 // VC++ 8.0 +// Disable warning about strdup being deprecated. +#pragma warning(disable : 4996) +#endif + +namespace Json { + +#if __cplusplus >= 201103L +typedef std::unique_ptr<StreamWriter> StreamWriterPtr; +#else +typedef std::auto_ptr<StreamWriter> StreamWriterPtr; +#endif + +static bool containsControlCharacter(const char* str) { + while (*str) { + if (isControlCharacter(*(str++))) + return true; + } + return false; +} + +static bool containsControlCharacter0(const char* str, unsigned len) { + char const* end = str + len; + while (end != str) { + if (isControlCharacter(*str) || 0==*str) + return true; + ++str; + } + return false; +} + +std::string valueToString(LargestInt value) { + UIntToStringBuffer buffer; + char* current = buffer + sizeof(buffer); + bool isNegative = value < 0; + if (isNegative) + value = -value; + uintToString(LargestUInt(value), current); + if (isNegative) + *--current = '-'; + assert(current >= buffer); + return current; +} + +std::string valueToString(LargestUInt value) { + UIntToStringBuffer buffer; + char* current = buffer + sizeof(buffer); + uintToString(value, current); + assert(current >= buffer); + return current; +} + +#if defined(JSON_HAS_INT64) + +std::string valueToString(Int value) { + return valueToString(LargestInt(value)); +} + +std::string valueToString(UInt value) { + return valueToString(LargestUInt(value)); +} + +#endif // # if defined(JSON_HAS_INT64) + +std::string valueToString(double value) { + // Allocate a buffer that is more than large enough to store the 16 digits of + // precision requested below. + char buffer[32]; + int len = -1; + +// Print into the buffer. We need not request the alternative representation +// that always has a decimal point because JSON doesn't distingish the +// concepts of reals and integers. +#if defined(_MSC_VER) && defined(__STDC_SECURE_LIB__) // Use secure version with + // visual studio 2005 to + // avoid warning. +#if defined(WINCE) + len = _snprintf(buffer, sizeof(buffer), "%.17g", value); +#else + len = sprintf_s(buffer, sizeof(buffer), "%.17g", value); +#endif +#else + if (isfinite(value)) { + len = snprintf(buffer, sizeof(buffer), "%.17g", value); + } else { + // IEEE standard states that NaN values will not compare to themselves + if (value != value) { + len = snprintf(buffer, sizeof(buffer), "null"); + } else if (value < 0) { + len = snprintf(buffer, sizeof(buffer), "-1e+9999"); + } else { + len = snprintf(buffer, sizeof(buffer), "1e+9999"); + } + // For those, we do not need to call fixNumLoc, but it is fast. + } +#endif + assert(len >= 0); + fixNumericLocale(buffer, buffer + len); + return buffer; +} + +std::string valueToString(bool value) { return value ? "true" : "false"; } + +std::string valueToQuotedString(const char* value) { + if (value == NULL) + return ""; + // Not sure how to handle unicode... + if (strpbrk(value, "\"\\\b\f\n\r\t") == NULL && + !containsControlCharacter(value)) + return std::string("\"") + value + "\""; + // We have to walk value and escape any special characters. + // Appending to std::string is not efficient, but this should be rare. + // (Note: forward slashes are *not* rare, but I am not escaping them.) + std::string::size_type maxsize = + strlen(value) * 2 + 3; // allescaped+quotes+NULL + std::string result; + result.reserve(maxsize); // to avoid lots of mallocs + result += "\""; + for (const char* c = value; *c != 0; ++c) { + switch (*c) { + case '\"': + result += "\\\""; + break; + case '\\': + result += "\\\\"; + break; + case '\b': + result += "\\b"; + break; + case '\f': + result += "\\f"; + break; + case '\n': + result += "\\n"; + break; + case '\r': + result += "\\r"; + break; + case '\t': + result += "\\t"; + break; + // case '/': + // Even though \/ is considered a legal escape in JSON, a bare + // slash is also legal, so I see no reason to escape it. + // (I hope I am not misunderstanding something. + // blep notes: actually escaping \/ may be useful in javascript to avoid </ + // sequence. + // Should add a flag to allow this compatibility mode and prevent this + // sequence from occurring. + default: + if (isControlCharacter(*c)) { + std::ostringstream oss; + oss << "\\u" << std::hex << std::uppercase << std::setfill('0') + << std::setw(4) << static_cast<int>(*c); + result += oss.str(); + } else { + result += *c; + } + break; + } + } + result += "\""; + return result; +} + +// https://github.com/upcaste/upcaste/blob/master/src/upcore/src/cstring/strnpbrk.cpp +static char const* strnpbrk(char const* s, char const* accept, size_t n) { + assert((s || !n) && accept); + + char const* const end = s + n; + for (char const* cur = s; cur < end; ++cur) { + int const c = *cur; + for (char const* a = accept; *a; ++a) { + if (*a == c) { + return cur; + } + } + } + return NULL; +} +static std::string valueToQuotedStringN(const char* value, unsigned length) { + if (value == NULL) + return ""; + // Not sure how to handle unicode... + if (strnpbrk(value, "\"\\\b\f\n\r\t", length) == NULL && + !containsControlCharacter0(value, length)) + return std::string("\"") + value + "\""; + // We have to walk value and escape any special characters. + // Appending to std::string is not efficient, but this should be rare. + // (Note: forward slashes are *not* rare, but I am not escaping them.) + std::string::size_type maxsize = + length * 2 + 3; // allescaped+quotes+NULL + std::string result; + result.reserve(maxsize); // to avoid lots of mallocs + result += "\""; + char const* end = value + length; + for (const char* c = value; c != end; ++c) { + switch (*c) { + case '\"': + result += "\\\""; + break; + case '\\': + result += "\\\\"; + break; + case '\b': + result += "\\b"; + break; + case '\f': + result += "\\f"; + break; + case '\n': + result += "\\n"; + break; + case '\r': + result += "\\r"; + break; + case '\t': + result += "\\t"; + break; + // case '/': + // Even though \/ is considered a legal escape in JSON, a bare + // slash is also legal, so I see no reason to escape it. + // (I hope I am not misunderstanding something.) + // blep notes: actually escaping \/ may be useful in javascript to avoid </ + // sequence. + // Should add a flag to allow this compatibility mode and prevent this + // sequence from occurring. + default: + if ((isControlCharacter(*c)) || (*c == 0)) { + std::ostringstream oss; + oss << "\\u" << std::hex << std::uppercase << std::setfill('0') + << std::setw(4) << static_cast<int>(*c); + result += oss.str(); + } else { + result += *c; + } + break; + } + } + result += "\""; + return result; +} + +// Class Writer +// ////////////////////////////////////////////////////////////////// +Writer::~Writer() {} + +// Class FastWriter +// ////////////////////////////////////////////////////////////////// + +FastWriter::FastWriter() + : yamlCompatiblityEnabled_(false), dropNullPlaceholders_(false), + omitEndingLineFeed_(false) {} + +void FastWriter::enableYAMLCompatibility() { yamlCompatiblityEnabled_ = true; } + +void FastWriter::dropNullPlaceholders() { dropNullPlaceholders_ = true; } + +void FastWriter::omitEndingLineFeed() { omitEndingLineFeed_ = true; } + +std::string FastWriter::write(const Value& root) { + document_ = ""; + writeValue(root); + if (!omitEndingLineFeed_) + document_ += "\n"; + return document_; +} + +void FastWriter::writeValue(const Value& value) { + switch (value.type()) { + case nullValue: + if (!dropNullPlaceholders_) + document_ += "null"; + break; + case intValue: + document_ += valueToString(value.asLargestInt()); + break; + case uintValue: + document_ += valueToString(value.asLargestUInt()); + break; + case realValue: + document_ += valueToString(value.asDouble()); + break; + case stringValue: + { + // Is NULL possible for value.string_? + char const* str; + char const* end; + bool ok = value.getString(&str, &end); + if (ok) document_ += valueToQuotedStringN(str, static_cast<unsigned>(end-str)); + break; + } + case booleanValue: + document_ += valueToString(value.asBool()); + break; + case arrayValue: { + document_ += '['; + int size = value.size(); + for (int index = 0; index < size; ++index) { + if (index > 0) + document_ += ','; + writeValue(value[index]); + } + document_ += ']'; + } break; + case objectValue: { + Value::Members members(value.getMemberNames()); + document_ += '{'; + for (Value::Members::iterator it = members.begin(); it != members.end(); + ++it) { + const std::string& name = *it; + if (it != members.begin()) + document_ += ','; + document_ += valueToQuotedStringN(name.data(), static_cast<unsigned>(name.length())); + document_ += yamlCompatiblityEnabled_ ? ": " : ":"; + writeValue(value[name]); + } + document_ += '}'; + } break; + } +} + +// Class StyledWriter +// ////////////////////////////////////////////////////////////////// + +StyledWriter::StyledWriter() + : rightMargin_(74), indentSize_(3), addChildValues_() {} + +std::string StyledWriter::write(const Value& root) { + document_ = ""; + addChildValues_ = false; + indentString_ = ""; + writeCommentBeforeValue(root); + writeValue(root); + writeCommentAfterValueOnSameLine(root); + document_ += "\n"; + return document_; +} + +void StyledWriter::writeValue(const Value& value) { + switch (value.type()) { + case nullValue: + pushValue("null"); + break; + case intValue: + pushValue(valueToString(value.asLargestInt())); + break; + case uintValue: + pushValue(valueToString(value.asLargestUInt())); + break; + case realValue: + pushValue(valueToString(value.asDouble())); + break; + case stringValue: + { + // Is NULL possible for value.string_? + char const* str; + char const* end; + bool ok = value.getString(&str, &end); + if (ok) pushValue(valueToQuotedStringN(str, static_cast<unsigned>(end-str))); + else pushValue(""); + break; + } + case booleanValue: + pushValue(valueToString(value.asBool())); + break; + case arrayValue: + writeArrayValue(value); + break; + case objectValue: { + Value::Members members(value.getMemberNames()); + if (members.empty()) + pushValue("{}"); + else { + writeWithIndent("{"); + indent(); + Value::Members::iterator it = members.begin(); + for (;;) { + const std::string& name = *it; + const Value& childValue = value[name]; + writeCommentBeforeValue(childValue); + writeWithIndent(valueToQuotedString(name.c_str())); + document_ += " : "; + writeValue(childValue); + if (++it == members.end()) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + document_ += ','; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("}"); + } + } break; + } +} + +void StyledWriter::writeArrayValue(const Value& value) { + unsigned size = value.size(); + if (size == 0) + pushValue("[]"); + else { + bool isArrayMultiLine = isMultineArray(value); + if (isArrayMultiLine) { + writeWithIndent("["); + indent(); + bool hasChildValue = !childValues_.empty(); + unsigned index = 0; + for (;;) { + const Value& childValue = value[index]; + writeCommentBeforeValue(childValue); + if (hasChildValue) + writeWithIndent(childValues_[index]); + else { + writeIndent(); + writeValue(childValue); + } + if (++index == size) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + document_ += ','; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("]"); + } else // output on a single line + { + assert(childValues_.size() == size); + document_ += "[ "; + for (unsigned index = 0; index < size; ++index) { + if (index > 0) + document_ += ", "; + document_ += childValues_[index]; + } + document_ += " ]"; + } + } +} + +bool StyledWriter::isMultineArray(const Value& value) { + int size = value.size(); + bool isMultiLine = size * 3 >= rightMargin_; + childValues_.clear(); + for (int index = 0; index < size && !isMultiLine; ++index) { + const Value& childValue = value[index]; + isMultiLine = + isMultiLine || ((childValue.isArray() || childValue.isObject()) && + childValue.size() > 0); + } + if (!isMultiLine) // check if line length > max line length + { + childValues_.reserve(size); + addChildValues_ = true; + int lineLength = 4 + (size - 1) * 2; // '[ ' + ', '*n + ' ]' + for (int index = 0; index < size; ++index) { + if (hasCommentForValue(value[index])) { + isMultiLine = true; + } + writeValue(value[index]); + lineLength += int(childValues_[index].length()); + } + addChildValues_ = false; + isMultiLine = isMultiLine || lineLength >= rightMargin_; + } + return isMultiLine; +} + +void StyledWriter::pushValue(const std::string& value) { + if (addChildValues_) + childValues_.push_back(value); + else + document_ += value; +} + +void StyledWriter::writeIndent() { + if (!document_.empty()) { + char last = document_[document_.length() - 1]; + if (last == ' ') // already indented + return; + if (last != '\n') // Comments may add new-line + document_ += '\n'; + } + document_ += indentString_; +} + +void StyledWriter::writeWithIndent(const std::string& value) { + writeIndent(); + document_ += value; +} + +void StyledWriter::indent() { indentString_ += std::string(indentSize_, ' '); } + +void StyledWriter::unindent() { + assert(int(indentString_.size()) >= indentSize_); + indentString_.resize(indentString_.size() - indentSize_); +} + +void StyledWriter::writeCommentBeforeValue(const Value& root) { + if (!root.hasComment(commentBefore)) + return; + + document_ += "\n"; + writeIndent(); + const std::string& comment = root.getComment(commentBefore); + std::string::const_iterator iter = comment.begin(); + while (iter != comment.end()) { + document_ += *iter; + if (*iter == '\n' && + (iter != comment.end() && *(iter + 1) == '/')) + writeIndent(); + ++iter; + } + + // Comments are stripped of trailing newlines, so add one here + document_ += "\n"; +} + +void StyledWriter::writeCommentAfterValueOnSameLine(const Value& root) { + if (root.hasComment(commentAfterOnSameLine)) + document_ += " " + root.getComment(commentAfterOnSameLine); + + if (root.hasComment(commentAfter)) { + document_ += "\n"; + document_ += root.getComment(commentAfter); + document_ += "\n"; + } +} + +bool StyledWriter::hasCommentForValue(const Value& value) { + return value.hasComment(commentBefore) || + value.hasComment(commentAfterOnSameLine) || + value.hasComment(commentAfter); +} + +// Class StyledStreamWriter +// ////////////////////////////////////////////////////////////////// + +StyledStreamWriter::StyledStreamWriter(std::string indentation) + : document_(NULL), rightMargin_(74), indentation_(indentation), + addChildValues_() {} + +void StyledStreamWriter::write(std::ostream& out, const Value& root) { + document_ = &out; + addChildValues_ = false; + indentString_ = ""; + indented_ = true; + writeCommentBeforeValue(root); + if (!indented_) writeIndent(); + indented_ = true; + writeValue(root); + writeCommentAfterValueOnSameLine(root); + *document_ << "\n"; + document_ = NULL; // Forget the stream, for safety. +} + +void StyledStreamWriter::writeValue(const Value& value) { + switch (value.type()) { + case nullValue: + pushValue("null"); + break; + case intValue: + pushValue(valueToString(value.asLargestInt())); + break; + case uintValue: + pushValue(valueToString(value.asLargestUInt())); + break; + case realValue: + pushValue(valueToString(value.asDouble())); + break; + case stringValue: + { + // Is NULL possible for value.string_? + char const* str; + char const* end; + bool ok = value.getString(&str, &end); + if (ok) pushValue(valueToQuotedStringN(str, static_cast<unsigned>(end-str))); + else pushValue(""); + break; + } + case booleanValue: + pushValue(valueToString(value.asBool())); + break; + case arrayValue: + writeArrayValue(value); + break; + case objectValue: { + Value::Members members(value.getMemberNames()); + if (members.empty()) + pushValue("{}"); + else { + writeWithIndent("{"); + indent(); + Value::Members::iterator it = members.begin(); + for (;;) { + const std::string& name = *it; + const Value& childValue = value[name]; + writeCommentBeforeValue(childValue); + writeWithIndent(valueToQuotedString(name.c_str())); + *document_ << " : "; + writeValue(childValue); + if (++it == members.end()) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + *document_ << ","; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("}"); + } + } break; + } +} + +void StyledStreamWriter::writeArrayValue(const Value& value) { + unsigned size = value.size(); + if (size == 0) + pushValue("[]"); + else { + bool isArrayMultiLine = isMultineArray(value); + if (isArrayMultiLine) { + writeWithIndent("["); + indent(); + bool hasChildValue = !childValues_.empty(); + unsigned index = 0; + for (;;) { + const Value& childValue = value[index]; + writeCommentBeforeValue(childValue); + if (hasChildValue) + writeWithIndent(childValues_[index]); + else { + if (!indented_) writeIndent(); + indented_ = true; + writeValue(childValue); + indented_ = false; + } + if (++index == size) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + *document_ << ","; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("]"); + } else // output on a single line + { + assert(childValues_.size() == size); + *document_ << "[ "; + for (unsigned index = 0; index < size; ++index) { + if (index > 0) + *document_ << ", "; + *document_ << childValues_[index]; + } + *document_ << " ]"; + } + } +} + +bool StyledStreamWriter::isMultineArray(const Value& value) { + int size = value.size(); + bool isMultiLine = size * 3 >= rightMargin_; + childValues_.clear(); + for (int index = 0; index < size && !isMultiLine; ++index) { + const Value& childValue = value[index]; + isMultiLine = + isMultiLine || ((childValue.isArray() || childValue.isObject()) && + childValue.size() > 0); + } + if (!isMultiLine) // check if line length > max line length + { + childValues_.reserve(size); + addChildValues_ = true; + int lineLength = 4 + (size - 1) * 2; // '[ ' + ', '*n + ' ]' + for (int index = 0; index < size; ++index) { + if (hasCommentForValue(value[index])) { + isMultiLine = true; + } + writeValue(value[index]); + lineLength += int(childValues_[index].length()); + } + addChildValues_ = false; + isMultiLine = isMultiLine || lineLength >= rightMargin_; + } + return isMultiLine; +} + +void StyledStreamWriter::pushValue(const std::string& value) { + if (addChildValues_) + childValues_.push_back(value); + else + *document_ << value; +} + +void StyledStreamWriter::writeIndent() { + // blep intended this to look at the so-far-written string + // to determine whether we are already indented, but + // with a stream we cannot do that. So we rely on some saved state. + // The caller checks indented_. + *document_ << '\n' << indentString_; +} + +void StyledStreamWriter::writeWithIndent(const std::string& value) { + if (!indented_) writeIndent(); + *document_ << value; + indented_ = false; +} + +void StyledStreamWriter::indent() { indentString_ += indentation_; } + +void StyledStreamWriter::unindent() { + assert(indentString_.size() >= indentation_.size()); + indentString_.resize(indentString_.size() - indentation_.size()); +} + +void StyledStreamWriter::writeCommentBeforeValue(const Value& root) { + if (!root.hasComment(commentBefore)) + return; + + if (!indented_) writeIndent(); + const std::string& comment = root.getComment(commentBefore); + std::string::const_iterator iter = comment.begin(); + while (iter != comment.end()) { + *document_ << *iter; + if (*iter == '\n' && + (iter != comment.end() && *(iter + 1) == '/')) + // writeIndent(); // would include newline + *document_ << indentString_; + ++iter; + } + indented_ = false; +} + +void StyledStreamWriter::writeCommentAfterValueOnSameLine(const Value& root) { + if (root.hasComment(commentAfterOnSameLine)) + *document_ << ' ' << root.getComment(commentAfterOnSameLine); + + if (root.hasComment(commentAfter)) { + writeIndent(); + *document_ << root.getComment(commentAfter); + } + indented_ = false; +} + +bool StyledStreamWriter::hasCommentForValue(const Value& value) { + return value.hasComment(commentBefore) || + value.hasComment(commentAfterOnSameLine) || + value.hasComment(commentAfter); +} + +////////////////////////// +// BuiltStyledStreamWriter + +/// Scoped enums are not available until C++11. +struct CommentStyle { + /// Decide whether to write comments. + enum Enum { + None, ///< Drop all comments. + Most, ///< Recover odd behavior of previous versions (not implemented yet). + All ///< Keep all comments. + }; +}; + +struct BuiltStyledStreamWriter : public StreamWriter +{ + BuiltStyledStreamWriter( + std::string const& indentation, + CommentStyle::Enum cs, + std::string const& colonSymbol, + std::string const& nullSymbol, + std::string const& endingLineFeedSymbol); + virtual int write(Value const& root, std::ostream* sout); +private: + void writeValue(Value const& value); + void writeArrayValue(Value const& value); + bool isMultineArray(Value const& value); + void pushValue(std::string const& value); + void writeIndent(); + void writeWithIndent(std::string const& value); + void indent(); + void unindent(); + void writeCommentBeforeValue(Value const& root); + void writeCommentAfterValueOnSameLine(Value const& root); + static bool hasCommentForValue(const Value& value); + + typedef std::vector<std::string> ChildValues; + + ChildValues childValues_; + std::string indentString_; + int rightMargin_; + std::string indentation_; + CommentStyle::Enum cs_; + std::string colonSymbol_; + std::string nullSymbol_; + std::string endingLineFeedSymbol_; + bool addChildValues_ : 1; + bool indented_ : 1; +}; +BuiltStyledStreamWriter::BuiltStyledStreamWriter( + std::string const& indentation, + CommentStyle::Enum cs, + std::string const& colonSymbol, + std::string const& nullSymbol, + std::string const& endingLineFeedSymbol) + : rightMargin_(74) + , indentation_(indentation) + , cs_(cs) + , colonSymbol_(colonSymbol) + , nullSymbol_(nullSymbol) + , endingLineFeedSymbol_(endingLineFeedSymbol) + , addChildValues_(false) + , indented_(false) +{ +} +int BuiltStyledStreamWriter::write(Value const& root, std::ostream* sout) +{ + sout_ = sout; + addChildValues_ = false; + indented_ = true; + indentString_ = ""; + writeCommentBeforeValue(root); + if (!indented_) writeIndent(); + indented_ = true; + writeValue(root); + writeCommentAfterValueOnSameLine(root); + *sout_ << endingLineFeedSymbol_; + sout_ = NULL; + return 0; +} +void BuiltStyledStreamWriter::writeValue(Value const& value) { + switch (value.type()) { + case nullValue: + pushValue(nullSymbol_); + break; + case intValue: + pushValue(valueToString(value.asLargestInt())); + break; + case uintValue: + pushValue(valueToString(value.asLargestUInt())); + break; + case realValue: + pushValue(valueToString(value.asDouble())); + break; + case stringValue: + { + // Is NULL is possible for value.string_? + char const* str; + char const* end; + bool ok = value.getString(&str, &end); + if (ok) pushValue(valueToQuotedStringN(str, static_cast<unsigned>(end-str))); + else pushValue(""); + break; + } + case booleanValue: + pushValue(valueToString(value.asBool())); + break; + case arrayValue: + writeArrayValue(value); + break; + case objectValue: { + Value::Members members(value.getMemberNames()); + if (members.empty()) + pushValue("{}"); + else { + writeWithIndent("{"); + indent(); + Value::Members::iterator it = members.begin(); + for (;;) { + std::string const& name = *it; + Value const& childValue = value[name]; + writeCommentBeforeValue(childValue); + writeWithIndent(valueToQuotedStringN(name.data(), static_cast<unsigned>(name.length()))); + *sout_ << colonSymbol_; + writeValue(childValue); + if (++it == members.end()) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + *sout_ << ","; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("}"); + } + } break; + } +} + +void BuiltStyledStreamWriter::writeArrayValue(Value const& value) { + unsigned size = value.size(); + if (size == 0) + pushValue("[]"); + else { + bool isMultiLine = (cs_ == CommentStyle::All) || isMultineArray(value); + if (isMultiLine) { + writeWithIndent("["); + indent(); + bool hasChildValue = !childValues_.empty(); + unsigned index = 0; + for (;;) { + Value const& childValue = value[index]; + writeCommentBeforeValue(childValue); + if (hasChildValue) + writeWithIndent(childValues_[index]); + else { + if (!indented_) writeIndent(); + indented_ = true; + writeValue(childValue); + indented_ = false; + } + if (++index == size) { + writeCommentAfterValueOnSameLine(childValue); + break; + } + *sout_ << ","; + writeCommentAfterValueOnSameLine(childValue); + } + unindent(); + writeWithIndent("]"); + } else // output on a single line + { + assert(childValues_.size() == size); + *sout_ << "["; + if (!indentation_.empty()) *sout_ << " "; + for (unsigned index = 0; index < size; ++index) { + if (index > 0) + *sout_ << ", "; + *sout_ << childValues_[index]; + } + if (!indentation_.empty()) *sout_ << " "; + *sout_ << "]"; + } + } +} + +bool BuiltStyledStreamWriter::isMultineArray(Value const& value) { + int size = value.size(); + bool isMultiLine = size * 3 >= rightMargin_; + childValues_.clear(); + for (int index = 0; index < size && !isMultiLine; ++index) { + Value const& childValue = value[index]; + isMultiLine = + isMultiLine || ((childValue.isArray() || childValue.isObject()) && + childValue.size() > 0); + } + if (!isMultiLine) // check if line length > max line length + { + childValues_.reserve(size); + addChildValues_ = true; + int lineLength = 4 + (size - 1) * 2; // '[ ' + ', '*n + ' ]' + for (int index = 0; index < size; ++index) { + if (hasCommentForValue(value[index])) { + isMultiLine = true; + } + writeValue(value[index]); + lineLength += int(childValues_[index].length()); + } + addChildValues_ = false; + isMultiLine = isMultiLine || lineLength >= rightMargin_; + } + return isMultiLine; +} + +void BuiltStyledStreamWriter::pushValue(std::string const& value) { + if (addChildValues_) + childValues_.push_back(value); + else + *sout_ << value; +} + +void BuiltStyledStreamWriter::writeIndent() { + // blep intended this to look at the so-far-written string + // to determine whether we are already indented, but + // with a stream we cannot do that. So we rely on some saved state. + // The caller checks indented_. + + if (!indentation_.empty()) { + // In this case, drop newlines too. + *sout_ << '\n' << indentString_; + } +} + +void BuiltStyledStreamWriter::writeWithIndent(std::string const& value) { + if (!indented_) writeIndent(); + *sout_ << value; + indented_ = false; +} + +void BuiltStyledStreamWriter::indent() { indentString_ += indentation_; } + +void BuiltStyledStreamWriter::unindent() { + assert(indentString_.size() >= indentation_.size()); + indentString_.resize(indentString_.size() - indentation_.size()); +} + +void BuiltStyledStreamWriter::writeCommentBeforeValue(Value const& root) { + if (cs_ == CommentStyle::None) return; + if (!root.hasComment(commentBefore)) + return; + + if (!indented_) writeIndent(); + const std::string& comment = root.getComment(commentBefore); + std::string::const_iterator iter = comment.begin(); + while (iter != comment.end()) { + *sout_ << *iter; + if (*iter == '\n' && + (iter != comment.end() && *(iter + 1) == '/')) + // writeIndent(); // would write extra newline + *sout_ << indentString_; + ++iter; + } + indented_ = false; +} + +void BuiltStyledStreamWriter::writeCommentAfterValueOnSameLine(Value const& root) { + if (cs_ == CommentStyle::None) return; + if (root.hasComment(commentAfterOnSameLine)) + *sout_ << " " + root.getComment(commentAfterOnSameLine); + + if (root.hasComment(commentAfter)) { + writeIndent(); + *sout_ << root.getComment(commentAfter); + } +} + +// static +bool BuiltStyledStreamWriter::hasCommentForValue(const Value& value) { + return value.hasComment(commentBefore) || + value.hasComment(commentAfterOnSameLine) || + value.hasComment(commentAfter); +} + +/////////////// +// StreamWriter + +StreamWriter::StreamWriter() + : sout_(NULL) +{ +} +StreamWriter::~StreamWriter() +{ +} +StreamWriter::Factory::~Factory() +{} +StreamWriterBuilder::StreamWriterBuilder() +{ + setDefaults(&settings_); +} +StreamWriterBuilder::~StreamWriterBuilder() +{} +StreamWriter* StreamWriterBuilder::newStreamWriter() const +{ + std::string indentation = settings_["indentation"].asString(); + std::string cs_str = settings_["commentStyle"].asString(); + bool eyc = settings_["enableYAMLCompatibility"].asBool(); + bool dnp = settings_["dropNullPlaceholders"].asBool(); + CommentStyle::Enum cs = CommentStyle::All; + if (cs_str == "All") { + cs = CommentStyle::All; + } else if (cs_str == "None") { + cs = CommentStyle::None; + } else { + throwRuntimeError("commentStyle must be 'All' or 'None'"); + } + std::string colonSymbol = " : "; + if (eyc) { + colonSymbol = ": "; + } else if (indentation.empty()) { + colonSymbol = ":"; + } + std::string nullSymbol = "null"; + if (dnp) { + nullSymbol = ""; + } + std::string endingLineFeedSymbol = ""; + return new BuiltStyledStreamWriter( + indentation, cs, + colonSymbol, nullSymbol, endingLineFeedSymbol); +} +static void getValidWriterKeys(std::set<std::string>* valid_keys) +{ + valid_keys->clear(); + valid_keys->insert("indentation"); + valid_keys->insert("commentStyle"); + valid_keys->insert("enableYAMLCompatibility"); + valid_keys->insert("dropNullPlaceholders"); +} +bool StreamWriterBuilder::validate(Json::Value* invalid) const +{ + Json::Value my_invalid; + if (!invalid) invalid = &my_invalid; // so we do not need to test for NULL + Json::Value& inv = *invalid; + std::set<std::string> valid_keys; + getValidWriterKeys(&valid_keys); + Value::Members keys = settings_.getMemberNames(); + size_t n = keys.size(); + for (size_t i = 0; i < n; ++i) { + std::string const& key = keys[i]; + if (valid_keys.find(key) == valid_keys.end()) { + inv[key] = settings_[key]; + } + } + return 0u == inv.size(); +} +Value& StreamWriterBuilder::operator[](std::string key) +{ + return settings_[key]; +} +// static +void StreamWriterBuilder::setDefaults(Json::Value* settings) +{ + //! [StreamWriterBuilderDefaults] + (*settings)["commentStyle"] = "All"; + (*settings)["indentation"] = "\t"; + (*settings)["enableYAMLCompatibility"] = false; + (*settings)["dropNullPlaceholders"] = false; + //! [StreamWriterBuilderDefaults] +} + +std::string writeString(StreamWriter::Factory const& builder, Value const& root) { + std::ostringstream sout; + StreamWriterPtr const writer(builder.newStreamWriter()); + writer->write(root, &sout); + return sout.str(); +} + +std::ostream& operator<<(std::ostream& sout, Value const& root) { + StreamWriterBuilder builder; + StreamWriterPtr const writer(builder.newStreamWriter()); + writer->write(root, &sout); + return sout; +} + +} // namespace Json + +// ////////////////////////////////////////////////////////////////////// +// End of content of file: src/lib_json/json_writer.cpp +// ////////////////////////////////////////////////////////////////////// + + + + + @@ -1,5 +1,5 @@ #! /bin/bash -/router/bin/python-2.7.1 waf-1.6.8 $@ +python2.7 waf-1.6.8 $@ sts=$? exit $sts diff --git a/linux/ws_main.py b/linux/ws_main.py index 93ed02b8..6a8967bf 100755 --- a/linux/ws_main.py +++ b/linux/ws_main.py @@ -7,15 +7,18 @@ VERSION='0.0.1' APPNAME='cxx_test' + import os; import commands; import shutil; import copy; +from distutils.version import StrictVersion top = '../' out = 'build' b_path ="./build/linux/" +REQUIRED_CC_VERSION = "4.7.0" class SrcGroup: ' group of source by directory ' @@ -68,8 +71,25 @@ class SrcGroups: def options(opt): opt.load('compiler_cxx') + +def verify_cc_version (env): + ver = '.'.join(env['CC_VERSION']) + + if StrictVersion(ver) < REQUIRED_CC_VERSION: + print "\nMachine GCC version too low '{0}' - required at least '{1}'".format(ver, REQUIRED_CC_VERSION) + print "\n*** please set a compiler using CXX / AR enviorment variables ***\n" + exit(-1) + + def configure(conf): + # start from clean + if 'RPATH' in os.environ: + conf.env.RPATH = os.environ['RPATH'].split(':') + else: + conf.env.RPATH = [] + conf.load('g++') + verify_cc_version(conf.env) main_src = SrcGroup(dir='src', @@ -118,6 +138,38 @@ net_src = SrcGroup(dir='src/common/Network/Packet', 'MacAddress.cpp', 'VLANHeader.cpp']); +# RPC code +rpc_server_src = SrcGroup(dir='src/rpc-server/src', + src_list=[ + 'trex_rpc_server.cpp', + 'trex_rpc_req_resp_server.cpp', + 'trex_rpc_jsonrpc_v2_parser.cpp', + 'trex_rpc_cmds_table.cpp', + + 'commands/trex_rpc_cmd_test.cpp', + 'commands/trex_rpc_cmd_general.cpp', + + ]) + +# RPC mock server (test) +rpc_server_mock_src = SrcGroup(dir='src/rpc-server/src', + src_list=[ + 'trex_rpc_server_mock.cpp', + '../../gtest/rpc_test.cpp', + ]) + +# JSON package +json_src = SrcGroup(dir='external_libs/json', + src_list=[ + 'jsoncpp.cpp' + ]) + +rpc_server_mock = SrcGroups([cmn_src, + rpc_server_src, + rpc_server_mock_src, + json_src + ]) + yaml_src = SrcGroup(dir='yaml-cpp/src/', src_list=[ 'aliasmanager.cpp', @@ -152,22 +204,27 @@ bp =SrcGroups([ main_src, cmn_src , net_src , - yaml_src + yaml_src, ]); cxxflags_base =['-DWIN_UCODE_SIM', - '-D_BYTE_ORDER', - '-D_LITTLE_ENDIAN', - '-DLINUX', - '-g', + '-D_BYTE_ORDER', + '-D_LITTLE_ENDIAN', + '-DLINUX', + '-g', + '-Wno-deprecated-declarations', + '-std=c++0x', ]; includes_path =''' ../src/pal/linux/ + ../src/zmq/include/ ../src/ + ../src/rpc-server/include + ../external_libs/json/ ../yaml-cpp/include/ '''; @@ -183,10 +240,13 @@ PLATFORM_32 = "32" class build_option: - def __init__(self,platform,debug_mode,is_pie): + def __init__(self, name, src, platform, debug_mode, is_pie, use = []): self.mode = debug_mode; ##debug,release self.platform = platform; #['32','64'] self.is_pie = is_pie + self.name = name + self.src = src + self.use = use def __str__(self): s=self.mode+","+self.platform; @@ -253,12 +313,18 @@ class build_option: return result; + def get_use_libs (self): + return self.use + def get_target (self): - return self.update_executable_name("bp-sim"); + return self.update_executable_name(self.name); def get_flags (self): return self.cxxcomp_flags(cxxflags_base); + def get_src (self): + return self.src.file_list(top) + def get_link_flags(self): # add here basic flags base_flags = ['-lpthread']; @@ -268,34 +334,39 @@ class build_option: #platform depended flags if self.is64Platform(): - base_flags += ['-m64']; + base_flags += ['-m64'] else: - base_flags += ['-lrt']; + base_flags += ['-m32'] + base_flags += ['-lrt'] + if self.isPIE(): base_flags += ['-pie', '-DPATCH_FOR_PIE'] return base_flags; - def get_exe (self,full_path = True): - return self.toExe(self.get_target(),full_path); - build_types = [ - build_option(debug_mode= DEBUG_, platform = PLATFORM_32, is_pie = False), - build_option(debug_mode= DEBUG_, platform = PLATFORM_64, is_pie = False), - build_option(debug_mode= RELEASE_,platform = PLATFORM_32, is_pie = False), - build_option(debug_mode= RELEASE_,platform = PLATFORM_64, is_pie = False), + build_option(name = "bp-sim", src = bp, debug_mode= DEBUG_, platform = PLATFORM_32, is_pie = False), + build_option(name = "bp-sim", src = bp, debug_mode= DEBUG_, platform = PLATFORM_64, is_pie = False), + build_option(name = "bp-sim", src = bp, debug_mode= RELEASE_,platform = PLATFORM_32, is_pie = False), + build_option(name = "bp-sim", src = bp, debug_mode= RELEASE_,platform = PLATFORM_64, is_pie = False) + + #build_option(name = "mock-rpc-server", use = ['zmq'], src = rpc_server_mock, debug_mode= DEBUG_,platform = PLATFORM_64, is_pie = False), ] def build_prog (bld, build_obj): + zmq_lib_path='src/zmq/' + bld.read_shlib( name='zmq' , paths=[top + zmq_lib_path] ) + bld.program(features='cxx cxxprogram', includes =includes_path, cxxflags =build_obj.get_flags(), - stlib = 'stdc++', linkflags = build_obj.get_link_flags(), - source = bp.file_list(top), + source = build_obj.get_src(), + use = build_obj.get_use_libs(), + rpath = bld.env.RPATH, target = build_obj.get_target()) @@ -305,15 +376,19 @@ def build_type(bld,build_obj): def post_build(bld): print "copy objects" + exec_p ="../scripts/" + for obj in build_types: install_single_system(bld, exec_p, obj); def build(bld): + bld.add_post_fun(post_build); for obj in build_types: build_type(bld,obj); + def build_info(bld): pass; diff --git a/linux_dpdk/ws_main.py b/linux_dpdk/ws_main.py index 73a1982e..997c80fb 100755 --- a/linux_dpdk/ws_main.py +++ b/linux_dpdk/ws_main.py @@ -130,6 +130,30 @@ net_src = SrcGroup(dir='src/common/Network/Packet', 'MacAddress.cpp', 'VLANHeader.cpp']); +# JSON package +json_src = SrcGroup(dir='external_libs/json', + src_list=[ + 'jsoncpp.cpp' + ]) + +# RPC code +rpc_server_src = SrcGroup(dir='src/rpc-server/src', + src_list=[ + 'trex_rpc_server.cpp', + 'trex_rpc_req_resp_server.cpp', + 'trex_rpc_jsonrpc_v2_parser.cpp', + 'trex_rpc_cmds_table.cpp', + + 'commands/trex_rpc_cmd_test.cpp', + 'commands/trex_rpc_cmd_general.cpp', + ]) + +# JSON package +json_src = SrcGroup(dir='external_libs/json', + src_list=[ + 'jsoncpp.cpp' + ]) + yaml_src = SrcGroup(dir='yaml-cpp/src/', src_list=[ 'aliasmanager.cpp', @@ -319,6 +343,8 @@ bp =SrcGroups([ main_src, cmn_src , net_src , + rpc_server_src, + json_src, yaml_src, version_src ]); @@ -333,48 +359,47 @@ l2fwd =SrcGroups([ l2fwd_main_src]); - -cxxflags_base =['-DWIN_UCODE_SIM', - '-D_BYTE_ORDER', - '-Wno-format', - '-D_LITTLE_ENDIAN', - '-DLINUX', - '-g', - '-DRTE_DPDK', - '-march=native', - '-DRTE_MACHINE_CPUFLAG_SSE', - '-DRTE_MACHINE_CPUFLAG_SSE2', - '-DRTE_MACHINE_CPUFLAG_SSE3', - '-DRTE_MACHINE_CPUFLAG_SSSE3', - '-DRTE_MACHINE_CPUFLAG_SSE4_1', - '-DRTE_MACHINE_CPUFLAG_SSE4_2', - '-DRTE_MACHINE_CPUFLAG_AES', - '-DRTE_MACHINE_CPUFLAG_PCLMULQDQ', - '-DRTE_MACHINE_CPUFLAG_AVX', - '-DRTE_COMPILE_TIME_CPUFLAGS=RTE_CPUFLAG_SSE3,RTE_CPUFLAG_SSE,RTE_CPUFLAG_SSE2,RTE_CPUFLAG_SSSE3,RTE_CPUFLAG_SSE4_1,RTE_CPUFLAG_SSE4_2,RTE_CPUFLAG_AES,RTE_CPUFLAG_PCLMULQDQ,RTE_CPUFLAG_AVX', - '-include','../src/pal/linux_dpdk/dpdk180/rte_config.h' - ]; - -cxxflags_base_old =['-DWIN_UCODE_SIM', - '-D_BYTE_ORDER', - '-D_LITTLE_ENDIAN', - '-DLINUX', - '-Wno-format', - '-DUCS_210', - '-g', - '-DRTE_DPDK', - '-march=corei7', - '-mtune=generic', - '-DRTE_MACHINE_CPUFLAG_SSE', - '-DRTE_COMPILE_TIME_CPUFLAGS=RTE_CPUFLAG_SSE', - '-include','../src/pal/linux_dpdk/dpdk180/rte_config.h' - ]; +# common flags for both new and old configurations +common_flags = ['-DWIN_UCODE_SIM', + '-D_BYTE_ORDER', + '-D_LITTLE_ENDIAN', + '-DLINUX', + '-g', + '-Wno-format', + '-Wno-deprecated-declarations', + '-DRTE_DPDK', + '-include','../src/pal/linux_dpdk/dpdk180/rte_config.h' + ] + +common_flags_new = common_flags + [ + '-march=native', + '-DRTE_MACHINE_CPUFLAG_SSE', + '-DRTE_MACHINE_CPUFLAG_SSE2', + '-DRTE_MACHINE_CPUFLAG_SSE3', + '-DRTE_MACHINE_CPUFLAG_SSSE3', + '-DRTE_MACHINE_CPUFLAG_SSE4_1', + '-DRTE_MACHINE_CPUFLAG_SSE4_2', + '-DRTE_MACHINE_CPUFLAG_AES', + '-DRTE_MACHINE_CPUFLAG_PCLMULQDQ', + '-DRTE_MACHINE_CPUFLAG_AVX', + '-DRTE_COMPILE_TIME_CPUFLAGS=RTE_CPUFLAG_SSE3,RTE_CPUFLAG_SSE,RTE_CPUFLAG_SSE2,RTE_CPUFLAG_SSSE3,RTE_CPUFLAG_SSE4_1,RTE_CPUFLAG_SSE4_2,RTE_CPUFLAG_AES,RTE_CPUFLAG_PCLMULQDQ,RTE_CPUFLAG_AVX', + ] + +common_flags_old = common_flags + [ + '-march=corei7', + '-DUCS_210', + '-mtune=generic', + '-DRTE_MACHINE_CPUFLAG_SSE', + '-DRTE_COMPILE_TIME_CPUFLAGS=RTE_CPUFLAG_SSE', + ]; includes_path =''' ../src/pal/linux_dpdk/ ../src/ + ../external_libs/json/ + ../src/rpc-server/include ../yaml-cpp/include/ ../src/zmq/include/ ../src/dpdk_lib18/librte_eal/linuxapp/eal/include/ @@ -503,20 +528,6 @@ class build_option: trg += delimiter + "o" return trg; - def cxxcomp_flags (self,flags): - result = copy.copy(flags); - if self.is64Platform () : - result+=['-m64']; - else: - result+=['-m32']; - - if self.isRelease () : - result+=['-O3']; - else: - result+=['-O0'];#'-DDEBUG','-D_DEBUG','-DSTILE_CPP_ASSERT','-DSTILE_SHIM_ASSERT' - - return result; - def get_target (self): return self.update_executable_name("_t-rex"); @@ -526,12 +537,37 @@ class build_option: def get_dpdk_target (self): return self.update_executable_name("dpdk"); - def get_flags (self): + def get_common_flags (self): if self.isPIE(): - return self.cxxcomp_flags(cxxflags_base_old); + flags = copy.copy(common_flags_old) else: - return self.cxxcomp_flags(cxxflags_base); + flags = copy.copy(common_flags_new); + if self.is64Platform () : + flags += ['-m64']; + else: + flags += ['-m32']; + + if self.isRelease () : + flags += ['-O3']; + else: + flags += ['-O0']; + + return (flags) + + def get_cxx_flags (self): + flags = self.get_common_flags() + + # support c++ 2011 + flags += ['-std=c++0x'] + + return (flags) + + def get_c_flags (self): + flags = self.get_common_flags() + + # for C no special flags yet + return (flags) def get_link_flags(self): base_flags = []; @@ -569,14 +605,14 @@ def build_prog (bld, build_obj): features='c ', includes = dpdk_includes_path, - cflags = (build_obj.get_flags()+DPDK_WARNING ), + cflags = (build_obj.get_c_flags()+DPDK_WARNING ), source = bp_dpdk.file_list(top), target=build_obj.get_dpdk_target() ); bld.program(features='cxx cxxprogram', includes =includes_path, - cxxflags =build_obj.get_flags(), + cxxflags =build_obj.get_cxx_flags(), linkflags = build_obj.get_link_flags() , lib=['pthread','dl'], use =[build_obj.get_dpdk_target(),'zmq'], @@ -643,8 +679,8 @@ def create_version_files (): s +=" extern \"C\" { \n" s +=" #endif \n"; s +='#define VERSION_USER "%s" \n' % os.environ.get('USER', 'unknown') - s +='extern char * get_build_date(void); \n' - s +='extern char * get_build_time(void); \n' + s +='extern const char * get_build_date(void); \n' + s +='extern const char * get_build_time(void); \n' s +='#define VERSION_UIID "%s" \n' % uuid.uuid1() s +='#define VERSION_BUILD_NUM "%s" \n' % get_build_num() s +="#ifdef __cplusplus \n" @@ -656,20 +692,20 @@ def create_version_files (): s ='#include "version.h" \n' s +='#define VERSION_UIID1 "%s" \n' % uuid.uuid1() - s +="char * get_build_date(void){ \n" + s +="const char * get_build_date(void){ \n" s +=" return (__DATE__); \n" - s +=" } \n" + s +="} \n" s +=" \n" - s +="char * get_build_time(void){ \n" + s +="const char * get_build_time(void){ \n" s +=" return (__TIME__ ); \n" - s +=" } \n" + s +="} \n" write_file (C_VER_FILE,s) def build_test(bld): create_version_files () -def _copy_single_system (bld, exec_p, build_obj,o): +def _copy_single_system (bld, exec_p, build_obj): o='build_dpdk/linux_dpdk/'; src_file = os.path.realpath(o+build_obj.get_target()) print src_file; @@ -679,11 +715,21 @@ def _copy_single_system (bld, exec_p, build_obj,o): os.system("cp %s %s " %(src_file,dest_file)); os.system("chmod +x %s " %(dest_file)); +def _copy_single_system1 (bld, exec_p, build_obj): + o='../scripts/'; + src_file = os.path.realpath(o+build_obj.get_target()[1:]) + print src_file; + if os.path.exists(src_file): + dest_file = exec_p +build_obj.get_target()[1:] + os.system("cp %s %s " %(src_file,dest_file)); + os.system("chmod +x %s " %(dest_file)); + + def copy_single_system (bld, exec_p, build_obj): - _copy_single_system (bld, exec_p, build_obj,'build_dpdk/linux_dpdk/') + _copy_single_system (bld, exec_p, build_obj) def copy_single_system1 (bld, exec_p, build_obj): - _copy_single_system (bld, exec_p, build_obj,'../scripts/') + _copy_single_system1 (bld, exec_p, build_obj) files_list=[ diff --git a/scripts/trex-console b/scripts/trex-console new file mode 100755 index 00000000..50e097e7 --- /dev/null +++ b/scripts/trex-console @@ -0,0 +1,2 @@ +#!/bin/bash +../src/console/trex_console.py $@ diff --git a/src/console/trex_console.py b/src/console/trex_console.py new file mode 100755 index 00000000..1cb8194d --- /dev/null +++ b/src/console/trex_console.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import cmd +import json +import ast +import argparse +import sys + +from trex_rpc_client import RpcClient +import trex_status + +class TrexConsole(cmd.Cmd): + """Trex Console""" + + def __init__(self, rpc_client): + cmd.Cmd.__init__(self) + + self.rpc_client = rpc_client + + self.do_connect("") + + self.intro = "\n-=TRex Console V1.0=-\n" + self.intro += "\nType 'help' or '?' for supported actions\n" + + self.verbose = False + + self.postcmd(False, "") + + + # a cool hack - i stole this function and added space + def completenames(self, text, *ignored): + dotext = 'do_'+text + return [a[3:]+' ' for a in self.get_names() if a.startswith(dotext)] + + # set verbose on / off + def do_verbose (self, line): + '''shows or set verbose mode\n''' + if line == "": + print "\nverbose is " + ("on\n" if self.verbose else "off\n") + + elif line == "on": + self.verbose = True + self.rpc_client.set_verbose(True) + print "\nverbose set to on\n" + + elif line == "off": + self.verbose = False + self.rpc_client.set_verbose(False) + print "\nverbose set to off\n" + + else: + print "\nplease specify 'on' or 'off'\n" + + # query the server for registered commands + def do_query_server(self, line): + '''query the RPC server for supported remote commands\n''' + + rc, msg = self.rpc_client.query_rpc_server() + if not rc: + print "\n*** " + msg + "\n" + return + + print "\nRPC server supports the following commands: \n\n" + for func in msg: + if func: + print func + print "\n" + + def do_ping (self, line): + '''Pings the RPC server\n''' + + print "\n-> Pinging RPC server" + + rc, msg = self.rpc_client.ping_rpc_server() + if rc: + print "[SUCCESS]\n" + else: + print "\n*** " + msg + "\n" + return + + def do_connect (self, line): + '''Connects to the server\n''' + + if line == "": + rc, msg = self.rpc_client.connect() + else: + sp = line.split() + if (len(sp) != 2): + print "\n[usage] connect [server] [port] or without parameters\n" + return + + rc, msg = self.rpc_client.connect(sp[0], sp[1]) + + if rc: + print "[SUCCESS]\n" + else: + print "\n*** " + msg + "\n" + return + + rc, msg = self.rpc_client.query_rpc_server() + + if rc: + self.supported_rpc = [str(x) for x in msg if x] + + def do_rpc (self, line): + '''Launches a RPC on the server\n''' + + if line == "": + print "\nUsage: [method name] [param dict as string]\n" + print "Example: rpc test_add {'x': 12, 'y': 17}\n" + return + + sp = line.split(' ', 1) + method = sp[0] + + params = None + bad_parse = False + if len(sp) > 1: + + try: + params = ast.literal_eval(sp[1]) + if not isinstance(params, dict): + bad_parse = True + + except ValueError as e1: + bad_parse = True + except SyntaxError as e2: + bad_parse = True + + if bad_parse: + print "\nValue should be a valid dict: '{0}'".format(sp[1]) + print "\nUsage: [method name] [param dict as string]\n" + print "Example: rpc test_add {'x': 12, 'y': 17}\n" + return + + rc, msg = self.rpc_client.invoke_rpc_method(method, params) + if rc: + print "\nServer Response:\n\n" + json.dumps(msg) + "\n" + else: + print "\n*** " + msg + "\n" + #print "Please try 'reconnect' to reconnect to server" + + + def complete_rpc (self, text, line, begidx, endidx): + return [x for x in self.supported_rpc if x.startswith(text)] + + def do_status (self, line): + '''Shows a graphical console\n''' + + self.do_verbose('off') + trex_status.show_trex_status(self.rpc_client) + + def do_quit(self, line): + '''exit the client\n''' + return True + + def do_disconnect (self, line): + '''Disconnect from the server\n''' + if not self.rpc_client.is_connected(): + print "Not connected to server\n" + return + + rc, msg = self.rpc_client.disconnect() + if rc: + print "[SUCCESS]\n" + else: + print msg + "\n" + + def postcmd(self, stop, line): + if self.rpc_client.is_connected(): + self.prompt = "TRex > " + else: + self.supported_rpc = None + self.prompt = "TRex (offline) > " + + return stop + + def default(self, line): + print "'{0}' is an unrecognized command. type 'help' or '?' for a list\n".format(line) + + def do_help (self, line): + '''Shows This Help Screen\n''' + if line: + try: + func = getattr(self, 'help_' + line) + except AttributeError: + try: + doc = getattr(self, 'do_' + line).__doc__ + if doc: + self.stdout.write("%s\n"%str(doc)) + return + except AttributeError: + pass + self.stdout.write("%s\n"%str(self.nohelp % (line,))) + return + func() + return + + print "\nSupported Console Commands:" + print "----------------------------\n" + + cmds = [x[3:] for x in self.get_names() if x.startswith("do_")] + for cmd in cmds: + if cmd == "EOF": + continue + + try: + doc = getattr(self, 'do_' + cmd).__doc__ + if doc: + help = str(doc) + else: + help = "*** Undocumented Function ***\n" + except AttributeError: + help = "*** Undocumented Function ***\n" + + print "{:<30} {:<30}".format(cmd + " - ", help) + + + # aliasing + do_exit = do_EOF = do_q = do_quit + +def setParserOptions (): + parser = argparse.ArgumentParser(prog="trex_console.py") + + parser.add_argument("-s", "--server", help = "T-Rex Server [default is localhost]", + default = "localhost", + type = str) + + parser.add_argument("-p", "--port", help = "T-Rex Server Port [default is 5050]\n", + default = 5050, + type = int) + + return parser + +def main (): + parser = setParserOptions() + options = parser.parse_args(sys.argv[1:]) + + # RPC client + rpc_client = RpcClient(options.server, options.port) + + # console + try: + console = TrexConsole(rpc_client) + console.cmdloop() + except KeyboardInterrupt as e: + print "\n\n*** Caught Ctrl + C... Exiting...\n\n" + return + +if __name__ == '__main__': + main() + diff --git a/src/console/trex_rpc_client.py b/src/console/trex_rpc_client.py new file mode 100644 index 00000000..77d5fe1c --- /dev/null +++ b/src/console/trex_rpc_client.py @@ -0,0 +1,180 @@ +
+import zmq
+import json
+from time import sleep
+import random
+
+class RpcClient():
+
+ def __init__ (self, default_server, default_port):
+ self.verbose = False
+ self.connected = False
+
+ # default values
+ self.port = default_port
+ self.server = default_server
+
+ def get_connection_details (self):
+ rc = {}
+ rc['server'] = self.server
+ rc['port'] = self.port
+
+ return rc
+
+ def pretty_json (self, json_str):
+ return json.dumps(json.loads(json_str), indent = 4, separators=(',', ': '), sort_keys = True)
+
+ def verbose_msg (self, msg):
+ if not self.verbose:
+ return
+
+ print "[verbose] " + msg
+
+
+ def create_jsonrpc_v2 (self, method_name, params = {}, id = None):
+ msg = {}
+ msg["jsonrpc"] = "2.0"
+ msg["method"] = method_name
+
+ msg["params"] = params
+
+ msg["id"] = id
+
+ return json.dumps(msg)
+
+ def invoke_rpc_method (self, method_name, params = {}, block = False):
+ rc, msg = self._invoke_rpc_method(method_name, params, block)
+ if not rc:
+ self.disconnect()
+
+ return rc, msg
+
+ def _invoke_rpc_method (self, method_name, params = {}, block = False):
+ if not self.connected:
+ return False, "Not connected to server"
+
+ id = random.randint(1, 1000)
+ msg = self.create_jsonrpc_v2(method_name, params, id = id)
+
+ self.verbose_msg("Sending Request To Server:\n\n" + self.pretty_json(msg) + "\n")
+
+ if block:
+ self.socket.send(msg)
+ else:
+ try:
+ self.socket.send(msg, flags = zmq.NOBLOCK)
+ except zmq.error.ZMQError:
+ return False, "Failed To Get Send Message"
+
+ got_response = False
+
+ if block:
+ response = self.socket.recv()
+ got_response = True
+ else:
+ for i in xrange(0 ,10):
+ try:
+ response = self.socket.recv(flags = zmq.NOBLOCK)
+ got_response = True
+ break
+ except zmq.error.Again:
+ sleep(0.2)
+
+ if not got_response:
+ return False, "Failed To Get Server Response"
+
+ self.verbose_msg("Server Response:\n\n" + self.pretty_json(response) + "\n")
+
+ # decode
+ response_json = json.loads(response)
+
+ if (response_json.get("jsonrpc") != "2.0"):
+ return False, "Malfromed Response ({0})".format(str(response))
+
+ if (response_json.get("id") != id):
+ return False, "Server Replied With Bad ID ({0})".format(str(response))
+
+ # error reported by server
+ if ("error" in response_json):
+ return True, response_json["error"]["message"]
+
+ # if no error there should be a result
+ if ("result" not in response_json):
+ return False, "Malfromed Response ({0})".format(str(response))
+
+ return True, response_json["result"]
+
+
+ def ping_rpc_server (self):
+
+ return self.invoke_rpc_method("ping", block = False)
+
+ def get_rpc_server_status (self):
+ return self.invoke_rpc_method("get_status")
+
+ def query_rpc_server (self):
+ return self.invoke_rpc_method("get_reg_cmds")
+
+
+ def set_verbose (self, mode):
+ self.verbose = mode
+
+ def disconnect (self):
+ if self.connected:
+ self.socket.close(linger = 0)
+ self.context.destroy(linger = 0)
+ self.connected = False
+ return True, ""
+ else:
+ return False, "Not connected to server"
+
+ def connect (self, server = None, port = None):
+ if self.connected:
+ self.disconnect()
+
+ self.context = zmq.Context()
+
+ self.server = (server if server else self.server)
+ self.port = (port if port else self.port)
+
+ # Socket to talk to server
+ self.transport = "tcp://{0}:{1}".format(self.server, self.port)
+
+ print "\nConnecting To RPC Server On {0}".format(self.transport)
+
+ self.socket = self.context.socket(zmq.REQ)
+ try:
+ self.socket.connect(self.transport)
+ except zmq.error.ZMQError as e:
+ return False, "ZMQ Error: Bad server or port name: " + str(e)
+
+
+ self.connected = True
+
+ # ping the server
+ rc, err = self.ping_rpc_server()
+ if not rc:
+ self.disconnect()
+ return rc, err
+
+ return True, ""
+
+ def reconnect (self):
+ # connect using current values
+ return self.connect()
+
+ if not self.connected:
+ return False, "Not connected to server"
+
+ # reconnect
+ return self.connect(self.server, self.port)
+
+ def is_connected (self):
+ return self.connected
+
+ def __del__ (self):
+ print "Shutting down RPC client\n"
+ self.context.destroy(linger = 0)
+
+
+
diff --git a/src/console/trex_status.py b/src/console/trex_status.py new file mode 100755 index 00000000..8ee669b5 --- /dev/null +++ b/src/console/trex_status.py @@ -0,0 +1,212 @@ +from time import sleep + +import os + +import curses +from curses import panel +import random +import collections +import operator +import datetime + +g_curses_active = False + +# +def percentage (a, total): + x = int ((float(a) / total) * 100) + return str(x) + "%" + +# panel object +class TrexStatusPanel(): + def __init__ (self, h, l, y, x, headline): + self.h = h + self.l = l + self.y = y + self.x = x + self.headline = headline + + self.win = curses.newwin(h, l, y, x) + self.win.erase() + self.win.box() + + self.win.addstr(1, 2, headline, curses.A_UNDERLINE) + self.win.refresh() + + panel.new_panel(self.win) + self.panel = panel.new_panel(self.win) + self.panel.top() + + def clear (self): + self.win.erase() + self.win.box() + self.win.addstr(1, 2, self.headline, curses.A_UNDERLINE) + + def getwin (self): + return self.win + +def float_to_human_readable (size, suffix = "bps"): + for unit in ['','K','M','G']: + if abs(size) < 1024.0: + return "%3.1f %s%s" % (size, unit, suffix) + size /= 1024.0 + return "NaN" + +# status object +class TrexStatus(): + def __init__ (self, stdscr, rpc_client): + self.stdscr = stdscr + self.log = [] + self.rpc_client = rpc_client + + self.get_server_info() + + def get_server_info (self): + rc, msg = self.rpc_client.get_rpc_server_status() + + if rc: + self.server_status = msg + else: + self.server_status = None + + def add_log_event (self, msg): + self.log.append("[{0}] {1}".format(str(datetime.datetime.now().time()), msg)) + + def add_panel (self, h, l, y, x, headline): + win = curses.newwin(h, l, y, x) + win.erase() + win.box() + + win.addstr(1, 2, headline) + win.refresh() + + panel.new_panel(win) + panel1 = panel.new_panel(win) + panel1.top() + + return win, panel1 + + # static info panel + def update_info (self): + if self.server_status == None: + return + + self.info_panel.clear() + + connection_details = self.rpc_client.get_connection_details() + + self.info_panel.getwin().addstr(3, 2, "{:<30} {:30}".format("Server:", connection_details['server'] + ":" + str(connection_details['port']))) + self.info_panel.getwin().addstr(4, 2, "{:<30} {:30}".format("Version:", self.server_status["general"]["version"])) + self.info_panel.getwin().addstr(5, 2, "{:<30} {:30}".format("Build:", + self.server_status["general"]["build_date"] + " @ " + self.server_status["general"]["build_time"] + " by " + self.server_status["general"]["version_user"])) + + self.info_panel.getwin().addstr(6, 2, "{:<30} {:30}".format("Server Uptime:", self.server_status["general"]["uptime"])) + + # general stats + def update_general (self, gen_stats): + pass + + # control panel + def update_control (self): + self.control_panel.clear() + + self.control_panel.getwin().addstr(1, 2, "'f' - freeze, 'c' - clear stats, 'p' - ping server, 'q' - quit") + + index = 3 + + cut = len(self.log) - 4 + if cut < 0: + cut = 0 + + for l in self.log[cut:]: + self.control_panel.getwin().addstr(index, 2, l) + index += 1 + + def generate_layout (self): + self.max_y = self.stdscr.getmaxyx()[0] + self.max_x = self.stdscr.getmaxyx()[1] + + # create cls panel + self.main_panel = TrexStatusPanel(int(self.max_y * 0.8), self.max_x / 2, 0,0, "Trex Activity:") + + self.general_panel = TrexStatusPanel(int(self.max_y * 0.6), self.max_x / 2, 0, self.max_x /2, "General Statistics:") + + self.info_panel = TrexStatusPanel(int(self.max_y * 0.2), self.max_x / 2, int(self.max_y * 0.6), self.max_x /2, "Server Info:") + + self.control_panel = TrexStatusPanel(int(self.max_y * 0.2), self.max_x , int(self.max_y * 0.8), 0, "") + + panel.update_panels(); self.stdscr.refresh() + + def wait_for_key_input (self): + ch = self.stdscr.getch() + + if (ch != curses.ERR): + # stop/start status + if (ch == ord('f')): + self.update_active = not self.update_active + self.add_log_event("Update continued" if self.update_active else "Update stopped") + + elif (ch == ord('p')): + self.add_log_event("Pinging RPC server") + rc, msg = self.rpc_client.ping_rpc_server() + if rc: + self.add_log_event("Server replied: '{0}'".format(msg)) + else: + self.add_log_event("Failed to get reply") + + # c - clear stats + elif (ch == ord('c')): + self.add_log_event("Statistics cleared") + + elif (ch == ord('q')): + return False + else: + self.add_log_event("Unknown key pressed {0}".format("'" + chr(ch) + "'" if chr(ch).isalpha() else "")) + + return True + + # main run entry point + def run (self): + try: + curses.curs_set(0) + except: + pass + + curses.use_default_colors() + self.stdscr.nodelay(1) + curses.nonl() + curses.noecho() + + self.generate_layout() + + self.update_active = True + while (True): + + rc = self.wait_for_key_input() + if not rc: + break + + self.update_control() + self.update_info() + + panel.update_panels(); + self.stdscr.refresh() + sleep(0.1) + + +def show_trex_status_internal (stdscr, rpc_client): + trex_status = TrexStatus(stdscr, rpc_client) + trex_status.run() + +def show_trex_status (rpc_client): + + try: + curses.wrapper(show_trex_status_internal, rpc_client) + except KeyboardInterrupt: + curses.endwin() + +def cleanup (): + try: + curses.endwin() + except: + pass + diff --git a/src/console/zmq/__init__.py b/src/console/zmq/__init__.py new file mode 100755 index 00000000..3408b3ba --- /dev/null +++ b/src/console/zmq/__init__.py @@ -0,0 +1,64 @@ +"""Python bindings for 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import os +import sys +import glob + +# load bundled libzmq, if there is one: + +here = os.path.dirname(__file__) + +bundled = [] +bundled_sodium = [] +for ext in ('pyd', 'so', 'dll', 'dylib'): + bundled_sodium.extend(glob.glob(os.path.join(here, 'libsodium*.%s*' % ext))) + bundled.extend(glob.glob(os.path.join(here, 'libzmq*.%s*' % ext))) + +if bundled: + import ctypes + if bundled_sodium: + if bundled[0].endswith('.pyd'): + # a Windows Extension + _libsodium = ctypes.cdll.LoadLibrary(bundled_sodium[0]) + else: + _libsodium = ctypes.CDLL(bundled_sodium[0], mode=ctypes.RTLD_GLOBAL) + if bundled[0].endswith('.pyd'): + # a Windows Extension + _libzmq = ctypes.cdll.LoadLibrary(bundled[0]) + else: + _libzmq = ctypes.CDLL(bundled[0], mode=ctypes.RTLD_GLOBAL) + del ctypes +else: + import zipimport + try: + if isinstance(__loader__, zipimport.zipimporter): + # a zipped pyzmq egg + from zmq import libzmq as _libzmq + except (NameError, ImportError): + pass + finally: + del zipimport + +del os, sys, glob, here, bundled, bundled_sodium, ext + +# zmq top-level imports + +from zmq import backend +from zmq.backend import * +from zmq import sugar +from zmq.sugar import * +from zmq import devices + +def get_includes(): + """Return a list of directories to include for linking against pyzmq with cython.""" + from os.path import join, dirname, abspath, pardir + base = dirname(__file__) + parent = abspath(join(base, pardir)) + return [ parent ] + [ join(parent, base, subdir) for subdir in ('utils',) ] + + +__all__ = ['get_includes'] + sugar.__all__ + backend.__all__ + diff --git a/src/console/zmq/auth/__init__.py b/src/console/zmq/auth/__init__.py new file mode 100755 index 00000000..11d3ad6b --- /dev/null +++ b/src/console/zmq/auth/__init__.py @@ -0,0 +1,10 @@ +"""Utilities for ZAP authentication. + +To run authentication in a background thread, see :mod:`zmq.auth.thread`. +For integration with the tornado eventloop, see :mod:`zmq.auth.ioloop`. + +.. versionadded:: 14.1 +""" + +from .base import * +from .certs import * diff --git a/src/console/zmq/auth/base.py b/src/console/zmq/auth/base.py new file mode 100755 index 00000000..9b4aaed7 --- /dev/null +++ b/src/console/zmq/auth/base.py @@ -0,0 +1,272 @@ +"""Base implementation of 0MQ authentication.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import logging + +import zmq +from zmq.utils import z85 +from zmq.utils.strtypes import bytes, unicode, b, u +from zmq.error import _check_version + +from .certs import load_certificates + + +CURVE_ALLOW_ANY = '*' +VERSION = b'1.0' + +class Authenticator(object): + """Implementation of ZAP authentication for zmq connections. + + Note: + - libzmq provides four levels of security: default NULL (which the Authenticator does + not see), and authenticated NULL, PLAIN, and CURVE, which the Authenticator can see. + - until you add policies, all incoming NULL connections are allowed + (classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied. + """ + + def __init__(self, context=None, encoding='utf-8', log=None): + _check_version((4,0), "security") + self.context = context or zmq.Context.instance() + self.encoding = encoding + self.allow_any = False + self.zap_socket = None + self.whitelist = set() + self.blacklist = set() + # passwords is a dict keyed by domain and contains values + # of dicts with username:password pairs. + self.passwords = {} + # certs is dict keyed by domain and contains values + # of dicts keyed by the public keys from the specified location. + self.certs = {} + self.log = log or logging.getLogger('zmq.auth') + + def start(self): + """Create and bind the ZAP socket""" + self.zap_socket = self.context.socket(zmq.REP) + self.zap_socket.linger = 1 + self.zap_socket.bind("inproc://zeromq.zap.01") + + def stop(self): + """Close the ZAP socket""" + if self.zap_socket: + self.zap_socket.close() + self.zap_socket = None + + def allow(self, *addresses): + """Allow (whitelist) IP address(es). + + Connections from addresses not in the whitelist will be rejected. + + - For NULL, all clients from this address will be accepted. + - For PLAIN and CURVE, they will be allowed to continue with authentication. + + whitelist is mutually exclusive with blacklist. + """ + if self.blacklist: + raise ValueError("Only use a whitelist or a blacklist, not both") + self.whitelist.update(addresses) + + def deny(self, *addresses): + """Deny (blacklist) IP address(es). + + Addresses not in the blacklist will be allowed to continue with authentication. + + Blacklist is mutually exclusive with whitelist. + """ + if self.whitelist: + raise ValueError("Only use a whitelist or a blacklist, not both") + self.blacklist.update(addresses) + + def configure_plain(self, domain='*', passwords=None): + """Configure PLAIN authentication for a given domain. + + PLAIN authentication uses a plain-text password file. + To cover all domains, use "*". + You can modify the password file at any time; it is reloaded automatically. + """ + if passwords: + self.passwords[domain] = passwords + + def configure_curve(self, domain='*', location=None): + """Configure CURVE authentication for a given domain. + + CURVE authentication uses a directory that holds all public client certificates, + i.e. their public keys. + + To cover all domains, use "*". + + You can add and remove certificates in that directory at any time. + + To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location. + """ + # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise + # treat location as a directory that holds the certificates. + if location == CURVE_ALLOW_ANY: + self.allow_any = True + else: + self.allow_any = False + try: + self.certs[domain] = load_certificates(location) + except Exception as e: + self.log.error("Failed to load CURVE certs from %s: %s", location, e) + + def handle_zap_message(self, msg): + """Perform ZAP authentication""" + if len(msg) < 6: + self.log.error("Invalid ZAP message, not enough frames: %r", msg) + if len(msg) < 2: + self.log.error("Not enough information to reply") + else: + self._send_zap_reply(msg[1], b"400", b"Not enough frames") + return + + version, request_id, domain, address, identity, mechanism = msg[:6] + credentials = msg[6:] + + domain = u(domain, self.encoding, 'replace') + address = u(address, self.encoding, 'replace') + + if (version != VERSION): + self.log.error("Invalid ZAP version: %r", msg) + self._send_zap_reply(request_id, b"400", b"Invalid version") + return + + self.log.debug("version: %r, request_id: %r, domain: %r," + " address: %r, identity: %r, mechanism: %r", + version, request_id, domain, + address, identity, mechanism, + ) + + + # Is address is explicitly whitelisted or blacklisted? + allowed = False + denied = False + reason = b"NO ACCESS" + + if self.whitelist: + if address in self.whitelist: + allowed = True + self.log.debug("PASSED (whitelist) address=%s", address) + else: + denied = True + reason = b"Address not in whitelist" + self.log.debug("DENIED (not in whitelist) address=%s", address) + + elif self.blacklist: + if address in self.blacklist: + denied = True + reason = b"Address is blacklisted" + self.log.debug("DENIED (blacklist) address=%s", address) + else: + allowed = True + self.log.debug("PASSED (not in blacklist) address=%s", address) + + # Perform authentication mechanism-specific checks if necessary + username = u("user") + if not denied: + + if mechanism == b'NULL' and not allowed: + # For NULL, we allow if the address wasn't blacklisted + self.log.debug("ALLOWED (NULL)") + allowed = True + + elif mechanism == b'PLAIN': + # For PLAIN, even a whitelisted address must authenticate + if len(credentials) != 2: + self.log.error("Invalid PLAIN credentials: %r", credentials) + self._send_zap_reply(request_id, b"400", b"Invalid credentials") + return + username, password = [ u(c, self.encoding, 'replace') for c in credentials ] + allowed, reason = self._authenticate_plain(domain, username, password) + + elif mechanism == b'CURVE': + # For CURVE, even a whitelisted address must authenticate + if len(credentials) != 1: + self.log.error("Invalid CURVE credentials: %r", credentials) + self._send_zap_reply(request_id, b"400", b"Invalid credentials") + return + key = credentials[0] + allowed, reason = self._authenticate_curve(domain, key) + + if allowed: + self._send_zap_reply(request_id, b"200", b"OK", username) + else: + self._send_zap_reply(request_id, b"400", reason) + + def _authenticate_plain(self, domain, username, password): + """PLAIN ZAP authentication""" + allowed = False + reason = b"" + if self.passwords: + # If no domain is not specified then use the default domain + if not domain: + domain = '*' + + if domain in self.passwords: + if username in self.passwords[domain]: + if password == self.passwords[domain][username]: + allowed = True + else: + reason = b"Invalid password" + else: + reason = b"Invalid username" + else: + reason = b"Invalid domain" + + if allowed: + self.log.debug("ALLOWED (PLAIN) domain=%s username=%s password=%s", + domain, username, password, + ) + else: + self.log.debug("DENIED %s", reason) + + else: + reason = b"No passwords defined" + self.log.debug("DENIED (PLAIN) %s", reason) + + return allowed, reason + + def _authenticate_curve(self, domain, client_key): + """CURVE ZAP authentication""" + allowed = False + reason = b"" + if self.allow_any: + allowed = True + reason = b"OK" + self.log.debug("ALLOWED (CURVE allow any client)") + else: + # If no explicit domain is specified then use the default domain + if not domain: + domain = '*' + + if domain in self.certs: + # The certs dict stores keys in z85 format, convert binary key to z85 bytes + z85_client_key = z85.encode(client_key) + if z85_client_key in self.certs[domain] or self.certs[domain] == b'OK': + allowed = True + reason = b"OK" + else: + reason = b"Unknown key" + + status = "ALLOWED" if allowed else "DENIED" + self.log.debug("%s (CURVE) domain=%s client_key=%s", + status, domain, z85_client_key, + ) + else: + reason = b"Unknown domain" + + return allowed, reason + + def _send_zap_reply(self, request_id, status_code, status_text, user_id='user'): + """Send a ZAP reply to finish the authentication.""" + user_id = user_id if status_code == b'200' else b'' + if isinstance(user_id, unicode): + user_id = user_id.encode(self.encoding, 'replace') + metadata = b'' # not currently used + self.log.debug("ZAP reply code=%s text=%s", status_code, status_text) + reply = [VERSION, request_id, status_code, status_text, user_id, metadata] + self.zap_socket.send_multipart(reply) + +__all__ = ['Authenticator', 'CURVE_ALLOW_ANY'] diff --git a/src/console/zmq/auth/certs.py b/src/console/zmq/auth/certs.py new file mode 100755 index 00000000..4d26ad7b --- /dev/null +++ b/src/console/zmq/auth/certs.py @@ -0,0 +1,119 @@ +"""0MQ authentication related functions and classes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import datetime +import glob +import io +import os +import zmq +from zmq.utils.strtypes import bytes, unicode, b, u + + +_cert_secret_banner = u("""# **** Generated on {0} by pyzmq **** +# ZeroMQ CURVE **Secret** Certificate +# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions. + +""") + +_cert_public_banner = u("""# **** Generated on {0} by pyzmq **** +# ZeroMQ CURVE Public Certificate +# Exchange securely, or use a secure mechanism to verify the contents +# of this file after exchange. Store public certificates in your home +# directory, in the .curve subdirectory. + +""") + +def _write_key_file(key_filename, banner, public_key, secret_key=None, metadata=None, encoding='utf-8'): + """Create a certificate file""" + if isinstance(public_key, bytes): + public_key = public_key.decode(encoding) + if isinstance(secret_key, bytes): + secret_key = secret_key.decode(encoding) + with io.open(key_filename, 'w', encoding='utf8') as f: + f.write(banner.format(datetime.datetime.now())) + + f.write(u('metadata\n')) + if metadata: + for k, v in metadata.items(): + if isinstance(v, bytes): + v = v.decode(encoding) + f.write(u(" {0} = {1}\n").format(k, v)) + + f.write(u('curve\n')) + f.write(u(" public-key = \"{0}\"\n").format(public_key)) + + if secret_key: + f.write(u(" secret-key = \"{0}\"\n").format(secret_key)) + + +def create_certificates(key_dir, name, metadata=None): + """Create zmq certificates. + + Returns the file paths to the public and secret certificate files. + """ + public_key, secret_key = zmq.curve_keypair() + base_filename = os.path.join(key_dir, name) + secret_key_file = "{0}.key_secret".format(base_filename) + public_key_file = "{0}.key".format(base_filename) + now = datetime.datetime.now() + + _write_key_file(public_key_file, + _cert_public_banner.format(now), + public_key) + + _write_key_file(secret_key_file, + _cert_secret_banner.format(now), + public_key, + secret_key=secret_key, + metadata=metadata) + + return public_key_file, secret_key_file + + +def load_certificate(filename): + """Load public and secret key from a zmq certificate. + + Returns (public_key, secret_key) + + If the certificate file only contains the public key, + secret_key will be None. + """ + public_key = None + secret_key = None + if not os.path.exists(filename): + raise IOError("Invalid certificate file: {0}".format(filename)) + + with open(filename, 'rb') as f: + for line in f: + line = line.strip() + if line.startswith(b'#'): + continue + if line.startswith(b'public-key'): + public_key = line.split(b"=", 1)[1].strip(b' \t\'"') + if line.startswith(b'secret-key'): + secret_key = line.split(b"=", 1)[1].strip(b' \t\'"') + if public_key and secret_key: + break + + return public_key, secret_key + + +def load_certificates(directory='.'): + """Load public keys from all certificates in a directory""" + certs = {} + if not os.path.isdir(directory): + raise IOError("Invalid certificate directory: {0}".format(directory)) + # Follow czmq pattern of public keys stored in *.key files. + glob_string = os.path.join(directory, "*.key") + + cert_files = glob.glob(glob_string) + for cert_file in cert_files: + public_key, _ = load_certificate(cert_file) + if public_key: + certs[public_key] = 'OK' + return certs + +__all__ = ['create_certificates', 'load_certificate', 'load_certificates'] diff --git a/src/console/zmq/auth/ioloop.py b/src/console/zmq/auth/ioloop.py new file mode 100755 index 00000000..1f448b47 --- /dev/null +++ b/src/console/zmq/auth/ioloop.py @@ -0,0 +1,34 @@ +"""ZAP Authenticator integrated with the tornado IOLoop. + +.. versionadded:: 14.1 +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.eventloop import ioloop, zmqstream +from .base import Authenticator + + +class IOLoopAuthenticator(Authenticator): + """ZAP authentication for use in the tornado IOLoop""" + + def __init__(self, context=None, encoding='utf-8', log=None, io_loop=None): + super(IOLoopAuthenticator, self).__init__(context) + self.zap_stream = None + self.io_loop = io_loop or ioloop.IOLoop.instance() + + def start(self): + """Start ZAP authentication""" + super(IOLoopAuthenticator, self).start() + self.zap_stream = zmqstream.ZMQStream(self.zap_socket, self.io_loop) + self.zap_stream.on_recv(self.handle_zap_message) + + def stop(self): + """Stop ZAP authentication""" + if self.zap_stream: + self.zap_stream.close() + self.zap_stream = None + super(IOLoopAuthenticator, self).stop() + +__all__ = ['IOLoopAuthenticator'] diff --git a/src/console/zmq/auth/thread.py b/src/console/zmq/auth/thread.py new file mode 100755 index 00000000..8c3355a9 --- /dev/null +++ b/src/console/zmq/auth/thread.py @@ -0,0 +1,184 @@ +"""ZAP Authenticator in a Python Thread. + +.. versionadded:: 14.1 +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import logging +from threading import Thread + +import zmq +from zmq.utils import jsonapi +from zmq.utils.strtypes import bytes, unicode, b, u + +from .base import Authenticator + +class AuthenticationThread(Thread): + """A Thread for running a zmq Authenticator + + This is run in the background by ThreadedAuthenticator + """ + + def __init__(self, context, endpoint, encoding='utf-8', log=None): + super(AuthenticationThread, self).__init__() + self.context = context or zmq.Context.instance() + self.encoding = encoding + self.log = log = log or logging.getLogger('zmq.auth') + self.authenticator = Authenticator(context, encoding=encoding, log=log) + + # create a socket to communicate back to main thread. + self.pipe = context.socket(zmq.PAIR) + self.pipe.linger = 1 + self.pipe.connect(endpoint) + + def run(self): + """ Start the Authentication Agent thread task """ + self.authenticator.start() + zap = self.authenticator.zap_socket + poller = zmq.Poller() + poller.register(self.pipe, zmq.POLLIN) + poller.register(zap, zmq.POLLIN) + while True: + try: + socks = dict(poller.poll()) + except zmq.ZMQError: + break # interrupted + + if self.pipe in socks and socks[self.pipe] == zmq.POLLIN: + terminate = self._handle_pipe() + if terminate: + break + + if zap in socks and socks[zap] == zmq.POLLIN: + self._handle_zap() + + self.pipe.close() + self.authenticator.stop() + + def _handle_zap(self): + """ + Handle a message from the ZAP socket. + """ + msg = self.authenticator.zap_socket.recv_multipart() + if not msg: return + self.authenticator.handle_zap_message(msg) + + def _handle_pipe(self): + """ + Handle a message from front-end API. + """ + terminate = False + + # Get the whole message off the pipe in one go + msg = self.pipe.recv_multipart() + + if msg is None: + terminate = True + return terminate + + command = msg[0] + self.log.debug("auth received API command %r", command) + + if command == b'ALLOW': + addresses = [u(m, self.encoding) for m in msg[1:]] + try: + self.authenticator.allow(*addresses) + except Exception as e: + self.log.exception("Failed to allow %s", addresses) + + elif command == b'DENY': + addresses = [u(m, self.encoding) for m in msg[1:]] + try: + self.authenticator.deny(*addresses) + except Exception as e: + self.log.exception("Failed to deny %s", addresses) + + elif command == b'PLAIN': + domain = u(msg[1], self.encoding) + json_passwords = msg[2] + self.authenticator.configure_plain(domain, jsonapi.loads(json_passwords)) + + elif command == b'CURVE': + # For now we don't do anything with domains + domain = u(msg[1], self.encoding) + + # If location is CURVE_ALLOW_ANY, allow all clients. Otherwise + # treat location as a directory that holds the certificates. + location = u(msg[2], self.encoding) + self.authenticator.configure_curve(domain, location) + + elif command == b'TERMINATE': + terminate = True + + else: + self.log.error("Invalid auth command from API: %r", command) + + return terminate + +def _inherit_docstrings(cls): + """inherit docstrings from Authenticator, so we don't duplicate them""" + for name, method in cls.__dict__.items(): + if name.startswith('_'): + continue + upstream_method = getattr(Authenticator, name, None) + if not method.__doc__: + method.__doc__ = upstream_method.__doc__ + return cls + +@_inherit_docstrings +class ThreadAuthenticator(object): + """Run ZAP authentication in a background thread""" + + def __init__(self, context=None, encoding='utf-8', log=None): + self.context = context or zmq.Context.instance() + self.log = log + self.encoding = encoding + self.pipe = None + self.pipe_endpoint = "inproc://{0}.inproc".format(id(self)) + self.thread = None + + def allow(self, *addresses): + self.pipe.send_multipart([b'ALLOW'] + [b(a, self.encoding) for a in addresses]) + + def deny(self, *addresses): + self.pipe.send_multipart([b'DENY'] + [b(a, self.encoding) for a in addresses]) + + def configure_plain(self, domain='*', passwords=None): + self.pipe.send_multipart([b'PLAIN', b(domain, self.encoding), jsonapi.dumps(passwords or {})]) + + def configure_curve(self, domain='*', location=''): + domain = b(domain, self.encoding) + location = b(location, self.encoding) + self.pipe.send_multipart([b'CURVE', domain, location]) + + def start(self): + """Start the authentication thread""" + # create a socket to communicate with auth thread. + self.pipe = self.context.socket(zmq.PAIR) + self.pipe.linger = 1 + self.pipe.bind(self.pipe_endpoint) + self.thread = AuthenticationThread(self.context, self.pipe_endpoint, encoding=self.encoding, log=self.log) + self.thread.start() + + def stop(self): + """Stop the authentication thread""" + if self.pipe: + self.pipe.send(b'TERMINATE') + if self.is_alive(): + self.thread.join() + self.thread = None + self.pipe.close() + self.pipe = None + + def is_alive(self): + """Is the ZAP thread currently running?""" + if self.thread and self.thread.is_alive(): + return True + return False + + def __del__(self): + self.stop() + +__all__ = ['ThreadAuthenticator'] diff --git a/src/console/zmq/backend/__init__.py b/src/console/zmq/backend/__init__.py new file mode 100755 index 00000000..7cac725c --- /dev/null +++ b/src/console/zmq/backend/__init__.py @@ -0,0 +1,45 @@ +"""Import basic exposure of libzmq C API as a backend""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import os +import platform +import sys + +from zmq.utils.sixcerpt import reraise + +from .select import public_api, select_backend + +if 'PYZMQ_BACKEND' in os.environ: + backend = os.environ['PYZMQ_BACKEND'] + if backend in ('cython', 'cffi'): + backend = 'zmq.backend.%s' % backend + _ns = select_backend(backend) +else: + # default to cython, fallback to cffi + # (reverse on PyPy) + if platform.python_implementation() == 'PyPy': + first, second = ('zmq.backend.cffi', 'zmq.backend.cython') + else: + first, second = ('zmq.backend.cython', 'zmq.backend.cffi') + + try: + _ns = select_backend(first) + except Exception: + exc_info = sys.exc_info() + exc = exc_info[1] + try: + _ns = select_backend(second) + except ImportError: + # prevent 'During handling of the above exception...' on py3 + # can't use `raise ... from` on Python 2 + if hasattr(exc, '__cause__'): + exc.__cause__ = None + # raise the *first* error, not the fallback + reraise(*exc_info) + +globals().update(_ns) + +__all__ = public_api diff --git a/src/console/zmq/backend/cffi/__init__.py b/src/console/zmq/backend/cffi/__init__.py new file mode 100755 index 00000000..ca3164d3 --- /dev/null +++ b/src/console/zmq/backend/cffi/__init__.py @@ -0,0 +1,22 @@ +"""CFFI backend (for PyPY)""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.backend.cffi import (constants, error, message, context, socket, + _poll, devices, utils) + +__all__ = [] +for submod in (constants, error, message, context, socket, + _poll, devices, utils): + __all__.extend(submod.__all__) + +from .constants import * +from .error import * +from .message import * +from .context import * +from .socket import * +from .devices import * +from ._poll import * +from ._cffi import zmq_version_info, ffi +from .utils import * diff --git a/src/console/zmq/backend/cffi/_cdefs.h b/src/console/zmq/backend/cffi/_cdefs.h new file mode 100644 index 00000000..d3300575 --- /dev/null +++ b/src/console/zmq/backend/cffi/_cdefs.h @@ -0,0 +1,68 @@ +void zmq_version(int *major, int *minor, int *patch); + +void* zmq_socket(void *context, int type); +int zmq_close(void *socket); + +int zmq_bind(void *socket, const char *endpoint); +int zmq_connect(void *socket, const char *endpoint); + +int zmq_errno(void); +const char * zmq_strerror(int errnum); + +void* zmq_stopwatch_start(void); +unsigned long zmq_stopwatch_stop(void *watch); +void zmq_sleep(int seconds_); +int zmq_device(int device, void *frontend, void *backend); + +int zmq_unbind(void *socket, const char *endpoint); +int zmq_disconnect(void *socket, const char *endpoint); +void* zmq_ctx_new(); +int zmq_ctx_destroy(void *context); +int zmq_ctx_get(void *context, int opt); +int zmq_ctx_set(void *context, int opt, int optval); +int zmq_proxy(void *frontend, void *backend, void *capture); +int zmq_socket_monitor(void *socket, const char *addr, int events); + +int zmq_curve_keypair (char *z85_public_key, char *z85_secret_key); +int zmq_has (const char *capability); + +typedef struct { ...; } zmq_msg_t; +typedef ... zmq_free_fn; + +int zmq_msg_init(zmq_msg_t *msg); +int zmq_msg_init_size(zmq_msg_t *msg, size_t size); +int zmq_msg_init_data(zmq_msg_t *msg, + void *data, + size_t size, + zmq_free_fn *ffn, + void *hint); + +size_t zmq_msg_size(zmq_msg_t *msg); +void *zmq_msg_data(zmq_msg_t *msg); +int zmq_msg_close(zmq_msg_t *msg); + +int zmq_msg_send(zmq_msg_t *msg, void *socket, int flags); +int zmq_msg_recv(zmq_msg_t *msg, void *socket, int flags); + +int zmq_getsockopt(void *socket, + int option_name, + void *option_value, + size_t *option_len); + +int zmq_setsockopt(void *socket, + int option_name, + const void *option_value, + size_t option_len); +typedef struct +{ + void *socket; + int fd; + short events; + short revents; +} zmq_pollitem_t; + +int zmq_poll(zmq_pollitem_t *items, int nitems, long timeout); + +// miscellany +void * memcpy(void *restrict s1, const void *restrict s2, size_t n); +int get_ipc_path_max_len(void); diff --git a/src/console/zmq/backend/cffi/_cffi.py b/src/console/zmq/backend/cffi/_cffi.py new file mode 100755 index 00000000..c73ebf83 --- /dev/null +++ b/src/console/zmq/backend/cffi/_cffi.py @@ -0,0 +1,127 @@ +# coding: utf-8 +"""The main CFFI wrapping of libzmq""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import json +import os +from os.path import dirname, join +from cffi import FFI + +from zmq.utils.constant_names import all_names, no_prefix + + +base_zmq_version = (3,2,2) + +def load_compiler_config(): + """load pyzmq compiler arguments""" + import zmq + zmq_dir = dirname(zmq.__file__) + zmq_parent = dirname(zmq_dir) + + fname = join(zmq_dir, 'utils', 'compiler.json') + if os.path.exists(fname): + with open(fname) as f: + cfg = json.load(f) + else: + cfg = {} + + cfg.setdefault("include_dirs", []) + cfg.setdefault("library_dirs", []) + cfg.setdefault("runtime_library_dirs", []) + cfg.setdefault("libraries", ["zmq"]) + + # cast to str, because cffi can't handle unicode paths (?!) + cfg['libraries'] = [str(lib) for lib in cfg['libraries']] + for key in ("include_dirs", "library_dirs", "runtime_library_dirs"): + # interpret paths relative to parent of zmq (like source tree) + abs_paths = [] + for p in cfg[key]: + if p.startswith('zmq'): + p = join(zmq_parent, p) + abs_paths.append(str(p)) + cfg[key] = abs_paths + return cfg + + +def zmq_version_info(): + """Get libzmq version as tuple of ints""" + major = ffi.new('int*') + minor = ffi.new('int*') + patch = ffi.new('int*') + + C.zmq_version(major, minor, patch) + + return (int(major[0]), int(minor[0]), int(patch[0])) + + +cfg = load_compiler_config() +ffi = FFI() + +def _make_defines(names): + _names = [] + for name in names: + define_line = "#define %s ..." % (name) + _names.append(define_line) + + return "\n".join(_names) + +c_constant_names = [] +for name in all_names: + if no_prefix(name): + c_constant_names.append(name) + else: + c_constant_names.append("ZMQ_" + name) + +# load ffi definitions +here = os.path.dirname(__file__) +with open(os.path.join(here, '_cdefs.h')) as f: + _cdefs = f.read() + +with open(os.path.join(here, '_verify.c')) as f: + _verify = f.read() + +ffi.cdef(_cdefs) +ffi.cdef(_make_defines(c_constant_names)) + +try: + C = ffi.verify(_verify, + modulename='_cffi_ext', + libraries=cfg['libraries'], + include_dirs=cfg['include_dirs'], + library_dirs=cfg['library_dirs'], + runtime_library_dirs=cfg['runtime_library_dirs'], + ) + _version_info = zmq_version_info() +except Exception as e: + raise ImportError("PyZMQ CFFI backend couldn't find zeromq: %s\n" + "Please check that you have zeromq headers and libraries." % e) + +if _version_info < (3,2,2): + raise ImportError("PyZMQ CFFI backend requires zeromq >= 3.2.2," + " but found %i.%i.%i" % _version_info + ) + +nsp = new_sizet_pointer = lambda length: ffi.new('size_t*', length) + +new_uint64_pointer = lambda: (ffi.new('uint64_t*'), + nsp(ffi.sizeof('uint64_t'))) +new_int64_pointer = lambda: (ffi.new('int64_t*'), + nsp(ffi.sizeof('int64_t'))) +new_int_pointer = lambda: (ffi.new('int*'), + nsp(ffi.sizeof('int'))) +new_binary_data = lambda length: (ffi.new('char[%d]' % (length)), + nsp(ffi.sizeof('char') * length)) + +value_uint64_pointer = lambda val : (ffi.new('uint64_t*', val), + ffi.sizeof('uint64_t')) +value_int64_pointer = lambda val: (ffi.new('int64_t*', val), + ffi.sizeof('int64_t')) +value_int_pointer = lambda val: (ffi.new('int*', val), + ffi.sizeof('int')) +value_binary_data = lambda val, length: (ffi.new('char[%d]' % (length + 1), val), + ffi.sizeof('char') * length) + +IPC_PATH_MAX_LEN = C.get_ipc_path_max_len() diff --git a/src/console/zmq/backend/cffi/_poll.py b/src/console/zmq/backend/cffi/_poll.py new file mode 100755 index 00000000..9bca34ca --- /dev/null +++ b/src/console/zmq/backend/cffi/_poll.py @@ -0,0 +1,56 @@ +# coding: utf-8 +"""zmq poll function""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import C, ffi, zmq_version_info + +from .constants import * + +from zmq.error import _check_rc + + +def _make_zmq_pollitem(socket, flags): + zmq_socket = socket._zmq_socket + zmq_pollitem = ffi.new('zmq_pollitem_t*') + zmq_pollitem.socket = zmq_socket + zmq_pollitem.fd = 0 + zmq_pollitem.events = flags + zmq_pollitem.revents = 0 + return zmq_pollitem[0] + +def _make_zmq_pollitem_fromfd(socket_fd, flags): + zmq_pollitem = ffi.new('zmq_pollitem_t*') + zmq_pollitem.socket = ffi.NULL + zmq_pollitem.fd = socket_fd + zmq_pollitem.events = flags + zmq_pollitem.revents = 0 + return zmq_pollitem[0] + +def zmq_poll(sockets, timeout): + cffi_pollitem_list = [] + low_level_to_socket_obj = {} + for item in sockets: + if isinstance(item[0], int): + low_level_to_socket_obj[item[0]] = item + cffi_pollitem_list.append(_make_zmq_pollitem_fromfd(item[0], item[1])) + else: + low_level_to_socket_obj[item[0]._zmq_socket] = item + cffi_pollitem_list.append(_make_zmq_pollitem(item[0], item[1])) + items = ffi.new('zmq_pollitem_t[]', cffi_pollitem_list) + list_length = ffi.cast('int', len(cffi_pollitem_list)) + c_timeout = ffi.cast('long', timeout) + rc = C.zmq_poll(items, list_length, c_timeout) + _check_rc(rc) + result = [] + for index in range(len(items)): + if not items[index].socket == ffi.NULL: + if items[index].revents > 0: + result.append((low_level_to_socket_obj[items[index].socket][0], + items[index].revents)) + else: + result.append((items[index].fd, items[index].revents)) + return result + +__all__ = ['zmq_poll'] diff --git a/src/console/zmq/backend/cffi/_verify.c b/src/console/zmq/backend/cffi/_verify.c new file mode 100644 index 00000000..547840eb --- /dev/null +++ b/src/console/zmq/backend/cffi/_verify.c @@ -0,0 +1,12 @@ +#include <stdio.h> +#include <sys/un.h> +#include <string.h> + +#include <zmq.h> +#include <zmq_utils.h> +#include "zmq_compat.h" + +int get_ipc_path_max_len(void) { + struct sockaddr_un *dummy; + return sizeof(dummy->sun_path) - 1; +} diff --git a/src/console/zmq/backend/cffi/constants.py b/src/console/zmq/backend/cffi/constants.py new file mode 100755 index 00000000..ee293e74 --- /dev/null +++ b/src/console/zmq/backend/cffi/constants.py @@ -0,0 +1,15 @@ +# coding: utf-8 +"""zmq constants""" + +from ._cffi import C, c_constant_names +from zmq.utils.constant_names import all_names + +g = globals() +for cname in c_constant_names: + if cname.startswith("ZMQ_"): + name = cname[4:] + else: + name = cname + g[name] = getattr(C, cname) + +__all__ = all_names diff --git a/src/console/zmq/backend/cffi/context.py b/src/console/zmq/backend/cffi/context.py new file mode 100755 index 00000000..16a7b257 --- /dev/null +++ b/src/console/zmq/backend/cffi/context.py @@ -0,0 +1,100 @@ +# coding: utf-8 +"""zmq Context class""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import weakref + +from ._cffi import C, ffi + +from .socket import * +from .constants import * + +from zmq.error import ZMQError, _check_rc + +class Context(object): + _zmq_ctx = None + _iothreads = None + _closed = None + _sockets = None + _shadow = False + + def __init__(self, io_threads=1, shadow=None): + + if shadow: + self._zmq_ctx = ffi.cast("void *", shadow) + self._shadow = True + else: + self._shadow = False + if not io_threads >= 0: + raise ZMQError(EINVAL) + + self._zmq_ctx = C.zmq_ctx_new() + if self._zmq_ctx == ffi.NULL: + raise ZMQError(C.zmq_errno()) + if not shadow: + C.zmq_ctx_set(self._zmq_ctx, IO_THREADS, io_threads) + self._closed = False + self._sockets = set() + + @property + def underlying(self): + """The address of the underlying libzmq context""" + return int(ffi.cast('size_t', self._zmq_ctx)) + + @property + def closed(self): + return self._closed + + def _add_socket(self, socket): + ref = weakref.ref(socket) + self._sockets.add(ref) + return ref + + def _rm_socket(self, ref): + if ref in self._sockets: + self._sockets.remove(ref) + + def set(self, option, value): + """set a context option + + see zmq_ctx_set + """ + rc = C.zmq_ctx_set(self._zmq_ctx, option, value) + _check_rc(rc) + + def get(self, option): + """get context option + + see zmq_ctx_get + """ + rc = C.zmq_ctx_get(self._zmq_ctx, option) + _check_rc(rc) + return rc + + def term(self): + if self.closed: + return + + C.zmq_ctx_destroy(self._zmq_ctx) + + self._zmq_ctx = None + self._closed = True + + def destroy(self, linger=None): + if self.closed: + return + + sockets = self._sockets + self._sockets = set() + for s in sockets: + s = s() + if s and not s.closed: + if linger: + s.setsockopt(LINGER, linger) + s.close() + + self.term() + +__all__ = ['Context'] diff --git a/src/console/zmq/backend/cffi/devices.py b/src/console/zmq/backend/cffi/devices.py new file mode 100755 index 00000000..c7a514a8 --- /dev/null +++ b/src/console/zmq/backend/cffi/devices.py @@ -0,0 +1,24 @@ +# coding: utf-8 +"""zmq device functions""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import C, ffi, zmq_version_info +from .socket import Socket +from zmq.error import ZMQError, _check_rc + +def device(device_type, frontend, backend): + rc = C.zmq_proxy(frontend._zmq_socket, backend._zmq_socket, ffi.NULL) + _check_rc(rc) + +def proxy(frontend, backend, capture=None): + if isinstance(capture, Socket): + capture = capture._zmq_socket + else: + capture = ffi.NULL + + rc = C.zmq_proxy(frontend._zmq_socket, backend._zmq_socket, capture) + _check_rc(rc) + +__all__ = ['device', 'proxy'] diff --git a/src/console/zmq/backend/cffi/error.py b/src/console/zmq/backend/cffi/error.py new file mode 100755 index 00000000..3bb64de0 --- /dev/null +++ b/src/console/zmq/backend/cffi/error.py @@ -0,0 +1,13 @@ +"""zmq error functions""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import C, ffi + +def strerror(errno): + return ffi.string(C.zmq_strerror(errno)) + +zmq_errno = C.zmq_errno + +__all__ = ['strerror', 'zmq_errno'] diff --git a/src/console/zmq/backend/cffi/message.py b/src/console/zmq/backend/cffi/message.py new file mode 100755 index 00000000..c35decb6 --- /dev/null +++ b/src/console/zmq/backend/cffi/message.py @@ -0,0 +1,69 @@ +"""Dummy Frame object""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import ffi, C + +import zmq +from zmq.utils.strtypes import unicode + +try: + view = memoryview +except NameError: + view = buffer + +_content = lambda x: x.tobytes() if type(x) == memoryview else x + +class Frame(object): + _data = None + tracker = None + closed = False + more = False + buffer = None + + + def __init__(self, data, track=False): + try: + view(data) + except TypeError: + raise + + self._data = data + + if isinstance(data, unicode): + raise TypeError("Unicode objects not allowed. Only: str/bytes, " + + "buffer interfaces.") + + self.more = False + self.tracker = None + self.closed = False + if track: + self.tracker = zmq.MessageTracker() + + self.buffer = view(self.bytes) + + @property + def bytes(self): + data = _content(self._data) + return data + + def __len__(self): + return len(self.bytes) + + def __eq__(self, other): + return self.bytes == _content(other) + + def __str__(self): + if str is unicode: + return self.bytes.decode() + else: + return self.bytes + + @property + def done(self): + return True + +Message = Frame + +__all__ = ['Frame', 'Message'] diff --git a/src/console/zmq/backend/cffi/socket.py b/src/console/zmq/backend/cffi/socket.py new file mode 100755 index 00000000..3c427739 --- /dev/null +++ b/src/console/zmq/backend/cffi/socket.py @@ -0,0 +1,244 @@ +# coding: utf-8 +"""zmq Socket class""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import random +import codecs + +import errno as errno_mod + +from ._cffi import (C, ffi, new_uint64_pointer, new_int64_pointer, + new_int_pointer, new_binary_data, value_uint64_pointer, + value_int64_pointer, value_int_pointer, value_binary_data, + IPC_PATH_MAX_LEN) + +from .message import Frame +from .constants import * + +import zmq +from zmq.error import ZMQError, _check_rc, _check_version +from zmq.utils.strtypes import unicode + + +def new_pointer_from_opt(option, length=0): + from zmq.sugar.constants import ( + int64_sockopts, bytes_sockopts, + ) + if option in int64_sockopts: + return new_int64_pointer() + elif option in bytes_sockopts: + return new_binary_data(length) + else: + # default + return new_int_pointer() + +def value_from_opt_pointer(option, opt_pointer, length=0): + from zmq.sugar.constants import ( + int64_sockopts, bytes_sockopts, + ) + if option in int64_sockopts: + return int(opt_pointer[0]) + elif option in bytes_sockopts: + return ffi.buffer(opt_pointer, length)[:] + else: + return int(opt_pointer[0]) + +def initialize_opt_pointer(option, value, length=0): + from zmq.sugar.constants import ( + int64_sockopts, bytes_sockopts, + ) + if option in int64_sockopts: + return value_int64_pointer(value) + elif option in bytes_sockopts: + return value_binary_data(value, length) + else: + return value_int_pointer(value) + + +class Socket(object): + context = None + socket_type = None + _zmq_socket = None + _closed = None + _ref = None + _shadow = False + + def __init__(self, context=None, socket_type=None, shadow=None): + self.context = context + if shadow is not None: + self._zmq_socket = ffi.cast("void *", shadow) + self._shadow = True + else: + self._shadow = False + self._zmq_socket = C.zmq_socket(context._zmq_ctx, socket_type) + if self._zmq_socket == ffi.NULL: + raise ZMQError() + self._closed = False + if context: + self._ref = context._add_socket(self) + + @property + def underlying(self): + """The address of the underlying libzmq socket""" + return int(ffi.cast('size_t', self._zmq_socket)) + + @property + def closed(self): + return self._closed + + def close(self, linger=None): + rc = 0 + if not self._closed and hasattr(self, '_zmq_socket'): + if self._zmq_socket is not None: + rc = C.zmq_close(self._zmq_socket) + self._closed = True + if self.context: + self.context._rm_socket(self._ref) + return rc + + def bind(self, address): + if isinstance(address, unicode): + address = address.encode('utf8') + rc = C.zmq_bind(self._zmq_socket, address) + if rc < 0: + if IPC_PATH_MAX_LEN and C.zmq_errno() == errno_mod.ENAMETOOLONG: + # py3compat: address is bytes, but msg wants str + if str is unicode: + address = address.decode('utf-8', 'replace') + path = address.split('://', 1)[-1] + msg = ('ipc path "{0}" is longer than {1} ' + 'characters (sizeof(sockaddr_un.sun_path)).' + .format(path, IPC_PATH_MAX_LEN)) + raise ZMQError(C.zmq_errno(), msg=msg) + else: + _check_rc(rc) + + def unbind(self, address): + _check_version((3,2), "unbind") + if isinstance(address, unicode): + address = address.encode('utf8') + rc = C.zmq_unbind(self._zmq_socket, address) + _check_rc(rc) + + def connect(self, address): + if isinstance(address, unicode): + address = address.encode('utf8') + rc = C.zmq_connect(self._zmq_socket, address) + _check_rc(rc) + + def disconnect(self, address): + _check_version((3,2), "disconnect") + if isinstance(address, unicode): + address = address.encode('utf8') + rc = C.zmq_disconnect(self._zmq_socket, address) + _check_rc(rc) + + def set(self, option, value): + length = None + if isinstance(value, unicode): + raise TypeError("unicode not allowed, use bytes") + + if isinstance(value, bytes): + if option not in zmq.constants.bytes_sockopts: + raise TypeError("not a bytes sockopt: %s" % option) + length = len(value) + + c_data = initialize_opt_pointer(option, value, length) + + c_value_pointer = c_data[0] + c_sizet = c_data[1] + + rc = C.zmq_setsockopt(self._zmq_socket, + option, + ffi.cast('void*', c_value_pointer), + c_sizet) + _check_rc(rc) + + def get(self, option): + c_data = new_pointer_from_opt(option, length=255) + + c_value_pointer = c_data[0] + c_sizet_pointer = c_data[1] + + rc = C.zmq_getsockopt(self._zmq_socket, + option, + c_value_pointer, + c_sizet_pointer) + _check_rc(rc) + + sz = c_sizet_pointer[0] + v = value_from_opt_pointer(option, c_value_pointer, sz) + if option != zmq.IDENTITY and option in zmq.constants.bytes_sockopts and v.endswith(b'\0'): + v = v[:-1] + return v + + def send(self, message, flags=0, copy=False, track=False): + if isinstance(message, unicode): + raise TypeError("Message must be in bytes, not an unicode Object") + + if isinstance(message, Frame): + message = message.bytes + + zmq_msg = ffi.new('zmq_msg_t*') + c_message = ffi.new('char[]', message) + rc = C.zmq_msg_init_size(zmq_msg, len(message)) + C.memcpy(C.zmq_msg_data(zmq_msg), c_message, len(message)) + + rc = C.zmq_msg_send(zmq_msg, self._zmq_socket, flags) + C.zmq_msg_close(zmq_msg) + _check_rc(rc) + + if track: + return zmq.MessageTracker() + + def recv(self, flags=0, copy=True, track=False): + zmq_msg = ffi.new('zmq_msg_t*') + C.zmq_msg_init(zmq_msg) + + rc = C.zmq_msg_recv(zmq_msg, self._zmq_socket, flags) + + if rc < 0: + C.zmq_msg_close(zmq_msg) + _check_rc(rc) + + _buffer = ffi.buffer(C.zmq_msg_data(zmq_msg), C.zmq_msg_size(zmq_msg)) + value = _buffer[:] + C.zmq_msg_close(zmq_msg) + + frame = Frame(value, track=track) + frame.more = self.getsockopt(RCVMORE) + + if copy: + return frame.bytes + else: + return frame + + def monitor(self, addr, events=-1): + """s.monitor(addr, flags) + + Start publishing socket events on inproc. + See libzmq docs for zmq_monitor for details. + + Note: requires libzmq >= 3.2 + + Parameters + ---------- + addr : str + The inproc url used for monitoring. Passing None as + the addr will cause an existing socket monitor to be + deregistered. + events : int [default: zmq.EVENT_ALL] + The zmq event bitmask for which events will be sent to the monitor. + """ + + _check_version((3,2), "monitor") + if events < 0: + events = zmq.EVENT_ALL + if addr is None: + addr = ffi.NULL + rc = C.zmq_socket_monitor(self._zmq_socket, addr, events) + + +__all__ = ['Socket', 'IPC_PATH_MAX_LEN'] diff --git a/src/console/zmq/backend/cffi/utils.py b/src/console/zmq/backend/cffi/utils.py new file mode 100755 index 00000000..fde7827b --- /dev/null +++ b/src/console/zmq/backend/cffi/utils.py @@ -0,0 +1,62 @@ +# coding: utf-8 +"""miscellaneous zmq_utils wrapping""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import ffi, C + +from zmq.error import ZMQError, _check_rc, _check_version +from zmq.utils.strtypes import unicode + +def has(capability): + """Check for zmq capability by name (e.g. 'ipc', 'curve') + + .. versionadded:: libzmq-4.1 + .. versionadded:: 14.1 + """ + _check_version((4,1), 'zmq.has') + if isinstance(capability, unicode): + capability = capability.encode('utf8') + return bool(C.zmq_has(capability)) + +def curve_keypair(): + """generate a Z85 keypair for use with zmq.CURVE security + + Requires libzmq (≥ 4.0) to have been linked with libsodium. + + Returns + ------- + (public, secret) : two bytestrings + The public and private keypair as 40 byte z85-encoded bytestrings. + """ + _check_version((3,2), "monitor") + public = ffi.new('char[64]') + private = ffi.new('char[64]') + rc = C.zmq_curve_keypair(public, private) + _check_rc(rc) + return ffi.buffer(public)[:40], ffi.buffer(private)[:40] + + +class Stopwatch(object): + def __init__(self): + self.watch = ffi.NULL + + def start(self): + if self.watch == ffi.NULL: + self.watch = C.zmq_stopwatch_start() + else: + raise ZMQError('Stopwatch is already runing.') + + def stop(self): + if self.watch == ffi.NULL: + raise ZMQError('Must start the Stopwatch before calling stop.') + else: + time = C.zmq_stopwatch_stop(self.watch) + self.watch = ffi.NULL + return time + + def sleep(self, seconds): + C.zmq_sleep(seconds) + +__all__ = ['has', 'curve_keypair', 'Stopwatch'] diff --git a/src/console/zmq/backend/cython/__init__.py b/src/console/zmq/backend/cython/__init__.py new file mode 100755 index 00000000..e5358185 --- /dev/null +++ b/src/console/zmq/backend/cython/__init__.py @@ -0,0 +1,23 @@ +"""Python bindings for core 0MQ objects.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Lesser GNU Public License (LGPL). + +from . import (constants, error, message, context, + socket, utils, _poll, _version, _device ) + +__all__ = [] +for submod in (constants, error, message, context, + socket, utils, _poll, _version, _device): + __all__.extend(submod.__all__) + +from .constants import * +from .error import * +from .message import * +from .context import * +from .socket import * +from ._poll import * +from .utils import * +from ._device import * +from ._version import * + diff --git a/src/console/zmq/backend/cython/_device.py b/src/console/zmq/backend/cython/_device.py new file mode 100755 index 00000000..3368ca2c --- /dev/null +++ b/src/console/zmq/backend/cython/_device.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'_device.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/_poll.py b/src/console/zmq/backend/cython/_poll.py new file mode 100755 index 00000000..cb1d5d77 --- /dev/null +++ b/src/console/zmq/backend/cython/_poll.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'_poll.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/_version.py b/src/console/zmq/backend/cython/_version.py new file mode 100755 index 00000000..08262706 --- /dev/null +++ b/src/console/zmq/backend/cython/_version.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'_version.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/checkrc.pxd b/src/console/zmq/backend/cython/checkrc.pxd new file mode 100644 index 00000000..3bf69fc3 --- /dev/null +++ b/src/console/zmq/backend/cython/checkrc.pxd @@ -0,0 +1,23 @@ +from libc.errno cimport EINTR, EAGAIN +from cpython cimport PyErr_CheckSignals +from libzmq cimport zmq_errno, ZMQ_ETERM + +cdef inline int _check_rc(int rc) except -1: + """internal utility for checking zmq return condition + + and raising the appropriate Exception class + """ + cdef int errno = zmq_errno() + PyErr_CheckSignals() + if rc < 0: + if errno == EAGAIN: + from zmq.error import Again + raise Again(errno) + elif errno == ZMQ_ETERM: + from zmq.error import ContextTerminated + raise ContextTerminated(errno) + else: + from zmq.error import ZMQError + raise ZMQError(errno) + # return -1 + return 0 diff --git a/src/console/zmq/backend/cython/constants.py b/src/console/zmq/backend/cython/constants.py new file mode 100755 index 00000000..ea772ac0 --- /dev/null +++ b/src/console/zmq/backend/cython/constants.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'constants.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/context.pxd b/src/console/zmq/backend/cython/context.pxd new file mode 100644 index 00000000..9c9267a5 --- /dev/null +++ b/src/console/zmq/backend/cython/context.pxd @@ -0,0 +1,41 @@ +"""0MQ Context class declaration.""" + +# +# Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +cdef class Context: + + cdef object __weakref__ # enable weakref + cdef void *handle # The C handle for the underlying zmq object. + cdef bint _shadow # whether the Context is a shadow wrapper of another + cdef void **_sockets # A C-array containg socket handles + cdef size_t _n_sockets # the number of sockets + cdef size_t _max_sockets # the size of the _sockets array + cdef int _pid # the pid of the process which created me (for fork safety) + + cdef public bint closed # bool property for a closed context. + cdef inline int _term(self) + # helpers for events on _sockets in Socket.__cinit__()/close() + cdef inline void _add_socket(self, void* handle) + cdef inline void _remove_socket(self, void* handle) + diff --git a/src/console/zmq/backend/cython/context.py b/src/console/zmq/backend/cython/context.py new file mode 100755 index 00000000..19f8ec7c --- /dev/null +++ b/src/console/zmq/backend/cython/context.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'context.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/error.py b/src/console/zmq/backend/cython/error.py new file mode 100755 index 00000000..d3a4ea0e --- /dev/null +++ b/src/console/zmq/backend/cython/error.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'error.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/libzmq.pxd b/src/console/zmq/backend/cython/libzmq.pxd new file mode 100644 index 00000000..e42f6d6b --- /dev/null +++ b/src/console/zmq/backend/cython/libzmq.pxd @@ -0,0 +1,110 @@ +"""All the C imports for 0MQ""" + +# +# Copyright (c) 2010 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +#----------------------------------------------------------------------------- +# Import the C header files +#----------------------------------------------------------------------------- + +cdef extern from *: + ctypedef void* const_void_ptr "const void *" + ctypedef char* const_char_ptr "const char *" + +cdef extern from "zmq_compat.h": + ctypedef signed long long int64_t "pyzmq_int64_t" + +include "constant_enums.pxi" + +cdef extern from "zmq.h" nogil: + + void _zmq_version "zmq_version"(int *major, int *minor, int *patch) + + ctypedef int fd_t "ZMQ_FD_T" + + enum: errno + char *zmq_strerror (int errnum) + int zmq_errno() + + void *zmq_ctx_new () + int zmq_ctx_destroy (void *context) + int zmq_ctx_set (void *context, int option, int optval) + int zmq_ctx_get (void *context, int option) + void *zmq_init (int io_threads) + int zmq_term (void *context) + + # blackbox def for zmq_msg_t + ctypedef void * zmq_msg_t "zmq_msg_t" + + ctypedef void zmq_free_fn(void *data, void *hint) + + int zmq_msg_init (zmq_msg_t *msg) + int zmq_msg_init_size (zmq_msg_t *msg, size_t size) + int zmq_msg_init_data (zmq_msg_t *msg, void *data, + size_t size, zmq_free_fn *ffn, void *hint) + int zmq_msg_send (zmq_msg_t *msg, void *s, int flags) + int zmq_msg_recv (zmq_msg_t *msg, void *s, int flags) + int zmq_msg_close (zmq_msg_t *msg) + int zmq_msg_move (zmq_msg_t *dest, zmq_msg_t *src) + int zmq_msg_copy (zmq_msg_t *dest, zmq_msg_t *src) + void *zmq_msg_data (zmq_msg_t *msg) + size_t zmq_msg_size (zmq_msg_t *msg) + int zmq_msg_more (zmq_msg_t *msg) + int zmq_msg_get (zmq_msg_t *msg, int option) + int zmq_msg_set (zmq_msg_t *msg, int option, int optval) + const_char_ptr zmq_msg_gets (zmq_msg_t *msg, const_char_ptr property) + int zmq_has (const_char_ptr capability) + + void *zmq_socket (void *context, int type) + int zmq_close (void *s) + int zmq_setsockopt (void *s, int option, void *optval, size_t optvallen) + int zmq_getsockopt (void *s, int option, void *optval, size_t *optvallen) + int zmq_bind (void *s, char *addr) + int zmq_connect (void *s, char *addr) + int zmq_unbind (void *s, char *addr) + int zmq_disconnect (void *s, char *addr) + + int zmq_socket_monitor (void *s, char *addr, int flags) + + # send/recv + int zmq_sendbuf (void *s, const_void_ptr buf, size_t n, int flags) + int zmq_recvbuf (void *s, void *buf, size_t n, int flags) + + ctypedef struct zmq_pollitem_t: + void *socket + int fd + short events + short revents + + int zmq_poll (zmq_pollitem_t *items, int nitems, long timeout) + + int zmq_device (int device_, void *insocket_, void *outsocket_) + int zmq_proxy (void *frontend, void *backend, void *capture) + +cdef extern from "zmq_utils.h" nogil: + + void *zmq_stopwatch_start () + unsigned long zmq_stopwatch_stop (void *watch_) + void zmq_sleep (int seconds_) + int zmq_curve_keypair (char *z85_public_key, char *z85_secret_key) + diff --git a/src/console/zmq/backend/cython/message.pxd b/src/console/zmq/backend/cython/message.pxd new file mode 100644 index 00000000..4781195f --- /dev/null +++ b/src/console/zmq/backend/cython/message.pxd @@ -0,0 +1,63 @@ +"""0MQ Message related class declarations.""" + +# +# Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from cpython cimport PyBytes_FromStringAndSize + +from libzmq cimport zmq_msg_t, zmq_msg_data, zmq_msg_size + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +cdef class MessageTracker(object): + + cdef set events # Message Event objects to track. + cdef set peers # Other Message or MessageTracker objects. + + +cdef class Frame: + + cdef zmq_msg_t zmq_msg + cdef object _data # The actual message data as a Python object. + cdef object _buffer # A Python Buffer/View of the message contents + cdef object _bytes # A bytes/str copy of the message. + cdef bint _failed_init # Flag to handle failed zmq_msg_init + cdef public object tracker_event # Event for use with zmq_free_fn. + cdef public object tracker # MessageTracker object. + cdef public bint more # whether RCVMORE was set + + cdef Frame fast_copy(self) # Create shallow copy of Message object. + cdef object _getbuffer(self) # Construct self._buffer. + + +cdef inline object copy_zmq_msg_bytes(zmq_msg_t *zmq_msg): + """ Copy the data from a zmq_msg_t """ + cdef char *data_c = NULL + cdef Py_ssize_t data_len_c + data_c = <char *>zmq_msg_data(zmq_msg) + data_len_c = zmq_msg_size(zmq_msg) + return PyBytes_FromStringAndSize(data_c, data_len_c) + + diff --git a/src/console/zmq/backend/cython/message.py b/src/console/zmq/backend/cython/message.py new file mode 100755 index 00000000..5e423b62 --- /dev/null +++ b/src/console/zmq/backend/cython/message.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'message.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/socket.pxd b/src/console/zmq/backend/cython/socket.pxd new file mode 100644 index 00000000..b8a331e2 --- /dev/null +++ b/src/console/zmq/backend/cython/socket.pxd @@ -0,0 +1,47 @@ +"""0MQ Socket class declaration.""" + +# +# Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from context cimport Context + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + + +cdef class Socket: + + cdef object __weakref__ # enable weakref + cdef void *handle # The C handle for the underlying zmq object. + cdef bint _shadow # whether the Socket is a shadow wrapper of another + # Hold on to a reference to the context to make sure it is not garbage + # collected until the socket it done with it. + cdef public Context context # The zmq Context object that owns this. + cdef public bint _closed # bool property for a closed socket. + cdef int _pid # the pid of the process which created me (for fork safety) + + # cpdef methods for direct-cython access: + cpdef object send(self, object data, int flags=*, copy=*, track=*) + cpdef object recv(self, int flags=*, copy=*, track=*) + diff --git a/src/console/zmq/backend/cython/socket.py b/src/console/zmq/backend/cython/socket.py new file mode 100755 index 00000000..faef8bee --- /dev/null +++ b/src/console/zmq/backend/cython/socket.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'socket.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/cython/utils.pxd b/src/console/zmq/backend/cython/utils.pxd new file mode 100644 index 00000000..1d7117f1 --- /dev/null +++ b/src/console/zmq/backend/cython/utils.pxd @@ -0,0 +1,29 @@ +"""Wrap zmq_utils.h""" + +# +# Copyright (c) 2010 Brian E. Granger & Min Ragan-Kelley +# +# This file is part of pyzmq. +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + + +cdef class Stopwatch: + cdef void *watch # The C handle for the underlying zmq object + diff --git a/src/console/zmq/backend/cython/utils.py b/src/console/zmq/backend/cython/utils.py new file mode 100755 index 00000000..fe928300 --- /dev/null +++ b/src/console/zmq/backend/cython/utils.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'utils.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/backend/select.py b/src/console/zmq/backend/select.py new file mode 100755 index 00000000..0a2e09a2 --- /dev/null +++ b/src/console/zmq/backend/select.py @@ -0,0 +1,39 @@ +"""Import basic exposure of libzmq C API as a backend""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +public_api = [ + 'Context', + 'Socket', + 'Frame', + 'Message', + 'Stopwatch', + 'device', + 'proxy', + 'zmq_poll', + 'strerror', + 'zmq_errno', + 'has', + 'curve_keypair', + 'constants', + 'zmq_version_info', + 'IPC_PATH_MAX_LEN', +] + +def select_backend(name): + """Select the pyzmq backend""" + try: + mod = __import__(name, fromlist=public_api) + except ImportError: + raise + except Exception as e: + import sys + from zmq.utils.sixcerpt import reraise + exc_info = sys.exc_info() + reraise(ImportError, ImportError("Importing %s failed with %s" % (name, e)), exc_info[2]) + + ns = {} + for key in public_api: + ns[key] = getattr(mod, key) + return ns diff --git a/src/console/zmq/devices/__init__.py b/src/console/zmq/devices/__init__.py new file mode 100755 index 00000000..23715963 --- /dev/null +++ b/src/console/zmq/devices/__init__.py @@ -0,0 +1,16 @@ +"""0MQ Device classes for running in background threads or processes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq import device +from zmq.devices import basedevice, proxydevice, monitoredqueue, monitoredqueuedevice + +from zmq.devices.basedevice import * +from zmq.devices.proxydevice import * +from zmq.devices.monitoredqueue import * +from zmq.devices.monitoredqueuedevice import * + +__all__ = ['device'] +for submod in (basedevice, proxydevice, monitoredqueue, monitoredqueuedevice): + __all__.extend(submod.__all__) diff --git a/src/console/zmq/devices/basedevice.py b/src/console/zmq/devices/basedevice.py new file mode 100755 index 00000000..7ba1b7ac --- /dev/null +++ b/src/console/zmq/devices/basedevice.py @@ -0,0 +1,229 @@ +"""Classes for running 0MQ Devices in the background.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import time +from threading import Thread +from multiprocessing import Process + +from zmq import device, QUEUE, Context, ETERM, ZMQError + + +class Device: + """A 0MQ Device to be run in the background. + + You do not pass Socket instances to this, but rather Socket types:: + + Device(device_type, in_socket_type, out_socket_type) + + For instance:: + + dev = Device(zmq.QUEUE, zmq.DEALER, zmq.ROUTER) + + Similar to zmq.device, but socket types instead of sockets themselves are + passed, and the sockets are created in the work thread, to avoid issues + with thread safety. As a result, additional bind_{in|out} and + connect_{in|out} methods and setsockopt_{in|out} allow users to specify + connections for the sockets. + + Parameters + ---------- + device_type : int + The 0MQ Device type + {in|out}_type : int + zmq socket types, to be passed later to context.socket(). e.g. + zmq.PUB, zmq.SUB, zmq.REQ. If out_type is < 0, then in_socket is used + for both in_socket and out_socket. + + Methods + ------- + bind_{in_out}(iface) + passthrough for ``{in|out}_socket.bind(iface)``, to be called in the thread + connect_{in_out}(iface) + passthrough for ``{in|out}_socket.connect(iface)``, to be called in the + thread + setsockopt_{in_out}(opt,value) + passthrough for ``{in|out}_socket.setsockopt(opt, value)``, to be called in + the thread + + Attributes + ---------- + daemon : int + sets whether the thread should be run as a daemon + Default is true, because if it is false, the thread will not + exit unless it is killed + context_factory : callable (class attribute) + Function for creating the Context. This will be Context.instance + in ThreadDevices, and Context in ProcessDevices. The only reason + it is not instance() in ProcessDevices is that there may be a stale + Context instance already initialized, and the forked environment + should *never* try to use it. + """ + + context_factory = Context.instance + """Callable that returns a context. Typically either Context.instance or Context, + depending on whether the device should share the global instance or not. + """ + + def __init__(self, device_type=QUEUE, in_type=None, out_type=None): + self.device_type = device_type + if in_type is None: + raise TypeError("in_type must be specified") + if out_type is None: + raise TypeError("out_type must be specified") + self.in_type = in_type + self.out_type = out_type + self._in_binds = [] + self._in_connects = [] + self._in_sockopts = [] + self._out_binds = [] + self._out_connects = [] + self._out_sockopts = [] + self.daemon = True + self.done = False + + def bind_in(self, addr): + """Enqueue ZMQ address for binding on in_socket. + + See zmq.Socket.bind for details. + """ + self._in_binds.append(addr) + + def connect_in(self, addr): + """Enqueue ZMQ address for connecting on in_socket. + + See zmq.Socket.connect for details. + """ + self._in_connects.append(addr) + + def setsockopt_in(self, opt, value): + """Enqueue setsockopt(opt, value) for in_socket + + See zmq.Socket.setsockopt for details. + """ + self._in_sockopts.append((opt, value)) + + def bind_out(self, addr): + """Enqueue ZMQ address for binding on out_socket. + + See zmq.Socket.bind for details. + """ + self._out_binds.append(addr) + + def connect_out(self, addr): + """Enqueue ZMQ address for connecting on out_socket. + + See zmq.Socket.connect for details. + """ + self._out_connects.append(addr) + + def setsockopt_out(self, opt, value): + """Enqueue setsockopt(opt, value) for out_socket + + See zmq.Socket.setsockopt for details. + """ + self._out_sockopts.append((opt, value)) + + def _setup_sockets(self): + ctx = self.context_factory() + + self._context = ctx + + # create the sockets + ins = ctx.socket(self.in_type) + if self.out_type < 0: + outs = ins + else: + outs = ctx.socket(self.out_type) + + # set sockopts (must be done first, in case of zmq.IDENTITY) + for opt,value in self._in_sockopts: + ins.setsockopt(opt, value) + for opt,value in self._out_sockopts: + outs.setsockopt(opt, value) + + for iface in self._in_binds: + ins.bind(iface) + for iface in self._out_binds: + outs.bind(iface) + + for iface in self._in_connects: + ins.connect(iface) + for iface in self._out_connects: + outs.connect(iface) + + return ins,outs + + def run_device(self): + """The runner method. + + Do not call me directly, instead call ``self.start()``, just like a Thread. + """ + ins,outs = self._setup_sockets() + device(self.device_type, ins, outs) + + def run(self): + """wrap run_device in try/catch ETERM""" + try: + self.run_device() + except ZMQError as e: + if e.errno == ETERM: + # silence TERM errors, because this should be a clean shutdown + pass + else: + raise + finally: + self.done = True + + def start(self): + """Start the device. Override me in subclass for other launchers.""" + return self.run() + + def join(self,timeout=None): + """wait for me to finish, like Thread.join. + + Reimplemented appropriately by subclasses.""" + tic = time.time() + toc = tic + while not self.done and not (timeout is not None and toc-tic > timeout): + time.sleep(.001) + toc = time.time() + + +class BackgroundDevice(Device): + """Base class for launching Devices in background processes and threads.""" + + launcher=None + _launch_class=None + + def start(self): + self.launcher = self._launch_class(target=self.run) + self.launcher.daemon = self.daemon + return self.launcher.start() + + def join(self, timeout=None): + return self.launcher.join(timeout=timeout) + + +class ThreadDevice(BackgroundDevice): + """A Device that will be run in a background Thread. + + See Device for details. + """ + _launch_class=Thread + +class ProcessDevice(BackgroundDevice): + """A Device that will be run in a background Process. + + See Device for details. + """ + _launch_class=Process + context_factory = Context + """Callable that returns a context. Typically either Context.instance or Context, + depending on whether the device should share the global instance or not. + """ + + +__all__ = ['Device', 'ThreadDevice', 'ProcessDevice'] diff --git a/src/console/zmq/devices/monitoredqueue.pxd b/src/console/zmq/devices/monitoredqueue.pxd new file mode 100644 index 00000000..1e26ed86 --- /dev/null +++ b/src/console/zmq/devices/monitoredqueue.pxd @@ -0,0 +1,177 @@ +"""MonitoredQueue class declarations. + +Authors +------- +* MinRK +* Brian Granger +""" + +# +# Copyright (c) 2010 Min Ragan-Kelley, Brian Granger +# +# This file is part of pyzmq, but is derived and adapted from zmq_queue.cpp +# originally from libzmq-2.1.6, used under LGPLv3 +# +# pyzmq is free software; you can redistribute it and/or modify it under +# the terms of the Lesser GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# pyzmq is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Lesser GNU General Public License for more details. +# +# You should have received a copy of the Lesser GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from libzmq cimport * + +#----------------------------------------------------------------------------- +# MonitoredQueue C functions +#----------------------------------------------------------------------------- + +cdef inline int _relay(void *insocket_, void *outsocket_, void *sidesocket_, + zmq_msg_t msg, zmq_msg_t side_msg, zmq_msg_t id_msg, + bint swap_ids) nogil: + cdef int rc + cdef int64_t flag_2 + cdef int flag_3 + cdef int flags + cdef bint more + cdef size_t flagsz + cdef void * flag_ptr + + if ZMQ_VERSION_MAJOR < 3: + flagsz = sizeof (int64_t) + flag_ptr = &flag_2 + else: + flagsz = sizeof (int) + flag_ptr = &flag_3 + + if swap_ids:# both router, must send second identity first + # recv two ids into msg, id_msg + rc = zmq_msg_recv(&msg, insocket_, 0) + if rc < 0: return rc + + rc = zmq_msg_recv(&id_msg, insocket_, 0) + if rc < 0: return rc + + # send second id (id_msg) first + #!!!! always send a copy before the original !!!! + rc = zmq_msg_copy(&side_msg, &id_msg) + if rc < 0: return rc + rc = zmq_msg_send(&side_msg, outsocket_, ZMQ_SNDMORE) + if rc < 0: return rc + rc = zmq_msg_send(&id_msg, sidesocket_, ZMQ_SNDMORE) + if rc < 0: return rc + # send first id (msg) second + rc = zmq_msg_copy(&side_msg, &msg) + if rc < 0: return rc + rc = zmq_msg_send(&side_msg, outsocket_, ZMQ_SNDMORE) + if rc < 0: return rc + rc = zmq_msg_send(&msg, sidesocket_, ZMQ_SNDMORE) + if rc < 0: return rc + while (True): + rc = zmq_msg_recv(&msg, insocket_, 0) + if rc < 0: return rc + # assert (rc == 0) + rc = zmq_getsockopt (insocket_, ZMQ_RCVMORE, flag_ptr, &flagsz) + if rc < 0: return rc + flags = 0 + if ZMQ_VERSION_MAJOR < 3: + if flag_2: + flags |= ZMQ_SNDMORE + else: + if flag_3: + flags |= ZMQ_SNDMORE + # LABEL has been removed: + # rc = zmq_getsockopt (insocket_, ZMQ_RCVLABEL, flag_ptr, &flagsz) + # if flag_3: + # flags |= ZMQ_SNDLABEL + # assert (rc == 0) + + rc = zmq_msg_copy(&side_msg, &msg) + if rc < 0: return rc + if flags: + rc = zmq_msg_send(&side_msg, outsocket_, flags) + if rc < 0: return rc + # only SNDMORE for side-socket + rc = zmq_msg_send(&msg, sidesocket_, ZMQ_SNDMORE) + if rc < 0: return rc + else: + rc = zmq_msg_send(&side_msg, outsocket_, 0) + if rc < 0: return rc + rc = zmq_msg_send(&msg, sidesocket_, 0) + if rc < 0: return rc + break + return rc + +# the MonitoredQueue C function, adapted from zmq::queue.cpp : +cdef inline int c_monitored_queue (void *insocket_, void *outsocket_, + void *sidesocket_, zmq_msg_t *in_msg_ptr, + zmq_msg_t *out_msg_ptr, int swap_ids) nogil: + """The actual C function for a monitored queue device. + + See ``monitored_queue()`` for details. + """ + + cdef zmq_msg_t msg + cdef int rc = zmq_msg_init (&msg) + cdef zmq_msg_t id_msg + rc = zmq_msg_init (&id_msg) + if rc < 0: return rc + cdef zmq_msg_t side_msg + rc = zmq_msg_init (&side_msg) + if rc < 0: return rc + + cdef zmq_pollitem_t items [2] + items [0].socket = insocket_ + items [0].fd = 0 + items [0].events = ZMQ_POLLIN + items [0].revents = 0 + items [1].socket = outsocket_ + items [1].fd = 0 + items [1].events = ZMQ_POLLIN + items [1].revents = 0 + # I don't think sidesocket should be polled? + # items [2].socket = sidesocket_ + # items [2].fd = 0 + # items [2].events = ZMQ_POLLIN + # items [2].revents = 0 + + while (True): + + # // Wait while there are either requests or replies to process. + rc = zmq_poll (&items [0], 2, -1) + if rc < 0: return rc + # // The algorithm below asumes ratio of request and replies processed + # // under full load to be 1:1. Although processing requests replies + # // first is tempting it is suspectible to DoS attacks (overloading + # // the system with unsolicited replies). + # + # // Process a request. + if (items [0].revents & ZMQ_POLLIN): + # send in_prefix to side socket + rc = zmq_msg_copy(&side_msg, in_msg_ptr) + if rc < 0: return rc + rc = zmq_msg_send(&side_msg, sidesocket_, ZMQ_SNDMORE) + if rc < 0: return rc + # relay the rest of the message + rc = _relay(insocket_, outsocket_, sidesocket_, msg, side_msg, id_msg, swap_ids) + if rc < 0: return rc + if (items [1].revents & ZMQ_POLLIN): + # send out_prefix to side socket + rc = zmq_msg_copy(&side_msg, out_msg_ptr) + if rc < 0: return rc + rc = zmq_msg_send(&side_msg, sidesocket_, ZMQ_SNDMORE) + if rc < 0: return rc + # relay the rest of the message + rc = _relay(outsocket_, insocket_, sidesocket_, msg, side_msg, id_msg, swap_ids) + if rc < 0: return rc + return rc diff --git a/src/console/zmq/devices/monitoredqueue.py b/src/console/zmq/devices/monitoredqueue.py new file mode 100755 index 00000000..6d714e51 --- /dev/null +++ b/src/console/zmq/devices/monitoredqueue.py @@ -0,0 +1,7 @@ +def __bootstrap__(): + global __bootstrap__, __loader__, __file__ + import sys, pkg_resources, imp + __file__ = pkg_resources.resource_filename(__name__,'monitoredqueue.so') + __loader__ = None; del __bootstrap__, __loader__ + imp.load_dynamic(__name__,__file__) +__bootstrap__() diff --git a/src/console/zmq/devices/monitoredqueuedevice.py b/src/console/zmq/devices/monitoredqueuedevice.py new file mode 100755 index 00000000..9723f866 --- /dev/null +++ b/src/console/zmq/devices/monitoredqueuedevice.py @@ -0,0 +1,66 @@ +"""MonitoredQueue classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from zmq import ZMQError, PUB +from zmq.devices.proxydevice import ProxyBase, Proxy, ThreadProxy, ProcessProxy +from zmq.devices.monitoredqueue import monitored_queue + + +class MonitoredQueueBase(ProxyBase): + """Base class for overriding methods.""" + + _in_prefix = b'' + _out_prefix = b'' + + def __init__(self, in_type, out_type, mon_type=PUB, in_prefix=b'in', out_prefix=b'out'): + + ProxyBase.__init__(self, in_type=in_type, out_type=out_type, mon_type=mon_type) + + self._in_prefix = in_prefix + self._out_prefix = out_prefix + + def run_device(self): + ins,outs,mons = self._setup_sockets() + monitored_queue(ins, outs, mons, self._in_prefix, self._out_prefix) + + +class MonitoredQueue(MonitoredQueueBase, Proxy): + """Class for running monitored_queue in the background. + + See zmq.devices.Device for most of the spec. MonitoredQueue differs from Proxy, + only in that it adds a ``prefix`` to messages sent on the monitor socket, + with a different prefix for each direction. + + MQ also supports ROUTER on both sides, which zmq.proxy does not. + + If a message arrives on `in_sock`, it will be prefixed with `in_prefix` on the monitor socket. + If it arrives on out_sock, it will be prefixed with `out_prefix`. + + A PUB socket is the most logical choice for the mon_socket, but it is not required. + """ + pass + + +class ThreadMonitoredQueue(MonitoredQueueBase, ThreadProxy): + """Run zmq.monitored_queue in a background thread. + + See MonitoredQueue and Proxy for details. + """ + pass + + +class ProcessMonitoredQueue(MonitoredQueueBase, ProcessProxy): + """Run zmq.monitored_queue in a background thread. + + See MonitoredQueue and Proxy for details. + """ + + +__all__ = [ + 'MonitoredQueue', + 'ThreadMonitoredQueue', + 'ProcessMonitoredQueue' +] diff --git a/src/console/zmq/devices/proxydevice.py b/src/console/zmq/devices/proxydevice.py new file mode 100755 index 00000000..68be3f15 --- /dev/null +++ b/src/console/zmq/devices/proxydevice.py @@ -0,0 +1,90 @@ +"""Proxy classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import zmq +from zmq.devices.basedevice import Device, ThreadDevice, ProcessDevice + + +class ProxyBase(object): + """Base class for overriding methods.""" + + def __init__(self, in_type, out_type, mon_type=zmq.PUB): + + Device.__init__(self, in_type=in_type, out_type=out_type) + self.mon_type = mon_type + self._mon_binds = [] + self._mon_connects = [] + self._mon_sockopts = [] + + def bind_mon(self, addr): + """Enqueue ZMQ address for binding on mon_socket. + + See zmq.Socket.bind for details. + """ + self._mon_binds.append(addr) + + def connect_mon(self, addr): + """Enqueue ZMQ address for connecting on mon_socket. + + See zmq.Socket.bind for details. + """ + self._mon_connects.append(addr) + + def setsockopt_mon(self, opt, value): + """Enqueue setsockopt(opt, value) for mon_socket + + See zmq.Socket.setsockopt for details. + """ + self._mon_sockopts.append((opt, value)) + + def _setup_sockets(self): + ins,outs = Device._setup_sockets(self) + ctx = self._context + mons = ctx.socket(self.mon_type) + + # set sockopts (must be done first, in case of zmq.IDENTITY) + for opt,value in self._mon_sockopts: + mons.setsockopt(opt, value) + + for iface in self._mon_binds: + mons.bind(iface) + + for iface in self._mon_connects: + mons.connect(iface) + + return ins,outs,mons + + def run_device(self): + ins,outs,mons = self._setup_sockets() + zmq.proxy(ins, outs, mons) + +class Proxy(ProxyBase, Device): + """Threadsafe Proxy object. + + See zmq.devices.Device for most of the spec. This subclass adds a + <method>_mon version of each <method>_{in|out} method, for configuring the + monitor socket. + + A Proxy is a 3-socket ZMQ Device that functions just like a + QUEUE, except each message is also sent out on the monitor socket. + + A PUB socket is the most logical choice for the mon_socket, but it is not required. + """ + pass + +class ThreadProxy(ProxyBase, ThreadDevice): + """Proxy in a Thread. See Proxy for more.""" + pass + +class ProcessProxy(ProxyBase, ProcessDevice): + """Proxy in a Process. See Proxy for more.""" + pass + + +__all__ = [ + 'Proxy', + 'ThreadProxy', + 'ProcessProxy', +] diff --git a/src/console/zmq/error.py b/src/console/zmq/error.py new file mode 100755 index 00000000..48cdaafa --- /dev/null +++ b/src/console/zmq/error.py @@ -0,0 +1,164 @@ +"""0MQ Error classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +class ZMQBaseError(Exception): + """Base exception class for 0MQ errors in Python.""" + pass + +class ZMQError(ZMQBaseError): + """Wrap an errno style error. + + Parameters + ---------- + errno : int + The ZMQ errno or None. If None, then ``zmq_errno()`` is called and + used. + msg : string + Description of the error or None. + """ + errno = None + + def __init__(self, errno=None, msg=None): + """Wrap an errno style error. + + Parameters + ---------- + errno : int + The ZMQ errno or None. If None, then ``zmq_errno()`` is called and + used. + msg : string + Description of the error or None. + """ + from zmq.backend import strerror, zmq_errno + if errno is None: + errno = zmq_errno() + if isinstance(errno, int): + self.errno = errno + if msg is None: + self.strerror = strerror(errno) + else: + self.strerror = msg + else: + if msg is None: + self.strerror = str(errno) + else: + self.strerror = msg + # flush signals, because there could be a SIGINT + # waiting to pounce, resulting in uncaught exceptions. + # Doing this here means getting SIGINT during a blocking + # libzmq call will raise a *catchable* KeyboardInterrupt + # PyErr_CheckSignals() + + def __str__(self): + return self.strerror + + def __repr__(self): + return "ZMQError('%s')"%self.strerror + + +class ZMQBindError(ZMQBaseError): + """An error for ``Socket.bind_to_random_port()``. + + See Also + -------- + .Socket.bind_to_random_port + """ + pass + + +class NotDone(ZMQBaseError): + """Raised when timeout is reached while waiting for 0MQ to finish with a Message + + See Also + -------- + .MessageTracker.wait : object for tracking when ZeroMQ is done + """ + pass + + +class ContextTerminated(ZMQError): + """Wrapper for zmq.ETERM + + .. versionadded:: 13.0 + """ + pass + + +class Again(ZMQError): + """Wrapper for zmq.EAGAIN + + .. versionadded:: 13.0 + """ + pass + + +def _check_rc(rc, errno=None): + """internal utility for checking zmq return condition + + and raising the appropriate Exception class + """ + if rc < 0: + from zmq.backend import zmq_errno + if errno is None: + errno = zmq_errno() + from zmq import EAGAIN, ETERM + if errno == EAGAIN: + raise Again(errno) + elif errno == ETERM: + raise ContextTerminated(errno) + else: + raise ZMQError(errno) + +_zmq_version_info = None +_zmq_version = None + +class ZMQVersionError(NotImplementedError): + """Raised when a feature is not provided by the linked version of libzmq. + + .. versionadded:: 14.2 + """ + min_version = None + def __init__(self, min_version, msg='Feature'): + global _zmq_version + if _zmq_version is None: + from zmq import zmq_version + _zmq_version = zmq_version() + self.msg = msg + self.min_version = min_version + self.version = _zmq_version + + def __repr__(self): + return "ZMQVersionError('%s')" % str(self) + + def __str__(self): + return "%s requires libzmq >= %s, have %s" % (self.msg, self.min_version, self.version) + + +def _check_version(min_version_info, msg='Feature'): + """Check for libzmq + + raises ZMQVersionError if current zmq version is not at least min_version + + min_version_info is a tuple of integers, and will be compared against zmq.zmq_version_info(). + """ + global _zmq_version_info + if _zmq_version_info is None: + from zmq import zmq_version_info + _zmq_version_info = zmq_version_info() + if _zmq_version_info < min_version_info: + min_version = '.'.join(str(v) for v in min_version_info) + raise ZMQVersionError(min_version, msg) + + +__all__ = [ + 'ZMQBaseError', + 'ZMQBindError', + 'ZMQError', + 'NotDone', + 'ContextTerminated', + 'Again', + 'ZMQVersionError', +] diff --git a/src/console/zmq/eventloop/__init__.py b/src/console/zmq/eventloop/__init__.py new file mode 100755 index 00000000..568e8e8d --- /dev/null +++ b/src/console/zmq/eventloop/__init__.py @@ -0,0 +1,5 @@ +"""A Tornado based event loop for PyZMQ.""" + +from zmq.eventloop.ioloop import IOLoop + +__all__ = ['IOLoop']
\ No newline at end of file diff --git a/src/console/zmq/eventloop/ioloop.py b/src/console/zmq/eventloop/ioloop.py new file mode 100755 index 00000000..35f4c418 --- /dev/null +++ b/src/console/zmq/eventloop/ioloop.py @@ -0,0 +1,193 @@ +# coding: utf-8 +"""tornado IOLoop API with zmq compatibility + +If you have tornado ≥ 3.0, this is a subclass of tornado's IOLoop, +otherwise we ship a minimal subset of tornado in zmq.eventloop.minitornado. + +The minimal shipped version of tornado's IOLoop does not include +support for concurrent futures - this will only be available if you +have tornado ≥ 3.0. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import absolute_import, division, with_statement + +import os +import time +import warnings + +from zmq import ( + Poller, + POLLIN, POLLOUT, POLLERR, + ZMQError, ETERM, +) + +try: + import tornado + tornado_version = tornado.version_info +except (ImportError, AttributeError): + tornado_version = () + +try: + # tornado ≥ 3 + from tornado.ioloop import PollIOLoop, PeriodicCallback + from tornado.log import gen_log +except ImportError: + from .minitornado.ioloop import PollIOLoop, PeriodicCallback + from .minitornado.log import gen_log + + +class DelayedCallback(PeriodicCallback): + """Schedules the given callback to be called once. + + The callback is called once, after callback_time milliseconds. + + `start` must be called after the DelayedCallback is created. + + The timeout is calculated from when `start` is called. + """ + def __init__(self, callback, callback_time, io_loop=None): + # PeriodicCallback require callback_time to be positive + warnings.warn("""DelayedCallback is deprecated. + Use loop.add_timeout instead.""", DeprecationWarning) + callback_time = max(callback_time, 1e-3) + super(DelayedCallback, self).__init__(callback, callback_time, io_loop) + + def start(self): + """Starts the timer.""" + self._running = True + self._firstrun = True + self._next_timeout = time.time() + self.callback_time / 1000.0 + self.io_loop.add_timeout(self._next_timeout, self._run) + + def _run(self): + if not self._running: return + self._running = False + try: + self.callback() + except Exception: + gen_log.error("Error in delayed callback", exc_info=True) + + +class ZMQPoller(object): + """A poller that can be used in the tornado IOLoop. + + This simply wraps a regular zmq.Poller, scaling the timeout + by 1000, so that it is in seconds rather than milliseconds. + """ + + def __init__(self): + self._poller = Poller() + + @staticmethod + def _map_events(events): + """translate IOLoop.READ/WRITE/ERROR event masks into zmq.POLLIN/OUT/ERR""" + z_events = 0 + if events & IOLoop.READ: + z_events |= POLLIN + if events & IOLoop.WRITE: + z_events |= POLLOUT + if events & IOLoop.ERROR: + z_events |= POLLERR + return z_events + + @staticmethod + def _remap_events(z_events): + """translate zmq.POLLIN/OUT/ERR event masks into IOLoop.READ/WRITE/ERROR""" + events = 0 + if z_events & POLLIN: + events |= IOLoop.READ + if z_events & POLLOUT: + events |= IOLoop.WRITE + if z_events & POLLERR: + events |= IOLoop.ERROR + return events + + def register(self, fd, events): + return self._poller.register(fd, self._map_events(events)) + + def modify(self, fd, events): + return self._poller.modify(fd, self._map_events(events)) + + def unregister(self, fd): + return self._poller.unregister(fd) + + def poll(self, timeout): + """poll in seconds rather than milliseconds. + + Event masks will be IOLoop.READ/WRITE/ERROR + """ + z_events = self._poller.poll(1000*timeout) + return [ (fd,self._remap_events(evt)) for (fd,evt) in z_events ] + + def close(self): + pass + + +class ZMQIOLoop(PollIOLoop): + """ZMQ subclass of tornado's IOLoop""" + def initialize(self, impl=None, **kwargs): + impl = ZMQPoller() if impl is None else impl + super(ZMQIOLoop, self).initialize(impl=impl, **kwargs) + + @staticmethod + def instance(): + """Returns a global `IOLoop` instance. + + Most applications have a single, global `IOLoop` running on the + main thread. Use this method to get this instance from + another thread. To get the current thread's `IOLoop`, use `current()`. + """ + # install ZMQIOLoop as the active IOLoop implementation + # when using tornado 3 + if tornado_version >= (3,): + PollIOLoop.configure(ZMQIOLoop) + return PollIOLoop.instance() + + def start(self): + try: + super(ZMQIOLoop, self).start() + except ZMQError as e: + if e.errno == ETERM: + # quietly return on ETERM + pass + else: + raise e + + +if tornado_version >= (3,0) and tornado_version < (3,1): + def backport_close(self, all_fds=False): + """backport IOLoop.close to 3.0 from 3.1 (supports fd.close() method)""" + from zmq.eventloop.minitornado.ioloop import PollIOLoop as mini_loop + return mini_loop.close.__get__(self)(all_fds) + ZMQIOLoop.close = backport_close + + +# public API name +IOLoop = ZMQIOLoop + + +def install(): + """set the tornado IOLoop instance with the pyzmq IOLoop. + + After calling this function, tornado's IOLoop.instance() and pyzmq's + IOLoop.instance() will return the same object. + + An assertion error will be raised if tornado's IOLoop has been initialized + prior to calling this function. + """ + from tornado import ioloop + # check if tornado's IOLoop is already initialized to something other + # than the pyzmq IOLoop instance: + assert (not ioloop.IOLoop.initialized()) or \ + ioloop.IOLoop.instance() is IOLoop.instance(), "tornado IOLoop already initialized" + + if tornado_version >= (3,): + # tornado 3 has an official API for registering new defaults, yay! + ioloop.IOLoop.configure(ZMQIOLoop) + else: + # we have to set the global instance explicitly + ioloop.IOLoop._instance = IOLoop.instance() + diff --git a/src/console/zmq/eventloop/minitornado/__init__.py b/src/console/zmq/eventloop/minitornado/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/__init__.py diff --git a/src/console/zmq/eventloop/minitornado/concurrent.py b/src/console/zmq/eventloop/minitornado/concurrent.py new file mode 100755 index 00000000..519b23d5 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/concurrent.py @@ -0,0 +1,11 @@ +"""pyzmq does not ship tornado's futures, +this just raises informative NotImplementedErrors to avoid having to change too much code. +""" + +class NotImplementedFuture(object): + def __init__(self, *args, **kwargs): + raise NotImplementedError("pyzmq does not ship tornado's Futures, " + "install tornado >= 3.0 for future support." + ) + +Future = TracebackFuture = NotImplementedFuture diff --git a/src/console/zmq/eventloop/minitornado/ioloop.py b/src/console/zmq/eventloop/minitornado/ioloop.py new file mode 100755 index 00000000..710a3ecb --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/ioloop.py @@ -0,0 +1,829 @@ +#!/usr/bin/env python +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""An I/O event loop for non-blocking sockets. + +Typical applications will use a single `IOLoop` object, in the +`IOLoop.instance` singleton. The `IOLoop.start` method should usually +be called at the end of the ``main()`` function. Atypical applications may +use more than one `IOLoop`, such as one `IOLoop` per thread, or per `unittest` +case. + +In addition to I/O events, the `IOLoop` can also schedule time-based events. +`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import datetime +import errno +import functools +import heapq +import logging +import numbers +import os +import select +import sys +import threading +import time +import traceback + +from .concurrent import Future, TracebackFuture +from .log import app_log, gen_log +from . import stack_context +from .util import Configurable + +try: + import signal +except ImportError: + signal = None + +try: + import thread # py2 +except ImportError: + import _thread as thread # py3 + +from .platform.auto import set_close_exec, Waker + + +class TimeoutError(Exception): + pass + + +class IOLoop(Configurable): + """A level-triggered I/O loop. + + We use ``epoll`` (Linux) or ``kqueue`` (BSD and Mac OS X) if they + are available, or else we fall back on select(). If you are + implementing a system that needs to handle thousands of + simultaneous connections, you should use a system that supports + either ``epoll`` or ``kqueue``. + + Example usage for a simple TCP server:: + + import errno + import functools + import ioloop + import socket + + def connection_ready(sock, fd, events): + while True: + try: + connection, address = sock.accept() + except socket.error, e: + if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN): + raise + return + connection.setblocking(0) + handle_connection(connection, address) + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(0) + sock.bind(("", port)) + sock.listen(128) + + io_loop = ioloop.IOLoop.instance() + callback = functools.partial(connection_ready, sock) + io_loop.add_handler(sock.fileno(), callback, io_loop.READ) + io_loop.start() + + """ + # Constants from the epoll module + _EPOLLIN = 0x001 + _EPOLLPRI = 0x002 + _EPOLLOUT = 0x004 + _EPOLLERR = 0x008 + _EPOLLHUP = 0x010 + _EPOLLRDHUP = 0x2000 + _EPOLLONESHOT = (1 << 30) + _EPOLLET = (1 << 31) + + # Our events map exactly to the epoll events + NONE = 0 + READ = _EPOLLIN + WRITE = _EPOLLOUT + ERROR = _EPOLLERR | _EPOLLHUP + + # Global lock for creating global IOLoop instance + _instance_lock = threading.Lock() + + _current = threading.local() + + @staticmethod + def instance(): + """Returns a global `IOLoop` instance. + + Most applications have a single, global `IOLoop` running on the + main thread. Use this method to get this instance from + another thread. To get the current thread's `IOLoop`, use `current()`. + """ + if not hasattr(IOLoop, "_instance"): + with IOLoop._instance_lock: + if not hasattr(IOLoop, "_instance"): + # New instance after double check + IOLoop._instance = IOLoop() + return IOLoop._instance + + @staticmethod + def initialized(): + """Returns true if the singleton instance has been created.""" + return hasattr(IOLoop, "_instance") + + def install(self): + """Installs this `IOLoop` object as the singleton instance. + + This is normally not necessary as `instance()` will create + an `IOLoop` on demand, but you may want to call `install` to use + a custom subclass of `IOLoop`. + """ + assert not IOLoop.initialized() + IOLoop._instance = self + + @staticmethod + def current(): + """Returns the current thread's `IOLoop`. + + If an `IOLoop` is currently running or has been marked as current + by `make_current`, returns that instance. Otherwise returns + `IOLoop.instance()`, i.e. the main thread's `IOLoop`. + + A common pattern for classes that depend on ``IOLoops`` is to use + a default argument to enable programs with multiple ``IOLoops`` + but not require the argument for simpler applications:: + + class MyClass(object): + def __init__(self, io_loop=None): + self.io_loop = io_loop or IOLoop.current() + + In general you should use `IOLoop.current` as the default when + constructing an asynchronous object, and use `IOLoop.instance` + when you mean to communicate to the main thread from a different + one. + """ + current = getattr(IOLoop._current, "instance", None) + if current is None: + return IOLoop.instance() + return current + + def make_current(self): + """Makes this the `IOLoop` for the current thread. + + An `IOLoop` automatically becomes current for its thread + when it is started, but it is sometimes useful to call + `make_current` explictly before starting the `IOLoop`, + so that code run at startup time can find the right + instance. + """ + IOLoop._current.instance = self + + @staticmethod + def clear_current(): + IOLoop._current.instance = None + + @classmethod + def configurable_base(cls): + return IOLoop + + @classmethod + def configurable_default(cls): + # this is the only patch to IOLoop: + from zmq.eventloop.ioloop import ZMQIOLoop + return ZMQIOLoop + # the remainder of this method is unused, + # but left for preservation reasons + if hasattr(select, "epoll"): + from tornado.platform.epoll import EPollIOLoop + return EPollIOLoop + if hasattr(select, "kqueue"): + # Python 2.6+ on BSD or Mac + from tornado.platform.kqueue import KQueueIOLoop + return KQueueIOLoop + from tornado.platform.select import SelectIOLoop + return SelectIOLoop + + def initialize(self): + pass + + def close(self, all_fds=False): + """Closes the `IOLoop`, freeing any resources used. + + If ``all_fds`` is true, all file descriptors registered on the + IOLoop will be closed (not just the ones created by the + `IOLoop` itself). + + Many applications will only use a single `IOLoop` that runs for the + entire lifetime of the process. In that case closing the `IOLoop` + is not necessary since everything will be cleaned up when the + process exits. `IOLoop.close` is provided mainly for scenarios + such as unit tests, which create and destroy a large number of + ``IOLoops``. + + An `IOLoop` must be completely stopped before it can be closed. This + means that `IOLoop.stop()` must be called *and* `IOLoop.start()` must + be allowed to return before attempting to call `IOLoop.close()`. + Therefore the call to `close` will usually appear just after + the call to `start` rather than near the call to `stop`. + + .. versionchanged:: 3.1 + If the `IOLoop` implementation supports non-integer objects + for "file descriptors", those objects will have their + ``close`` method when ``all_fds`` is true. + """ + raise NotImplementedError() + + def add_handler(self, fd, handler, events): + """Registers the given handler to receive the given events for fd. + + The ``events`` argument is a bitwise or of the constants + ``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``. + + When an event occurs, ``handler(fd, events)`` will be run. + """ + raise NotImplementedError() + + def update_handler(self, fd, events): + """Changes the events we listen for fd.""" + raise NotImplementedError() + + def remove_handler(self, fd): + """Stop listening for events on fd.""" + raise NotImplementedError() + + def set_blocking_signal_threshold(self, seconds, action): + """Sends a signal if the `IOLoop` is blocked for more than + ``s`` seconds. + + Pass ``seconds=None`` to disable. Requires Python 2.6 on a unixy + platform. + + The action parameter is a Python signal handler. Read the + documentation for the `signal` module for more information. + If ``action`` is None, the process will be killed if it is + blocked for too long. + """ + raise NotImplementedError() + + def set_blocking_log_threshold(self, seconds): + """Logs a stack trace if the `IOLoop` is blocked for more than + ``s`` seconds. + + Equivalent to ``set_blocking_signal_threshold(seconds, + self.log_stack)`` + """ + self.set_blocking_signal_threshold(seconds, self.log_stack) + + def log_stack(self, signal, frame): + """Signal handler to log the stack trace of the current thread. + + For use with `set_blocking_signal_threshold`. + """ + gen_log.warning('IOLoop blocked for %f seconds in\n%s', + self._blocking_signal_threshold, + ''.join(traceback.format_stack(frame))) + + def start(self): + """Starts the I/O loop. + + The loop will run until one of the callbacks calls `stop()`, which + will make the loop stop after the current event iteration completes. + """ + raise NotImplementedError() + + def stop(self): + """Stop the I/O loop. + + If the event loop is not currently running, the next call to `start()` + will return immediately. + + To use asynchronous methods from otherwise-synchronous code (such as + unit tests), you can start and stop the event loop like this:: + + ioloop = IOLoop() + async_method(ioloop=ioloop, callback=ioloop.stop) + ioloop.start() + + ``ioloop.start()`` will return after ``async_method`` has run + its callback, whether that callback was invoked before or + after ``ioloop.start``. + + Note that even after `stop` has been called, the `IOLoop` is not + completely stopped until `IOLoop.start` has also returned. + Some work that was scheduled before the call to `stop` may still + be run before the `IOLoop` shuts down. + """ + raise NotImplementedError() + + def run_sync(self, func, timeout=None): + """Starts the `IOLoop`, runs the given function, and stops the loop. + + If the function returns a `.Future`, the `IOLoop` will run + until the future is resolved. If it raises an exception, the + `IOLoop` will stop and the exception will be re-raised to the + caller. + + The keyword-only argument ``timeout`` may be used to set + a maximum duration for the function. If the timeout expires, + a `TimeoutError` is raised. + + This method is useful in conjunction with `tornado.gen.coroutine` + to allow asynchronous calls in a ``main()`` function:: + + @gen.coroutine + def main(): + # do stuff... + + if __name__ == '__main__': + IOLoop.instance().run_sync(main) + """ + future_cell = [None] + + def run(): + try: + result = func() + except Exception: + future_cell[0] = TracebackFuture() + future_cell[0].set_exc_info(sys.exc_info()) + else: + if isinstance(result, Future): + future_cell[0] = result + else: + future_cell[0] = Future() + future_cell[0].set_result(result) + self.add_future(future_cell[0], lambda future: self.stop()) + self.add_callback(run) + if timeout is not None: + timeout_handle = self.add_timeout(self.time() + timeout, self.stop) + self.start() + if timeout is not None: + self.remove_timeout(timeout_handle) + if not future_cell[0].done(): + raise TimeoutError('Operation timed out after %s seconds' % timeout) + return future_cell[0].result() + + def time(self): + """Returns the current time according to the `IOLoop`'s clock. + + The return value is a floating-point number relative to an + unspecified time in the past. + + By default, the `IOLoop`'s time function is `time.time`. However, + it may be configured to use e.g. `time.monotonic` instead. + Calls to `add_timeout` that pass a number instead of a + `datetime.timedelta` should use this function to compute the + appropriate time, so they can work no matter what time function + is chosen. + """ + return time.time() + + def add_timeout(self, deadline, callback): + """Runs the ``callback`` at the time ``deadline`` from the I/O loop. + + Returns an opaque handle that may be passed to + `remove_timeout` to cancel. + + ``deadline`` may be a number denoting a time (on the same + scale as `IOLoop.time`, normally `time.time`), or a + `datetime.timedelta` object for a deadline relative to the + current time. + + Note that it is not safe to call `add_timeout` from other threads. + Instead, you must use `add_callback` to transfer control to the + `IOLoop`'s thread, and then call `add_timeout` from there. + """ + raise NotImplementedError() + + def remove_timeout(self, timeout): + """Cancels a pending timeout. + + The argument is a handle as returned by `add_timeout`. It is + safe to call `remove_timeout` even if the callback has already + been run. + """ + raise NotImplementedError() + + def add_callback(self, callback, *args, **kwargs): + """Calls the given callback on the next I/O loop iteration. + + It is safe to call this method from any thread at any time, + except from a signal handler. Note that this is the **only** + method in `IOLoop` that makes this thread-safety guarantee; all + other interaction with the `IOLoop` must be done from that + `IOLoop`'s thread. `add_callback()` may be used to transfer + control from other threads to the `IOLoop`'s thread. + + To add a callback from a signal handler, see + `add_callback_from_signal`. + """ + raise NotImplementedError() + + def add_callback_from_signal(self, callback, *args, **kwargs): + """Calls the given callback on the next I/O loop iteration. + + Safe for use from a Python signal handler; should not be used + otherwise. + + Callbacks added with this method will be run without any + `.stack_context`, to avoid picking up the context of the function + that was interrupted by the signal. + """ + raise NotImplementedError() + + def add_future(self, future, callback): + """Schedules a callback on the ``IOLoop`` when the given + `.Future` is finished. + + The callback is invoked with one argument, the + `.Future`. + """ + assert isinstance(future, Future) + callback = stack_context.wrap(callback) + future.add_done_callback( + lambda future: self.add_callback(callback, future)) + + def _run_callback(self, callback): + """Runs a callback with error handling. + + For use in subclasses. + """ + try: + callback() + except Exception: + self.handle_callback_exception(callback) + + def handle_callback_exception(self, callback): + """This method is called whenever a callback run by the `IOLoop` + throws an exception. + + By default simply logs the exception as an error. Subclasses + may override this method to customize reporting of exceptions. + + The exception itself is not passed explicitly, but is available + in `sys.exc_info`. + """ + app_log.error("Exception in callback %r", callback, exc_info=True) + + +class PollIOLoop(IOLoop): + """Base class for IOLoops built around a select-like function. + + For concrete implementations, see `tornado.platform.epoll.EPollIOLoop` + (Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or + `tornado.platform.select.SelectIOLoop` (all platforms). + """ + def initialize(self, impl, time_func=None): + super(PollIOLoop, self).initialize() + self._impl = impl + if hasattr(self._impl, 'fileno'): + set_close_exec(self._impl.fileno()) + self.time_func = time_func or time.time + self._handlers = {} + self._events = {} + self._callbacks = [] + self._callback_lock = threading.Lock() + self._timeouts = [] + self._cancellations = 0 + self._running = False + self._stopped = False + self._closing = False + self._thread_ident = None + self._blocking_signal_threshold = None + + # Create a pipe that we send bogus data to when we want to wake + # the I/O loop when it is idle + self._waker = Waker() + self.add_handler(self._waker.fileno(), + lambda fd, events: self._waker.consume(), + self.READ) + + def close(self, all_fds=False): + with self._callback_lock: + self._closing = True + self.remove_handler(self._waker.fileno()) + if all_fds: + for fd in self._handlers.keys(): + try: + close_method = getattr(fd, 'close', None) + if close_method is not None: + close_method() + else: + os.close(fd) + except Exception: + gen_log.debug("error closing fd %s", fd, exc_info=True) + self._waker.close() + self._impl.close() + + def add_handler(self, fd, handler, events): + self._handlers[fd] = stack_context.wrap(handler) + self._impl.register(fd, events | self.ERROR) + + def update_handler(self, fd, events): + self._impl.modify(fd, events | self.ERROR) + + def remove_handler(self, fd): + self._handlers.pop(fd, None) + self._events.pop(fd, None) + try: + self._impl.unregister(fd) + except Exception: + gen_log.debug("Error deleting fd from IOLoop", exc_info=True) + + def set_blocking_signal_threshold(self, seconds, action): + if not hasattr(signal, "setitimer"): + gen_log.error("set_blocking_signal_threshold requires a signal module " + "with the setitimer method") + return + self._blocking_signal_threshold = seconds + if seconds is not None: + signal.signal(signal.SIGALRM, + action if action is not None else signal.SIG_DFL) + + def start(self): + if not logging.getLogger().handlers: + # The IOLoop catches and logs exceptions, so it's + # important that log output be visible. However, python's + # default behavior for non-root loggers (prior to python + # 3.2) is to print an unhelpful "no handlers could be + # found" message rather than the actual log entry, so we + # must explicitly configure logging if we've made it this + # far without anything. + logging.basicConfig() + if self._stopped: + self._stopped = False + return + old_current = getattr(IOLoop._current, "instance", None) + IOLoop._current.instance = self + self._thread_ident = thread.get_ident() + self._running = True + + # signal.set_wakeup_fd closes a race condition in event loops: + # a signal may arrive at the beginning of select/poll/etc + # before it goes into its interruptible sleep, so the signal + # will be consumed without waking the select. The solution is + # for the (C, synchronous) signal handler to write to a pipe, + # which will then be seen by select. + # + # In python's signal handling semantics, this only matters on the + # main thread (fortunately, set_wakeup_fd only works on the main + # thread and will raise a ValueError otherwise). + # + # If someone has already set a wakeup fd, we don't want to + # disturb it. This is an issue for twisted, which does its + # SIGCHILD processing in response to its own wakeup fd being + # written to. As long as the wakeup fd is registered on the IOLoop, + # the loop will still wake up and everything should work. + old_wakeup_fd = None + if hasattr(signal, 'set_wakeup_fd') and os.name == 'posix': + # requires python 2.6+, unix. set_wakeup_fd exists but crashes + # the python process on windows. + try: + old_wakeup_fd = signal.set_wakeup_fd(self._waker.write_fileno()) + if old_wakeup_fd != -1: + # Already set, restore previous value. This is a little racy, + # but there's no clean get_wakeup_fd and in real use the + # IOLoop is just started once at the beginning. + signal.set_wakeup_fd(old_wakeup_fd) + old_wakeup_fd = None + except ValueError: # non-main thread + pass + + while True: + poll_timeout = 3600.0 + + # Prevent IO event starvation by delaying new callbacks + # to the next iteration of the event loop. + with self._callback_lock: + callbacks = self._callbacks + self._callbacks = [] + for callback in callbacks: + self._run_callback(callback) + + if self._timeouts: + now = self.time() + while self._timeouts: + if self._timeouts[0].callback is None: + # the timeout was cancelled + heapq.heappop(self._timeouts) + self._cancellations -= 1 + elif self._timeouts[0].deadline <= now: + timeout = heapq.heappop(self._timeouts) + self._run_callback(timeout.callback) + else: + seconds = self._timeouts[0].deadline - now + poll_timeout = min(seconds, poll_timeout) + break + if (self._cancellations > 512 + and self._cancellations > (len(self._timeouts) >> 1)): + # Clean up the timeout queue when it gets large and it's + # more than half cancellations. + self._cancellations = 0 + self._timeouts = [x for x in self._timeouts + if x.callback is not None] + heapq.heapify(self._timeouts) + + if self._callbacks: + # If any callbacks or timeouts called add_callback, + # we don't want to wait in poll() before we run them. + poll_timeout = 0.0 + + if not self._running: + break + + if self._blocking_signal_threshold is not None: + # clear alarm so it doesn't fire while poll is waiting for + # events. + signal.setitimer(signal.ITIMER_REAL, 0, 0) + + try: + event_pairs = self._impl.poll(poll_timeout) + except Exception as e: + # Depending on python version and IOLoop implementation, + # different exception types may be thrown and there are + # two ways EINTR might be signaled: + # * e.errno == errno.EINTR + # * e.args is like (errno.EINTR, 'Interrupted system call') + if (getattr(e, 'errno', None) == errno.EINTR or + (isinstance(getattr(e, 'args', None), tuple) and + len(e.args) == 2 and e.args[0] == errno.EINTR)): + continue + else: + raise + + if self._blocking_signal_threshold is not None: + signal.setitimer(signal.ITIMER_REAL, + self._blocking_signal_threshold, 0) + + # Pop one fd at a time from the set of pending fds and run + # its handler. Since that handler may perform actions on + # other file descriptors, there may be reentrant calls to + # this IOLoop that update self._events + self._events.update(event_pairs) + while self._events: + fd, events = self._events.popitem() + try: + self._handlers[fd](fd, events) + except (OSError, IOError) as e: + if e.args[0] == errno.EPIPE: + # Happens when the client closes the connection + pass + else: + app_log.error("Exception in I/O handler for fd %s", + fd, exc_info=True) + except Exception: + app_log.error("Exception in I/O handler for fd %s", + fd, exc_info=True) + # reset the stopped flag so another start/stop pair can be issued + self._stopped = False + if self._blocking_signal_threshold is not None: + signal.setitimer(signal.ITIMER_REAL, 0, 0) + IOLoop._current.instance = old_current + if old_wakeup_fd is not None: + signal.set_wakeup_fd(old_wakeup_fd) + + def stop(self): + self._running = False + self._stopped = True + self._waker.wake() + + def time(self): + return self.time_func() + + def add_timeout(self, deadline, callback): + timeout = _Timeout(deadline, stack_context.wrap(callback), self) + heapq.heappush(self._timeouts, timeout) + return timeout + + def remove_timeout(self, timeout): + # Removing from a heap is complicated, so just leave the defunct + # timeout object in the queue (see discussion in + # http://docs.python.org/library/heapq.html). + # If this turns out to be a problem, we could add a garbage + # collection pass whenever there are too many dead timeouts. + timeout.callback = None + self._cancellations += 1 + + def add_callback(self, callback, *args, **kwargs): + with self._callback_lock: + if self._closing: + raise RuntimeError("IOLoop is closing") + list_empty = not self._callbacks + self._callbacks.append(functools.partial( + stack_context.wrap(callback), *args, **kwargs)) + if list_empty and thread.get_ident() != self._thread_ident: + # If we're in the IOLoop's thread, we know it's not currently + # polling. If we're not, and we added the first callback to an + # empty list, we may need to wake it up (it may wake up on its + # own, but an occasional extra wake is harmless). Waking + # up a polling IOLoop is relatively expensive, so we try to + # avoid it when we can. + self._waker.wake() + + def add_callback_from_signal(self, callback, *args, **kwargs): + with stack_context.NullContext(): + if thread.get_ident() != self._thread_ident: + # if the signal is handled on another thread, we can add + # it normally (modulo the NullContext) + self.add_callback(callback, *args, **kwargs) + else: + # If we're on the IOLoop's thread, we cannot use + # the regular add_callback because it may deadlock on + # _callback_lock. Blindly insert into self._callbacks. + # This is safe because the GIL makes list.append atomic. + # One subtlety is that if the signal interrupted the + # _callback_lock block in IOLoop.start, we may modify + # either the old or new version of self._callbacks, + # but either way will work. + self._callbacks.append(functools.partial( + stack_context.wrap(callback), *args, **kwargs)) + + +class _Timeout(object): + """An IOLoop timeout, a UNIX timestamp and a callback""" + + # Reduce memory overhead when there are lots of pending callbacks + __slots__ = ['deadline', 'callback'] + + def __init__(self, deadline, callback, io_loop): + if isinstance(deadline, numbers.Real): + self.deadline = deadline + elif isinstance(deadline, datetime.timedelta): + self.deadline = io_loop.time() + _Timeout.timedelta_to_seconds(deadline) + else: + raise TypeError("Unsupported deadline %r" % deadline) + self.callback = callback + + @staticmethod + def timedelta_to_seconds(td): + """Equivalent to td.total_seconds() (introduced in python 2.7).""" + return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6) + + # Comparison methods to sort by deadline, with object id as a tiebreaker + # to guarantee a consistent ordering. The heapq module uses __le__ + # in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons + # use __lt__). + def __lt__(self, other): + return ((self.deadline, id(self)) < + (other.deadline, id(other))) + + def __le__(self, other): + return ((self.deadline, id(self)) <= + (other.deadline, id(other))) + + +class PeriodicCallback(object): + """Schedules the given callback to be called periodically. + + The callback is called every ``callback_time`` milliseconds. + + `start` must be called after the `PeriodicCallback` is created. + """ + def __init__(self, callback, callback_time, io_loop=None): + self.callback = callback + if callback_time <= 0: + raise ValueError("Periodic callback must have a positive callback_time") + self.callback_time = callback_time + self.io_loop = io_loop or IOLoop.current() + self._running = False + self._timeout = None + + def start(self): + """Starts the timer.""" + self._running = True + self._next_timeout = self.io_loop.time() + self._schedule_next() + + def stop(self): + """Stops the timer.""" + self._running = False + if self._timeout is not None: + self.io_loop.remove_timeout(self._timeout) + self._timeout = None + + def _run(self): + if not self._running: + return + try: + self.callback() + except Exception: + app_log.error("Error in periodic callback", exc_info=True) + self._schedule_next() + + def _schedule_next(self): + if self._running: + current_time = self.io_loop.time() + while self._next_timeout <= current_time: + self._next_timeout += self.callback_time / 1000.0 + self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run) diff --git a/src/console/zmq/eventloop/minitornado/log.py b/src/console/zmq/eventloop/minitornado/log.py new file mode 100755 index 00000000..49051e89 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/log.py @@ -0,0 +1,6 @@ +"""minimal subset of tornado.log for zmq.eventloop.minitornado""" + +import logging + +app_log = logging.getLogger("tornado.application") +gen_log = logging.getLogger("tornado.general") diff --git a/src/console/zmq/eventloop/minitornado/platform/__init__.py b/src/console/zmq/eventloop/minitornado/platform/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/__init__.py diff --git a/src/console/zmq/eventloop/minitornado/platform/auto.py b/src/console/zmq/eventloop/minitornado/platform/auto.py new file mode 100755 index 00000000..b40ccd94 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/auto.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Implementation of platform-specific functionality. + +For each function or class described in `tornado.platform.interface`, +the appropriate platform-specific implementation exists in this module. +Most code that needs access to this functionality should do e.g.:: + + from tornado.platform.auto import set_close_exec +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import os + +if os.name == 'nt': + from .common import Waker + from .windows import set_close_exec +else: + from .posix import set_close_exec, Waker + +try: + # monotime monkey-patches the time module to have a monotonic function + # in versions of python before 3.3. + import monotime +except ImportError: + pass +try: + from time import monotonic as monotonic_time +except ImportError: + monotonic_time = None diff --git a/src/console/zmq/eventloop/minitornado/platform/common.py b/src/console/zmq/eventloop/minitornado/platform/common.py new file mode 100755 index 00000000..2d75dc1e --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/common.py @@ -0,0 +1,91 @@ +"""Lowest-common-denominator implementations of platform functionality.""" +from __future__ import absolute_import, division, print_function, with_statement + +import errno +import socket + +from . import interface + + +class Waker(interface.Waker): + """Create an OS independent asynchronous pipe. + + For use on platforms that don't have os.pipe() (or where pipes cannot + be passed to select()), but do have sockets. This includes Windows + and Jython. + """ + def __init__(self): + # Based on Zope async.py: http://svn.zope.org/zc.ngi/trunk/src/zc/ngi/async.py + + self.writer = socket.socket() + # Disable buffering -- pulling the trigger sends 1 byte, + # and we want that sent immediately, to wake up ASAP. + self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + count = 0 + while 1: + count += 1 + # Bind to a local port; for efficiency, let the OS pick + # a free port for us. + # Unfortunately, stress tests showed that we may not + # be able to connect to that port ("Address already in + # use") despite that the OS picked it. This appears + # to be a race bug in the Windows socket implementation. + # So we loop until a connect() succeeds (almost always + # on the first try). See the long thread at + # http://mail.zope.org/pipermail/zope/2005-July/160433.html + # for hideous details. + a = socket.socket() + a.bind(("127.0.0.1", 0)) + a.listen(1) + connect_address = a.getsockname() # assigned (host, port) pair + try: + self.writer.connect(connect_address) + break # success + except socket.error as detail: + if (not hasattr(errno, 'WSAEADDRINUSE') or + detail[0] != errno.WSAEADDRINUSE): + # "Address already in use" is the only error + # I've seen on two WinXP Pro SP2 boxes, under + # Pythons 2.3.5 and 2.4.1. + raise + # (10048, 'Address already in use') + # assert count <= 2 # never triggered in Tim's tests + if count >= 10: # I've never seen it go above 2 + a.close() + self.writer.close() + raise socket.error("Cannot bind trigger!") + # Close `a` and try again. Note: I originally put a short + # sleep() here, but it didn't appear to help or hurt. + a.close() + + self.reader, addr = a.accept() + self.reader.setblocking(0) + self.writer.setblocking(0) + a.close() + self.reader_fd = self.reader.fileno() + + def fileno(self): + return self.reader.fileno() + + def write_fileno(self): + return self.writer.fileno() + + def wake(self): + try: + self.writer.send(b"x") + except (IOError, socket.error): + pass + + def consume(self): + try: + while True: + result = self.reader.recv(1024) + if not result: + break + except (IOError, socket.error): + pass + + def close(self): + self.reader.close() + self.writer.close() diff --git a/src/console/zmq/eventloop/minitornado/platform/interface.py b/src/console/zmq/eventloop/minitornado/platform/interface.py new file mode 100755 index 00000000..07da6bab --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/interface.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Interfaces for platform-specific functionality. + +This module exists primarily for documentation purposes and as base classes +for other tornado.platform modules. Most code should import the appropriate +implementation from `tornado.platform.auto`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + + +def set_close_exec(fd): + """Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor.""" + raise NotImplementedError() + + +class Waker(object): + """A socket-like object that can wake another thread from ``select()``. + + The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to + its ``select`` (or ``epoll`` or ``kqueue``) calls. When another + thread wants to wake up the loop, it calls `wake`. Once it has woken + up, it will call `consume` to do any necessary per-wake cleanup. When + the ``IOLoop`` is closed, it closes its waker too. + """ + def fileno(self): + """Returns the read file descriptor for this waker. + + Must be suitable for use with ``select()`` or equivalent on the + local platform. + """ + raise NotImplementedError() + + def write_fileno(self): + """Returns the write file descriptor for this waker.""" + raise NotImplementedError() + + def wake(self): + """Triggers activity on the waker's file descriptor.""" + raise NotImplementedError() + + def consume(self): + """Called after the listen has woken up to do any necessary cleanup.""" + raise NotImplementedError() + + def close(self): + """Closes the waker's file descriptor(s).""" + raise NotImplementedError() diff --git a/src/console/zmq/eventloop/minitornado/platform/posix.py b/src/console/zmq/eventloop/minitornado/platform/posix.py new file mode 100755 index 00000000..ccffbb66 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/posix.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Posix implementations of platform-specific functionality.""" + +from __future__ import absolute_import, division, print_function, with_statement + +import fcntl +import os + +from . import interface + + +def set_close_exec(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + +class Waker(interface.Waker): + def __init__(self): + r, w = os.pipe() + _set_nonblocking(r) + _set_nonblocking(w) + set_close_exec(r) + set_close_exec(w) + self.reader = os.fdopen(r, "rb", 0) + self.writer = os.fdopen(w, "wb", 0) + + def fileno(self): + return self.reader.fileno() + + def write_fileno(self): + return self.writer.fileno() + + def wake(self): + try: + self.writer.write(b"x") + except IOError: + pass + + def consume(self): + try: + while True: + result = self.reader.read() + if not result: + break + except IOError: + pass + + def close(self): + self.reader.close() + self.writer.close() diff --git a/src/console/zmq/eventloop/minitornado/platform/windows.py b/src/console/zmq/eventloop/minitornado/platform/windows.py new file mode 100755 index 00000000..817bdca1 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/platform/windows.py @@ -0,0 +1,20 @@ +# NOTE: win32 support is currently experimental, and not recommended +# for production use. + + +from __future__ import absolute_import, division, print_function, with_statement +import ctypes +import ctypes.wintypes + +# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx +SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation +SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD) +SetHandleInformation.restype = ctypes.wintypes.BOOL + +HANDLE_FLAG_INHERIT = 0x00000001 + + +def set_close_exec(fd): + success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0) + if not success: + raise ctypes.GetLastError() diff --git a/src/console/zmq/eventloop/minitornado/stack_context.py b/src/console/zmq/eventloop/minitornado/stack_context.py new file mode 100755 index 00000000..226d8042 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/stack_context.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python +# +# Copyright 2010 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""`StackContext` allows applications to maintain threadlocal-like state +that follows execution as it moves to other execution contexts. + +The motivating examples are to eliminate the need for explicit +``async_callback`` wrappers (as in `tornado.web.RequestHandler`), and to +allow some additional context to be kept for logging. + +This is slightly magic, but it's an extension of the idea that an +exception handler is a kind of stack-local state and when that stack +is suspended and resumed in a new context that state needs to be +preserved. `StackContext` shifts the burden of restoring that state +from each call site (e.g. wrapping each `.AsyncHTTPClient` callback +in ``async_callback``) to the mechanisms that transfer control from +one context to another (e.g. `.AsyncHTTPClient` itself, `.IOLoop`, +thread pools, etc). + +Example usage:: + + @contextlib.contextmanager + def die_on_error(): + try: + yield + except Exception: + logging.error("exception in asynchronous operation",exc_info=True) + sys.exit(1) + + with StackContext(die_on_error): + # Any exception thrown here *or in callback and its desendents* + # will cause the process to exit instead of spinning endlessly + # in the ioloop. + http_client.fetch(url, callback) + ioloop.start() + +Most applications shouln't have to work with `StackContext` directly. +Here are a few rules of thumb for when it's necessary: + +* If you're writing an asynchronous library that doesn't rely on a + stack_context-aware library like `tornado.ioloop` or `tornado.iostream` + (for example, if you're writing a thread pool), use + `.stack_context.wrap()` before any asynchronous operations to capture the + stack context from where the operation was started. + +* If you're writing an asynchronous library that has some shared + resources (such as a connection pool), create those shared resources + within a ``with stack_context.NullContext():`` block. This will prevent + ``StackContexts`` from leaking from one request to another. + +* If you want to write something like an exception handler that will + persist across asynchronous calls, create a new `StackContext` (or + `ExceptionStackContext`), and make your asynchronous calls in a ``with`` + block that references your `StackContext`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import sys +import threading + +from .util import raise_exc_info + + +class StackContextInconsistentError(Exception): + pass + + +class _State(threading.local): + def __init__(self): + self.contexts = (tuple(), None) +_state = _State() + + +class StackContext(object): + """Establishes the given context as a StackContext that will be transferred. + + Note that the parameter is a callable that returns a context + manager, not the context itself. That is, where for a + non-transferable context manager you would say:: + + with my_context(): + + StackContext takes the function itself rather than its result:: + + with StackContext(my_context): + + The result of ``with StackContext() as cb:`` is a deactivation + callback. Run this callback when the StackContext is no longer + needed to ensure that it is not propagated any further (note that + deactivating a context does not affect any instances of that + context that are currently pending). This is an advanced feature + and not necessary in most applications. + """ + def __init__(self, context_factory): + self.context_factory = context_factory + self.contexts = [] + self.active = True + + def _deactivate(self): + self.active = False + + # StackContext protocol + def enter(self): + context = self.context_factory() + self.contexts.append(context) + context.__enter__() + + def exit(self, type, value, traceback): + context = self.contexts.pop() + context.__exit__(type, value, traceback) + + # Note that some of this code is duplicated in ExceptionStackContext + # below. ExceptionStackContext is more common and doesn't need + # the full generality of this class. + def __enter__(self): + self.old_contexts = _state.contexts + self.new_contexts = (self.old_contexts[0] + (self,), self) + _state.contexts = self.new_contexts + + try: + self.enter() + except: + _state.contexts = self.old_contexts + raise + + return self._deactivate + + def __exit__(self, type, value, traceback): + try: + self.exit(type, value, traceback) + finally: + final_contexts = _state.contexts + _state.contexts = self.old_contexts + + # Generator coroutines and with-statements with non-local + # effects interact badly. Check here for signs of + # the stack getting out of sync. + # Note that this check comes after restoring _state.context + # so that if it fails things are left in a (relatively) + # consistent state. + if final_contexts is not self.new_contexts: + raise StackContextInconsistentError( + 'stack_context inconsistency (may be caused by yield ' + 'within a "with StackContext" block)') + + # Break up a reference to itself to allow for faster GC on CPython. + self.new_contexts = None + + +class ExceptionStackContext(object): + """Specialization of StackContext for exception handling. + + The supplied ``exception_handler`` function will be called in the + event of an uncaught exception in this context. The semantics are + similar to a try/finally clause, and intended use cases are to log + an error, close a socket, or similar cleanup actions. The + ``exc_info`` triple ``(type, value, traceback)`` will be passed to the + exception_handler function. + + If the exception handler returns true, the exception will be + consumed and will not be propagated to other exception handlers. + """ + def __init__(self, exception_handler): + self.exception_handler = exception_handler + self.active = True + + def _deactivate(self): + self.active = False + + def exit(self, type, value, traceback): + if type is not None: + return self.exception_handler(type, value, traceback) + + def __enter__(self): + self.old_contexts = _state.contexts + self.new_contexts = (self.old_contexts[0], self) + _state.contexts = self.new_contexts + + return self._deactivate + + def __exit__(self, type, value, traceback): + try: + if type is not None: + return self.exception_handler(type, value, traceback) + finally: + final_contexts = _state.contexts + _state.contexts = self.old_contexts + + if final_contexts is not self.new_contexts: + raise StackContextInconsistentError( + 'stack_context inconsistency (may be caused by yield ' + 'within a "with StackContext" block)') + + # Break up a reference to itself to allow for faster GC on CPython. + self.new_contexts = None + + +class NullContext(object): + """Resets the `StackContext`. + + Useful when creating a shared resource on demand (e.g. an + `.AsyncHTTPClient`) where the stack that caused the creating is + not relevant to future operations. + """ + def __enter__(self): + self.old_contexts = _state.contexts + _state.contexts = (tuple(), None) + + def __exit__(self, type, value, traceback): + _state.contexts = self.old_contexts + + +def _remove_deactivated(contexts): + """Remove deactivated handlers from the chain""" + # Clean ctx handlers + stack_contexts = tuple([h for h in contexts[0] if h.active]) + + # Find new head + head = contexts[1] + while head is not None and not head.active: + head = head.old_contexts[1] + + # Process chain + ctx = head + while ctx is not None: + parent = ctx.old_contexts[1] + + while parent is not None: + if parent.active: + break + ctx.old_contexts = parent.old_contexts + parent = parent.old_contexts[1] + + ctx = parent + + return (stack_contexts, head) + + +def wrap(fn): + """Returns a callable object that will restore the current `StackContext` + when executed. + + Use this whenever saving a callback to be executed later in a + different execution context (either in a different thread or + asynchronously in the same thread). + """ + # Check if function is already wrapped + if fn is None or hasattr(fn, '_wrapped'): + return fn + + # Capture current stack head + # TODO: Any other better way to store contexts and update them in wrapped function? + cap_contexts = [_state.contexts] + + def wrapped(*args, **kwargs): + ret = None + try: + # Capture old state + current_state = _state.contexts + + # Remove deactivated items + cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0]) + + # Force new state + _state.contexts = contexts + + # Current exception + exc = (None, None, None) + top = None + + # Apply stack contexts + last_ctx = 0 + stack = contexts[0] + + # Apply state + for n in stack: + try: + n.enter() + last_ctx += 1 + except: + # Exception happened. Record exception info and store top-most handler + exc = sys.exc_info() + top = n.old_contexts[1] + + # Execute callback if no exception happened while restoring state + if top is None: + try: + ret = fn(*args, **kwargs) + except: + exc = sys.exc_info() + top = contexts[1] + + # If there was exception, try to handle it by going through the exception chain + if top is not None: + exc = _handle_exception(top, exc) + else: + # Otherwise take shorter path and run stack contexts in reverse order + while last_ctx > 0: + last_ctx -= 1 + c = stack[last_ctx] + + try: + c.exit(*exc) + except: + exc = sys.exc_info() + top = c.old_contexts[1] + break + else: + top = None + + # If if exception happened while unrolling, take longer exception handler path + if top is not None: + exc = _handle_exception(top, exc) + + # If exception was not handled, raise it + if exc != (None, None, None): + raise_exc_info(exc) + finally: + _state.contexts = current_state + return ret + + wrapped._wrapped = True + return wrapped + + +def _handle_exception(tail, exc): + while tail is not None: + try: + if tail.exit(*exc): + exc = (None, None, None) + except: + exc = sys.exc_info() + + tail = tail.old_contexts[1] + + return exc + + +def run_with_stack_context(context, func): + """Run a coroutine ``func`` in the given `StackContext`. + + It is not safe to have a ``yield`` statement within a ``with StackContext`` + block, so it is difficult to use stack context with `.gen.coroutine`. + This helper function runs the function in the correct context while + keeping the ``yield`` and ``with`` statements syntactically separate. + + Example:: + + @gen.coroutine + def incorrect(): + with StackContext(ctx): + # ERROR: this will raise StackContextInconsistentError + yield other_coroutine() + + @gen.coroutine + def correct(): + yield run_with_stack_context(StackContext(ctx), other_coroutine) + + .. versionadded:: 3.1 + """ + with context: + return func() diff --git a/src/console/zmq/eventloop/minitornado/util.py b/src/console/zmq/eventloop/minitornado/util.py new file mode 100755 index 00000000..c1e2eb95 --- /dev/null +++ b/src/console/zmq/eventloop/minitornado/util.py @@ -0,0 +1,184 @@ +"""Miscellaneous utility functions and classes. + +This module is used internally by Tornado. It is not necessarily expected +that the functions and classes defined here will be useful to other +applications, but they are documented here in case they are. + +The one public-facing part of this module is the `Configurable` class +and its `~Configurable.configure` method, which becomes a part of the +interface of its subclasses, including `.AsyncHTTPClient`, `.IOLoop`, +and `.Resolver`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import sys + + +def import_object(name): + """Imports an object by name. + + import_object('x') is equivalent to 'import x'. + import_object('x.y.z') is equivalent to 'from x.y import z'. + + >>> import tornado.escape + >>> import_object('tornado.escape') is tornado.escape + True + >>> import_object('tornado.escape.utf8') is tornado.escape.utf8 + True + >>> import_object('tornado') is tornado + True + >>> import_object('tornado.missing_module') + Traceback (most recent call last): + ... + ImportError: No module named missing_module + """ + if name.count('.') == 0: + return __import__(name, None, None) + + parts = name.split('.') + obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0) + try: + return getattr(obj, parts[-1]) + except AttributeError: + raise ImportError("No module named %s" % parts[-1]) + + +# Fake unicode literal support: Python 3.2 doesn't have the u'' marker for +# literal strings, and alternative solutions like "from __future__ import +# unicode_literals" have other problems (see PEP 414). u() can be applied +# to ascii strings that include \u escapes (but they must not contain +# literal non-ascii characters). +if type('') is not type(b''): + def u(s): + return s + bytes_type = bytes + unicode_type = str + basestring_type = str +else: + def u(s): + return s.decode('unicode_escape') + bytes_type = str + unicode_type = unicode + basestring_type = basestring + + +if sys.version_info > (3,): + exec(""" +def raise_exc_info(exc_info): + raise exc_info[1].with_traceback(exc_info[2]) + +def exec_in(code, glob, loc=None): + if isinstance(code, str): + code = compile(code, '<string>', 'exec', dont_inherit=True) + exec(code, glob, loc) +""") +else: + exec(""" +def raise_exc_info(exc_info): + raise exc_info[0], exc_info[1], exc_info[2] + +def exec_in(code, glob, loc=None): + if isinstance(code, basestring): + # exec(string) inherits the caller's future imports; compile + # the string first to prevent that. + code = compile(code, '<string>', 'exec', dont_inherit=True) + exec code in glob, loc +""") + + +class Configurable(object): + """Base class for configurable interfaces. + + A configurable interface is an (abstract) class whose constructor + acts as a factory function for one of its implementation subclasses. + The implementation subclass as well as optional keyword arguments to + its initializer can be set globally at runtime with `configure`. + + By using the constructor as the factory method, the interface + looks like a normal class, `isinstance` works as usual, etc. This + pattern is most useful when the choice of implementation is likely + to be a global decision (e.g. when `~select.epoll` is available, + always use it instead of `~select.select`), or when a + previously-monolithic class has been split into specialized + subclasses. + + Configurable subclasses must define the class methods + `configurable_base` and `configurable_default`, and use the instance + method `initialize` instead of ``__init__``. + """ + __impl_class = None + __impl_kwargs = None + + def __new__(cls, **kwargs): + base = cls.configurable_base() + args = {} + if cls is base: + impl = cls.configured_class() + if base.__impl_kwargs: + args.update(base.__impl_kwargs) + else: + impl = cls + args.update(kwargs) + instance = super(Configurable, cls).__new__(impl) + # initialize vs __init__ chosen for compatiblity with AsyncHTTPClient + # singleton magic. If we get rid of that we can switch to __init__ + # here too. + instance.initialize(**args) + return instance + + @classmethod + def configurable_base(cls): + """Returns the base class of a configurable hierarchy. + + This will normally return the class in which it is defined. + (which is *not* necessarily the same as the cls classmethod parameter). + """ + raise NotImplementedError() + + @classmethod + def configurable_default(cls): + """Returns the implementation class to be used if none is configured.""" + raise NotImplementedError() + + def initialize(self): + """Initialize a `Configurable` subclass instance. + + Configurable classes should use `initialize` instead of ``__init__``. + """ + + @classmethod + def configure(cls, impl, **kwargs): + """Sets the class to use when the base class is instantiated. + + Keyword arguments will be saved and added to the arguments passed + to the constructor. This can be used to set global defaults for + some parameters. + """ + base = cls.configurable_base() + if isinstance(impl, (unicode_type, bytes_type)): + impl = import_object(impl) + if impl is not None and not issubclass(impl, cls): + raise ValueError("Invalid subclass of %s" % cls) + base.__impl_class = impl + base.__impl_kwargs = kwargs + + @classmethod + def configured_class(cls): + """Returns the currently configured class.""" + base = cls.configurable_base() + if cls.__impl_class is None: + base.__impl_class = cls.configurable_default() + return base.__impl_class + + @classmethod + def _save_configuration(cls): + base = cls.configurable_base() + return (base.__impl_class, base.__impl_kwargs) + + @classmethod + def _restore_configuration(cls, saved): + base = cls.configurable_base() + base.__impl_class = saved[0] + base.__impl_kwargs = saved[1] + diff --git a/src/console/zmq/eventloop/zmqstream.py b/src/console/zmq/eventloop/zmqstream.py new file mode 100755 index 00000000..86a97e44 --- /dev/null +++ b/src/console/zmq/eventloop/zmqstream.py @@ -0,0 +1,529 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""A utility class to send to and recv from a non-blocking socket.""" + +from __future__ import with_statement + +import sys + +import zmq +from zmq.utils import jsonapi + +try: + import cPickle as pickle +except ImportError: + import pickle + +from .ioloop import IOLoop + +try: + # gen_log will only import from >= 3.0 + from tornado.log import gen_log + from tornado import stack_context +except ImportError: + from .minitornado.log import gen_log + from .minitornado import stack_context + +try: + from queue import Queue +except ImportError: + from Queue import Queue + +from zmq.utils.strtypes import bytes, unicode, basestring + +try: + callable +except NameError: + callable = lambda obj: hasattr(obj, '__call__') + + +class ZMQStream(object): + """A utility class to register callbacks when a zmq socket sends and receives + + For use with zmq.eventloop.ioloop + + There are three main methods + + Methods: + + * **on_recv(callback, copy=True):** + register a callback to be run every time the socket has something to receive + * **on_send(callback):** + register a callback to be run every time you call send + * **send(self, msg, flags=0, copy=False, callback=None):** + perform a send that will trigger the callback + if callback is passed, on_send is also called. + + There are also send_multipart(), send_json(), send_pyobj() + + Three other methods for deactivating the callbacks: + + * **stop_on_recv():** + turn off the recv callback + * **stop_on_send():** + turn off the send callback + + which simply call ``on_<evt>(None)``. + + The entire socket interface, excluding direct recv methods, is also + provided, primarily through direct-linking the methods. + e.g. + + >>> stream.bind is stream.socket.bind + True + + """ + + socket = None + io_loop = None + poller = None + + def __init__(self, socket, io_loop=None): + self.socket = socket + self.io_loop = io_loop or IOLoop.instance() + self.poller = zmq.Poller() + + self._send_queue = Queue() + self._recv_callback = None + self._send_callback = None + self._close_callback = None + self._recv_copy = False + self._flushed = False + + self._state = self.io_loop.ERROR + self._init_io_state() + + # shortcircuit some socket methods + self.bind = self.socket.bind + self.bind_to_random_port = self.socket.bind_to_random_port + self.connect = self.socket.connect + self.setsockopt = self.socket.setsockopt + self.getsockopt = self.socket.getsockopt + self.setsockopt_string = self.socket.setsockopt_string + self.getsockopt_string = self.socket.getsockopt_string + self.setsockopt_unicode = self.socket.setsockopt_unicode + self.getsockopt_unicode = self.socket.getsockopt_unicode + + + def stop_on_recv(self): + """Disable callback and automatic receiving.""" + return self.on_recv(None) + + def stop_on_send(self): + """Disable callback on sending.""" + return self.on_send(None) + + def stop_on_err(self): + """DEPRECATED, does nothing""" + gen_log.warn("on_err does nothing, and will be removed") + + def on_err(self, callback): + """DEPRECATED, does nothing""" + gen_log.warn("on_err does nothing, and will be removed") + + def on_recv(self, callback, copy=True): + """Register a callback for when a message is ready to recv. + + There can be only one callback registered at a time, so each + call to `on_recv` replaces previously registered callbacks. + + on_recv(None) disables recv event polling. + + Use on_recv_stream(callback) instead, to register a callback that will receive + both this ZMQStream and the message, instead of just the message. + + Parameters + ---------- + + callback : callable + callback must take exactly one argument, which will be a + list, as returned by socket.recv_multipart() + if callback is None, recv callbacks are disabled. + copy : bool + copy is passed directly to recv, so if copy is False, + callback will receive Message objects. If copy is True, + then callback will receive bytes/str objects. + + Returns : None + """ + + self._check_closed() + assert callback is None or callable(callback) + self._recv_callback = stack_context.wrap(callback) + self._recv_copy = copy + if callback is None: + self._drop_io_state(self.io_loop.READ) + else: + self._add_io_state(self.io_loop.READ) + + def on_recv_stream(self, callback, copy=True): + """Same as on_recv, but callback will get this stream as first argument + + callback must take exactly two arguments, as it will be called as:: + + callback(stream, msg) + + Useful when a single callback should be used with multiple streams. + """ + if callback is None: + self.stop_on_recv() + else: + self.on_recv(lambda msg: callback(self, msg), copy=copy) + + def on_send(self, callback): + """Register a callback to be called on each send + + There will be two arguments:: + + callback(msg, status) + + * `msg` will be the list of sendable objects that was just sent + * `status` will be the return result of socket.send_multipart(msg) - + MessageTracker or None. + + Non-copying sends return a MessageTracker object whose + `done` attribute will be True when the send is complete. + This allows users to track when an object is safe to write to + again. + + The second argument will always be None if copy=True + on the send. + + Use on_send_stream(callback) to register a callback that will be passed + this ZMQStream as the first argument, in addition to the other two. + + on_send(None) disables recv event polling. + + Parameters + ---------- + + callback : callable + callback must take exactly two arguments, which will be + the message being sent (always a list), + and the return result of socket.send_multipart(msg) - + MessageTracker or None. + + if callback is None, send callbacks are disabled. + """ + + self._check_closed() + assert callback is None or callable(callback) + self._send_callback = stack_context.wrap(callback) + + + def on_send_stream(self, callback): + """Same as on_send, but callback will get this stream as first argument + + Callback will be passed three arguments:: + + callback(stream, msg, status) + + Useful when a single callback should be used with multiple streams. + """ + if callback is None: + self.stop_on_send() + else: + self.on_send(lambda msg, status: callback(self, msg, status)) + + + def send(self, msg, flags=0, copy=True, track=False, callback=None): + """Send a message, optionally also register a new callback for sends. + See zmq.socket.send for details. + """ + return self.send_multipart([msg], flags=flags, copy=copy, track=track, callback=callback) + + def send_multipart(self, msg, flags=0, copy=True, track=False, callback=None): + """Send a multipart message, optionally also register a new callback for sends. + See zmq.socket.send_multipart for details. + """ + kwargs = dict(flags=flags, copy=copy, track=track) + self._send_queue.put((msg, kwargs)) + callback = callback or self._send_callback + if callback is not None: + self.on_send(callback) + else: + # noop callback + self.on_send(lambda *args: None) + self._add_io_state(self.io_loop.WRITE) + + def send_string(self, u, flags=0, encoding='utf-8', callback=None): + """Send a unicode message with an encoding. + See zmq.socket.send_unicode for details. + """ + if not isinstance(u, basestring): + raise TypeError("unicode/str objects only") + return self.send(u.encode(encoding), flags=flags, callback=callback) + + send_unicode = send_string + + def send_json(self, obj, flags=0, callback=None): + """Send json-serialized version of an object. + See zmq.socket.send_json for details. + """ + if jsonapi is None: + raise ImportError('jsonlib{1,2}, json or simplejson library is required.') + else: + msg = jsonapi.dumps(obj) + return self.send(msg, flags=flags, callback=callback) + + def send_pyobj(self, obj, flags=0, protocol=-1, callback=None): + """Send a Python object as a message using pickle to serialize. + + See zmq.socket.send_json for details. + """ + msg = pickle.dumps(obj, protocol) + return self.send(msg, flags, callback=callback) + + def _finish_flush(self): + """callback for unsetting _flushed flag.""" + self._flushed = False + + def flush(self, flag=zmq.POLLIN|zmq.POLLOUT, limit=None): + """Flush pending messages. + + This method safely handles all pending incoming and/or outgoing messages, + bypassing the inner loop, passing them to the registered callbacks. + + A limit can be specified, to prevent blocking under high load. + + flush will return the first time ANY of these conditions are met: + * No more events matching the flag are pending. + * the total number of events handled reaches the limit. + + Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback + is registered, unlike normal IOLoop operation. This allows flush to be + used to remove *and ignore* incoming messages. + + Parameters + ---------- + flag : int, default=POLLIN|POLLOUT + 0MQ poll flags. + If flag|POLLIN, recv events will be flushed. + If flag|POLLOUT, send events will be flushed. + Both flags can be set at once, which is the default. + limit : None or int, optional + The maximum number of messages to send or receive. + Both send and recv count against this limit. + + Returns + ------- + int : count of events handled (both send and recv) + """ + self._check_closed() + # unset self._flushed, so callbacks will execute, in case flush has + # already been called this iteration + already_flushed = self._flushed + self._flushed = False + # initialize counters + count = 0 + def update_flag(): + """Update the poll flag, to prevent registering POLLOUT events + if we don't have pending sends.""" + return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT) + flag = update_flag() + if not flag: + # nothing to do + return 0 + self.poller.register(self.socket, flag) + events = self.poller.poll(0) + while events and (not limit or count < limit): + s,event = events[0] + if event & zmq.POLLIN: # receiving + self._handle_recv() + count += 1 + if self.socket is None: + # break if socket was closed during callback + break + if event & zmq.POLLOUT and self.sending(): + self._handle_send() + count += 1 + if self.socket is None: + # break if socket was closed during callback + break + + flag = update_flag() + if flag: + self.poller.register(self.socket, flag) + events = self.poller.poll(0) + else: + events = [] + if count: # only bypass loop if we actually flushed something + # skip send/recv callbacks this iteration + self._flushed = True + # reregister them at the end of the loop + if not already_flushed: # don't need to do it again + self.io_loop.add_callback(self._finish_flush) + elif already_flushed: + self._flushed = True + + # update ioloop poll state, which may have changed + self._rebuild_io_state() + return count + + def set_close_callback(self, callback): + """Call the given callback when the stream is closed.""" + self._close_callback = stack_context.wrap(callback) + + def close(self, linger=None): + """Close this stream.""" + if self.socket is not None: + self.io_loop.remove_handler(self.socket) + self.socket.close(linger) + self.socket = None + if self._close_callback: + self._run_callback(self._close_callback) + + def receiving(self): + """Returns True if we are currently receiving from the stream.""" + return self._recv_callback is not None + + def sending(self): + """Returns True if we are currently sending to the stream.""" + return not self._send_queue.empty() + + def closed(self): + return self.socket is None + + def _run_callback(self, callback, *args, **kwargs): + """Wrap running callbacks in try/except to allow us to + close our socket.""" + try: + # Use a NullContext to ensure that all StackContexts are run + # inside our blanket exception handler rather than outside. + with stack_context.NullContext(): + callback(*args, **kwargs) + except: + gen_log.error("Uncaught exception, closing connection.", + exc_info=True) + # Close the socket on an uncaught exception from a user callback + # (It would eventually get closed when the socket object is + # gc'd, but we don't want to rely on gc happening before we + # run out of file descriptors) + self.close() + # Re-raise the exception so that IOLoop.handle_callback_exception + # can see it and log the error + raise + + def _handle_events(self, fd, events): + """This method is the actual handler for IOLoop, that gets called whenever + an event on my socket is posted. It dispatches to _handle_recv, etc.""" + # print "handling events" + if not self.socket: + gen_log.warning("Got events for closed stream %s", fd) + return + try: + # dispatch events: + if events & IOLoop.ERROR: + gen_log.error("got POLLERR event on ZMQStream, which doesn't make sense") + return + if events & IOLoop.READ: + self._handle_recv() + if not self.socket: + return + if events & IOLoop.WRITE: + self._handle_send() + if not self.socket: + return + + # rebuild the poll state + self._rebuild_io_state() + except: + gen_log.error("Uncaught exception, closing connection.", + exc_info=True) + self.close() + raise + + def _handle_recv(self): + """Handle a recv event.""" + if self._flushed: + return + try: + msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy) + except zmq.ZMQError as e: + if e.errno == zmq.EAGAIN: + # state changed since poll event + pass + else: + gen_log.error("RECV Error: %s"%zmq.strerror(e.errno)) + else: + if self._recv_callback: + callback = self._recv_callback + # self._recv_callback = None + self._run_callback(callback, msg) + + # self.update_state() + + + def _handle_send(self): + """Handle a send event.""" + if self._flushed: + return + if not self.sending(): + gen_log.error("Shouldn't have handled a send event") + return + + msg, kwargs = self._send_queue.get() + try: + status = self.socket.send_multipart(msg, **kwargs) + except zmq.ZMQError as e: + gen_log.error("SEND Error: %s", e) + status = e + if self._send_callback: + callback = self._send_callback + self._run_callback(callback, msg, status) + + # self.update_state() + + def _check_closed(self): + if not self.socket: + raise IOError("Stream is closed") + + def _rebuild_io_state(self): + """rebuild io state based on self.sending() and receiving()""" + if self.socket is None: + return + state = self.io_loop.ERROR + if self.receiving(): + state |= self.io_loop.READ + if self.sending(): + state |= self.io_loop.WRITE + if state != self._state: + self._state = state + self._update_handler(state) + + def _add_io_state(self, state): + """Add io_state to poller.""" + if not self._state & state: + self._state = self._state | state + self._update_handler(self._state) + + def _drop_io_state(self, state): + """Stop poller from watching an io_state.""" + if self._state & state: + self._state = self._state & (~state) + self._update_handler(self._state) + + def _update_handler(self, state): + """Update IOLoop handler with state.""" + if self.socket is None: + return + self.io_loop.update_handler(self.socket, state) + + def _init_io_state(self): + """initialize the ioloop event handler""" + with stack_context.NullContext(): + self.io_loop.add_handler(self.socket, self._handle_events, self._state) + diff --git a/src/console/zmq/green/__init__.py b/src/console/zmq/green/__init__.py new file mode 100755 index 00000000..ff7e5965 --- /dev/null +++ b/src/console/zmq/green/__init__.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +#----------------------------------------------------------------------------- +# Copyright (C) 2011-2012 Travis Cline +# +# This file is part of pyzmq +# It is adapted from upstream project zeromq_gevent under the New BSD License +# +# Distributed under the terms of the New BSD License. The full license is in +# the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + +"""zmq.green - gevent compatibility with zeromq. + +Usage +----- + +Instead of importing zmq directly, do so in the following manner: + +.. + + import zmq.green as zmq + + +Any calls that would have blocked the current thread will now only block the +current green thread. + +This compatibility is accomplished by ensuring the nonblocking flag is set +before any blocking operation and the ØMQ file descriptor is polled internally +to trigger needed events. +""" + +from zmq import * +from zmq.green.core import _Context, _Socket +from zmq.green.poll import _Poller +Context = _Context +Socket = _Socket +Poller = _Poller + +from zmq.green.device import device + diff --git a/src/console/zmq/green/core.py b/src/console/zmq/green/core.py new file mode 100755 index 00000000..9fc73e32 --- /dev/null +++ b/src/console/zmq/green/core.py @@ -0,0 +1,287 @@ +#----------------------------------------------------------------------------- +# Copyright (C) 2011-2012 Travis Cline +# +# This file is part of pyzmq +# It is adapted from upstream project zeromq_gevent under the New BSD License +# +# Distributed under the terms of the New BSD License. The full license is in +# the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + +"""This module wraps the :class:`Socket` and :class:`Context` found in :mod:`pyzmq <zmq>` to be non blocking +""" + +from __future__ import print_function + +import sys +import time +import warnings + +import zmq + +from zmq import Context as _original_Context +from zmq import Socket as _original_Socket +from .poll import _Poller + +import gevent +from gevent.event import AsyncResult +from gevent.hub import get_hub + +if hasattr(zmq, 'RCVTIMEO'): + TIMEOS = (zmq.RCVTIMEO, zmq.SNDTIMEO) +else: + TIMEOS = () + +def _stop(evt): + """simple wrapper for stopping an Event, allowing for method rename in gevent 1.0""" + try: + evt.stop() + except AttributeError as e: + # gevent<1.0 compat + evt.cancel() + +class _Socket(_original_Socket): + """Green version of :class:`zmq.Socket` + + The following methods are overridden: + + * send + * recv + + To ensure that the ``zmq.NOBLOCK`` flag is set and that sending or receiving + is deferred to the hub if a ``zmq.EAGAIN`` (retry) error is raised. + + The `__state_changed` method is triggered when the zmq.FD for the socket is + marked as readable and triggers the necessary read and write events (which + are waited for in the recv and send methods). + + Some double underscore prefixes are used to minimize pollution of + :class:`zmq.Socket`'s namespace. + """ + __in_send_multipart = False + __in_recv_multipart = False + __writable = None + __readable = None + _state_event = None + _gevent_bug_timeout = 11.6 # timeout for not trusting gevent + _debug_gevent = False # turn on if you think gevent is missing events + _poller_class = _Poller + + def __init__(self, context, socket_type): + _original_Socket.__init__(self, context, socket_type) + self.__in_send_multipart = False + self.__in_recv_multipart = False + self.__setup_events() + + + def __del__(self): + self.close() + + def close(self, linger=None): + super(_Socket, self).close(linger) + self.__cleanup_events() + + def __cleanup_events(self): + # close the _state_event event, keeps the number of active file descriptors down + if getattr(self, '_state_event', None): + _stop(self._state_event) + self._state_event = None + # if the socket has entered a close state resume any waiting greenlets + self.__writable.set() + self.__readable.set() + + def __setup_events(self): + self.__readable = AsyncResult() + self.__writable = AsyncResult() + self.__readable.set() + self.__writable.set() + + try: + self._state_event = get_hub().loop.io(self.getsockopt(zmq.FD), 1) # read state watcher + self._state_event.start(self.__state_changed) + except AttributeError: + # for gevent<1.0 compatibility + from gevent.core import read_event + self._state_event = read_event(self.getsockopt(zmq.FD), self.__state_changed, persist=True) + + def __state_changed(self, event=None, _evtype=None): + if self.closed: + self.__cleanup_events() + return + try: + # avoid triggering __state_changed from inside __state_changed + events = super(_Socket, self).getsockopt(zmq.EVENTS) + except zmq.ZMQError as exc: + self.__writable.set_exception(exc) + self.__readable.set_exception(exc) + else: + if events & zmq.POLLOUT: + self.__writable.set() + if events & zmq.POLLIN: + self.__readable.set() + + def _wait_write(self): + assert self.__writable.ready(), "Only one greenlet can be waiting on this event" + self.__writable = AsyncResult() + # timeout is because libzmq cannot be trusted to properly signal a new send event: + # this is effectively a maximum poll interval of 1s + tic = time.time() + dt = self._gevent_bug_timeout + if dt: + timeout = gevent.Timeout(seconds=dt) + else: + timeout = None + try: + if timeout: + timeout.start() + self.__writable.get(block=True) + except gevent.Timeout as t: + if t is not timeout: + raise + toc = time.time() + # gevent bug: get can raise timeout even on clean return + # don't display zmq bug warning for gevent bug (this is getting ridiculous) + if self._debug_gevent and timeout and toc-tic > dt and \ + self.getsockopt(zmq.EVENTS) & zmq.POLLOUT: + print("BUG: gevent may have missed a libzmq send event on %i!" % self.FD, file=sys.stderr) + finally: + if timeout: + timeout.cancel() + self.__writable.set() + + def _wait_read(self): + assert self.__readable.ready(), "Only one greenlet can be waiting on this event" + self.__readable = AsyncResult() + # timeout is because libzmq cannot always be trusted to play nice with libevent. + # I can only confirm that this actually happens for send, but lets be symmetrical + # with our dirty hacks. + # this is effectively a maximum poll interval of 1s + tic = time.time() + dt = self._gevent_bug_timeout + if dt: + timeout = gevent.Timeout(seconds=dt) + else: + timeout = None + try: + if timeout: + timeout.start() + self.__readable.get(block=True) + except gevent.Timeout as t: + if t is not timeout: + raise + toc = time.time() + # gevent bug: get can raise timeout even on clean return + # don't display zmq bug warning for gevent bug (this is getting ridiculous) + if self._debug_gevent and timeout and toc-tic > dt and \ + self.getsockopt(zmq.EVENTS) & zmq.POLLIN: + print("BUG: gevent may have missed a libzmq recv event on %i!" % self.FD, file=sys.stderr) + finally: + if timeout: + timeout.cancel() + self.__readable.set() + + def send(self, data, flags=0, copy=True, track=False): + """send, which will only block current greenlet + + state_changed always fires exactly once (success or fail) at the + end of this method. + """ + + # if we're given the NOBLOCK flag act as normal and let the EAGAIN get raised + if flags & zmq.NOBLOCK: + try: + msg = super(_Socket, self).send(data, flags, copy, track) + finally: + if not self.__in_send_multipart: + self.__state_changed() + return msg + # ensure the zmq.NOBLOCK flag is part of flags + flags |= zmq.NOBLOCK + while True: # Attempt to complete this operation indefinitely, blocking the current greenlet + try: + # attempt the actual call + msg = super(_Socket, self).send(data, flags, copy, track) + except zmq.ZMQError as e: + # if the raised ZMQError is not EAGAIN, reraise + if e.errno != zmq.EAGAIN: + if not self.__in_send_multipart: + self.__state_changed() + raise + else: + if not self.__in_send_multipart: + self.__state_changed() + return msg + # defer to the event loop until we're notified the socket is writable + self._wait_write() + + def recv(self, flags=0, copy=True, track=False): + """recv, which will only block current greenlet + + state_changed always fires exactly once (success or fail) at the + end of this method. + """ + if flags & zmq.NOBLOCK: + try: + msg = super(_Socket, self).recv(flags, copy, track) + finally: + if not self.__in_recv_multipart: + self.__state_changed() + return msg + + flags |= zmq.NOBLOCK + while True: + try: + msg = super(_Socket, self).recv(flags, copy, track) + except zmq.ZMQError as e: + if e.errno != zmq.EAGAIN: + if not self.__in_recv_multipart: + self.__state_changed() + raise + else: + if not self.__in_recv_multipart: + self.__state_changed() + return msg + self._wait_read() + + def send_multipart(self, *args, **kwargs): + """wrap send_multipart to prevent state_changed on each partial send""" + self.__in_send_multipart = True + try: + msg = super(_Socket, self).send_multipart(*args, **kwargs) + finally: + self.__in_send_multipart = False + self.__state_changed() + return msg + + def recv_multipart(self, *args, **kwargs): + """wrap recv_multipart to prevent state_changed on each partial recv""" + self.__in_recv_multipart = True + try: + msg = super(_Socket, self).recv_multipart(*args, **kwargs) + finally: + self.__in_recv_multipart = False + self.__state_changed() + return msg + + def get(self, opt): + """trigger state_changed on getsockopt(EVENTS)""" + if opt in TIMEOS: + warnings.warn("TIMEO socket options have no effect in zmq.green", UserWarning) + optval = super(_Socket, self).get(opt) + if opt == zmq.EVENTS: + self.__state_changed() + return optval + + def set(self, opt, val): + """set socket option""" + if opt in TIMEOS: + warnings.warn("TIMEO socket options have no effect in zmq.green", UserWarning) + return super(_Socket, self).set(opt, val) + + +class _Context(_original_Context): + """Replacement for :class:`zmq.Context` + + Ensures that the greened Socket above is used in calls to `socket`. + """ + _socket_class = _Socket diff --git a/src/console/zmq/green/device.py b/src/console/zmq/green/device.py new file mode 100755 index 00000000..4b070237 --- /dev/null +++ b/src/console/zmq/green/device.py @@ -0,0 +1,32 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import zmq +from zmq.green import Poller + +def device(device_type, isocket, osocket): + """Start a zeromq device (gevent-compatible). + + Unlike the true zmq.device, this does not release the GIL. + + Parameters + ---------- + device_type : (QUEUE, FORWARDER, STREAMER) + The type of device to start (ignored). + isocket : Socket + The Socket instance for the incoming traffic. + osocket : Socket + The Socket instance for the outbound traffic. + """ + p = Poller() + if osocket == -1: + osocket = isocket + p.register(isocket, zmq.POLLIN) + p.register(osocket, zmq.POLLIN) + + while True: + events = dict(p.poll()) + if isocket in events: + osocket.send_multipart(isocket.recv_multipart()) + if osocket in events: + isocket.send_multipart(osocket.recv_multipart()) diff --git a/src/console/zmq/green/eventloop/__init__.py b/src/console/zmq/green/eventloop/__init__.py new file mode 100755 index 00000000..c5150efe --- /dev/null +++ b/src/console/zmq/green/eventloop/__init__.py @@ -0,0 +1,3 @@ +from zmq.green.eventloop.ioloop import IOLoop + +__all__ = ['IOLoop']
\ No newline at end of file diff --git a/src/console/zmq/green/eventloop/ioloop.py b/src/console/zmq/green/eventloop/ioloop.py new file mode 100755 index 00000000..e12fd5e9 --- /dev/null +++ b/src/console/zmq/green/eventloop/ioloop.py @@ -0,0 +1,33 @@ +from zmq.eventloop.ioloop import * +from zmq.green import Poller + +RealIOLoop = IOLoop +RealZMQPoller = ZMQPoller + +class IOLoop(RealIOLoop): + + def initialize(self, impl=None): + impl = _poll() if impl is None else impl + super(IOLoop, self).initialize(impl) + + @staticmethod + def instance(): + """Returns a global `IOLoop` instance. + + Most applications have a single, global `IOLoop` running on the + main thread. Use this method to get this instance from + another thread. To get the current thread's `IOLoop`, use `current()`. + """ + # install this class as the active IOLoop implementation + # when using tornado 3 + if tornado_version >= (3,): + PollIOLoop.configure(IOLoop) + return PollIOLoop.instance() + + +class ZMQPoller(RealZMQPoller): + """gevent-compatible version of ioloop.ZMQPoller""" + def __init__(self): + self._poller = Poller() + +_poll = ZMQPoller diff --git a/src/console/zmq/green/eventloop/zmqstream.py b/src/console/zmq/green/eventloop/zmqstream.py new file mode 100755 index 00000000..90fbd1f5 --- /dev/null +++ b/src/console/zmq/green/eventloop/zmqstream.py @@ -0,0 +1,11 @@ +from zmq.eventloop.zmqstream import * + +from zmq.green.eventloop.ioloop import IOLoop + +RealZMQStream = ZMQStream + +class ZMQStream(RealZMQStream): + + def __init__(self, socket, io_loop=None): + io_loop = io_loop or IOLoop.instance() + super(ZMQStream, self).__init__(socket, io_loop=io_loop) diff --git a/src/console/zmq/green/poll.py b/src/console/zmq/green/poll.py new file mode 100755 index 00000000..8f016129 --- /dev/null +++ b/src/console/zmq/green/poll.py @@ -0,0 +1,95 @@ +import zmq +import gevent +from gevent import select + +from zmq import Poller as _original_Poller + + +class _Poller(_original_Poller): + """Replacement for :class:`zmq.Poller` + + Ensures that the greened Poller below is used in calls to + :meth:`zmq.Poller.poll`. + """ + _gevent_bug_timeout = 1.33 # minimum poll interval, for working around gevent bug + + def _get_descriptors(self): + """Returns three elements tuple with socket descriptors ready + for gevent.select.select + """ + rlist = [] + wlist = [] + xlist = [] + + for socket, flags in self.sockets: + if isinstance(socket, zmq.Socket): + rlist.append(socket.getsockopt(zmq.FD)) + continue + elif isinstance(socket, int): + fd = socket + elif hasattr(socket, 'fileno'): + try: + fd = int(socket.fileno()) + except: + raise ValueError('fileno() must return an valid integer fd') + else: + raise TypeError('Socket must be a 0MQ socket, an integer fd ' + 'or have a fileno() method: %r' % socket) + + if flags & zmq.POLLIN: + rlist.append(fd) + if flags & zmq.POLLOUT: + wlist.append(fd) + if flags & zmq.POLLERR: + xlist.append(fd) + + return (rlist, wlist, xlist) + + def poll(self, timeout=-1): + """Overridden method to ensure that the green version of + Poller is used. + + Behaves the same as :meth:`zmq.core.Poller.poll` + """ + + if timeout is None: + timeout = -1 + + if timeout < 0: + timeout = -1 + + rlist = None + wlist = None + xlist = None + + if timeout > 0: + tout = gevent.Timeout.start_new(timeout/1000.0) + + try: + # Loop until timeout or events available + rlist, wlist, xlist = self._get_descriptors() + while True: + events = super(_Poller, self).poll(0) + if events or timeout == 0: + return events + + # wait for activity on sockets in a green way + # set a minimum poll frequency, + # because gevent < 1.0 cannot be trusted to catch edge-triggered FD events + _bug_timeout = gevent.Timeout.start_new(self._gevent_bug_timeout) + try: + select.select(rlist, wlist, xlist) + except gevent.Timeout as t: + if t is not _bug_timeout: + raise + finally: + _bug_timeout.cancel() + + except gevent.Timeout as t: + if t is not tout: + raise + return [] + finally: + if timeout > 0: + tout.cancel() + diff --git a/src/console/zmq/log/__init__.py b/src/console/zmq/log/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/console/zmq/log/__init__.py diff --git a/src/console/zmq/log/handlers.py b/src/console/zmq/log/handlers.py new file mode 100755 index 00000000..5ff21bf3 --- /dev/null +++ b/src/console/zmq/log/handlers.py @@ -0,0 +1,146 @@ +"""pyzmq logging handlers. + +This mainly defines the PUBHandler object for publishing logging messages over +a zmq.PUB socket. + +The PUBHandler can be used with the regular logging module, as in:: + + >>> import logging + >>> handler = PUBHandler('tcp://127.0.0.1:12345') + >>> handler.root_topic = 'foo' + >>> logger = logging.getLogger('foobar') + >>> logger.setLevel(logging.DEBUG) + >>> logger.addHandler(handler) + +After this point, all messages logged by ``logger`` will be published on the +PUB socket. + +Code adapted from StarCluster: + + http://github.com/jtriley/StarCluster/blob/master/starcluster/logger.py +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import logging +from logging import INFO, DEBUG, WARN, ERROR, FATAL + +import zmq +from zmq.utils.strtypes import bytes, unicode, cast_bytes + + +TOPIC_DELIM="::" # delimiter for splitting topics on the receiving end. + + +class PUBHandler(logging.Handler): + """A basic logging handler that emits log messages through a PUB socket. + + Takes a PUB socket already bound to interfaces or an interface to bind to. + + Example:: + + sock = context.socket(zmq.PUB) + sock.bind('inproc://log') + handler = PUBHandler(sock) + + Or:: + + handler = PUBHandler('inproc://loc') + + These are equivalent. + + Log messages handled by this handler are broadcast with ZMQ topics + ``this.root_topic`` comes first, followed by the log level + (DEBUG,INFO,etc.), followed by any additional subtopics specified in the + message by: log.debug("subtopic.subsub::the real message") + """ + root_topic="" + socket = None + + formatters = { + logging.DEBUG: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n"), + logging.INFO: logging.Formatter("%(message)s\n"), + logging.WARN: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n"), + logging.ERROR: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s - %(exc_info)s\n"), + logging.CRITICAL: logging.Formatter( + "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n")} + + def __init__(self, interface_or_socket, context=None): + logging.Handler.__init__(self) + if isinstance(interface_or_socket, zmq.Socket): + self.socket = interface_or_socket + self.ctx = self.socket.context + else: + self.ctx = context or zmq.Context() + self.socket = self.ctx.socket(zmq.PUB) + self.socket.bind(interface_or_socket) + + def format(self,record): + """Format a record.""" + return self.formatters[record.levelno].format(record) + + def emit(self, record): + """Emit a log message on my socket.""" + try: + topic, record.msg = record.msg.split(TOPIC_DELIM,1) + except Exception: + topic = "" + try: + bmsg = cast_bytes(self.format(record)) + except Exception: + self.handleError(record) + return + + topic_list = [] + + if self.root_topic: + topic_list.append(self.root_topic) + + topic_list.append(record.levelname) + + if topic: + topic_list.append(topic) + + btopic = b'.'.join(cast_bytes(t) for t in topic_list) + + self.socket.send_multipart([btopic, bmsg]) + + +class TopicLogger(logging.Logger): + """A simple wrapper that takes an additional argument to log methods. + + All the regular methods exist, but instead of one msg argument, two + arguments: topic, msg are passed. + + That is:: + + logger.debug('msg') + + Would become:: + + logger.debug('topic.sub', 'msg') + """ + def log(self, level, topic, msg, *args, **kwargs): + """Log 'msg % args' with level and topic. + + To pass exception information, use the keyword argument exc_info + with a True value:: + + logger.log(level, "zmq.fun", "We have a %s", + "mysterious problem", exc_info=1) + """ + logging.Logger.log(self, level, '%s::%s'%(topic,msg), *args, **kwargs) + +# Generate the methods of TopicLogger, since they are just adding a +# topic prefix to a message. +for name in "debug warn warning error critical fatal".split(): + meth = getattr(logging.Logger,name) + setattr(TopicLogger, name, + lambda self, level, topic, msg, *args, **kwargs: + meth(self, level, topic+TOPIC_DELIM+msg,*args, **kwargs)) + diff --git a/src/console/zmq/ssh/__init__.py b/src/console/zmq/ssh/__init__.py new file mode 100755 index 00000000..57f09568 --- /dev/null +++ b/src/console/zmq/ssh/__init__.py @@ -0,0 +1 @@ +from zmq.ssh.tunnel import * diff --git a/src/console/zmq/ssh/forward.py b/src/console/zmq/ssh/forward.py new file mode 100755 index 00000000..2d619462 --- /dev/null +++ b/src/console/zmq/ssh/forward.py @@ -0,0 +1,91 @@ +# +# This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1. +# Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com> +# Edits Copyright (C) 2010 The IPython Team +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA. + +""" +Sample script showing how to do local port forwarding over paramiko. + +This script connects to the requested SSH server and sets up local port +forwarding (the openssh -L option) from a local port through a tunneled +connection to a destination reachable from the SSH server machine. +""" + +from __future__ import print_function + +import logging +import select +try: # Python 3 + import socketserver +except ImportError: # Python 2 + import SocketServer as socketserver + +logger = logging.getLogger('ssh') + +class ForwardServer (socketserver.ThreadingTCPServer): + daemon_threads = True + allow_reuse_address = True + + +class Handler (socketserver.BaseRequestHandler): + + def handle(self): + try: + chan = self.ssh_transport.open_channel('direct-tcpip', + (self.chain_host, self.chain_port), + self.request.getpeername()) + except Exception as e: + logger.debug('Incoming request to %s:%d failed: %s' % (self.chain_host, + self.chain_port, + repr(e))) + return + if chan is None: + logger.debug('Incoming request to %s:%d was rejected by the SSH server.' % + (self.chain_host, self.chain_port)) + return + + logger.debug('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(), + chan.getpeername(), (self.chain_host, self.chain_port))) + while True: + r, w, x = select.select([self.request, chan], [], []) + if self.request in r: + data = self.request.recv(1024) + if len(data) == 0: + break + chan.send(data) + if chan in r: + data = chan.recv(1024) + if len(data) == 0: + break + self.request.send(data) + chan.close() + self.request.close() + logger.debug('Tunnel closed ') + + +def forward_tunnel(local_port, remote_host, remote_port, transport): + # this is a little convoluted, but lets me configure things for the Handler + # object. (SocketServer doesn't give Handlers any way to access the outer + # server normally.) + class SubHander (Handler): + chain_host = remote_host + chain_port = remote_port + ssh_transport = transport + ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever() + + +__all__ = ['forward_tunnel'] diff --git a/src/console/zmq/ssh/tunnel.py b/src/console/zmq/ssh/tunnel.py new file mode 100755 index 00000000..5a0c5433 --- /dev/null +++ b/src/console/zmq/ssh/tunnel.py @@ -0,0 +1,376 @@ +"""Basic ssh tunnel utilities, and convenience functions for tunneling +zeromq connections. +""" + +# Copyright (C) 2010-2011 IPython Development Team +# Copyright (C) 2011- PyZMQ Developers +# +# Redistributed from IPython under the terms of the BSD License. + + +from __future__ import print_function + +import atexit +import os +import signal +import socket +import sys +import warnings +from getpass import getpass, getuser +from multiprocessing import Process + +try: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + import paramiko + SSHException = paramiko.ssh_exception.SSHException +except ImportError: + paramiko = None + class SSHException(Exception): + pass +else: + from .forward import forward_tunnel + +try: + import pexpect +except ImportError: + pexpect = None + + +_random_ports = set() + +def select_random_ports(n): + """Selects and return n random ports that are available.""" + ports = [] + for i in range(n): + sock = socket.socket() + sock.bind(('', 0)) + while sock.getsockname()[1] in _random_ports: + sock.close() + sock = socket.socket() + sock.bind(('', 0)) + ports.append(sock) + for i, sock in enumerate(ports): + port = sock.getsockname()[1] + sock.close() + ports[i] = port + _random_ports.add(port) + return ports + + +#----------------------------------------------------------------------------- +# Check for passwordless login +#----------------------------------------------------------------------------- + +def try_passwordless_ssh(server, keyfile, paramiko=None): + """Attempt to make an ssh connection without a password. + This is mainly used for requiring password input only once + when many tunnels may be connected to the same server. + + If paramiko is None, the default for the platform is chosen. + """ + if paramiko is None: + paramiko = sys.platform == 'win32' + if not paramiko: + f = _try_passwordless_openssh + else: + f = _try_passwordless_paramiko + return f(server, keyfile) + +def _try_passwordless_openssh(server, keyfile): + """Try passwordless login with shell ssh command.""" + if pexpect is None: + raise ImportError("pexpect unavailable, use paramiko") + cmd = 'ssh -f '+ server + if keyfile: + cmd += ' -i ' + keyfile + cmd += ' exit' + + # pop SSH_ASKPASS from env + env = os.environ.copy() + env.pop('SSH_ASKPASS', None) + + ssh_newkey = 'Are you sure you want to continue connecting' + p = pexpect.spawn(cmd, env=env) + while True: + try: + i = p.expect([ssh_newkey, '[Pp]assword:'], timeout=.1) + if i==0: + raise SSHException('The authenticity of the host can\'t be established.') + except pexpect.TIMEOUT: + continue + except pexpect.EOF: + return True + else: + return False + +def _try_passwordless_paramiko(server, keyfile): + """Try passwordless login with paramiko.""" + if paramiko is None: + msg = "Paramiko unavaliable, " + if sys.platform == 'win32': + msg += "Paramiko is required for ssh tunneled connections on Windows." + else: + msg += "use OpenSSH." + raise ImportError(msg) + username, server, port = _split_server(server) + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + try: + client.connect(server, port, username=username, key_filename=keyfile, + look_for_keys=True) + except paramiko.AuthenticationException: + return False + else: + client.close() + return True + + +def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramiko=None, timeout=60): + """Connect a socket to an address via an ssh tunnel. + + This is a wrapper for socket.connect(addr), when addr is not accessible + from the local machine. It simply creates an ssh tunnel using the remaining args, + and calls socket.connect('tcp://localhost:lport') where lport is the randomly + selected local port of the tunnel. + + """ + new_url, tunnel = open_tunnel(addr, server, keyfile=keyfile, password=password, paramiko=paramiko, timeout=timeout) + socket.connect(new_url) + return tunnel + + +def open_tunnel(addr, server, keyfile=None, password=None, paramiko=None, timeout=60): + """Open a tunneled connection from a 0MQ url. + + For use inside tunnel_connection. + + Returns + ------- + + (url, tunnel) : (str, object) + The 0MQ url that has been forwarded, and the tunnel object + """ + + lport = select_random_ports(1)[0] + transport, addr = addr.split('://') + ip,rport = addr.split(':') + rport = int(rport) + if paramiko is None: + paramiko = sys.platform == 'win32' + if paramiko: + tunnelf = paramiko_tunnel + else: + tunnelf = openssh_tunnel + + tunnel = tunnelf(lport, rport, server, remoteip=ip, keyfile=keyfile, password=password, timeout=timeout) + return 'tcp://127.0.0.1:%i'%lport, tunnel + +def openssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60): + """Create an ssh tunnel using command-line ssh that connects port lport + on this machine to localhost:rport on server. The tunnel + will automatically close when not in use, remaining open + for a minimum of timeout seconds for an initial connection. + + This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`, + as seen from `server`. + + keyfile and password may be specified, but ssh config is checked for defaults. + + Parameters + ---------- + + lport : int + local port for connecting to the tunnel from this machine. + rport : int + port on the remote machine to connect to. + server : str + The ssh server to connect to. The full ssh server string will be parsed. + user@server:port + remoteip : str [Default: 127.0.0.1] + The remote ip, specifying the destination of the tunnel. + Default is localhost, which means that the tunnel would redirect + localhost:lport on this machine to localhost:rport on the *server*. + + keyfile : str; path to public key file + This specifies a key to be used in ssh login, default None. + Regular default ssh keys will be used without specifying this argument. + password : str; + Your ssh password to the ssh server. Note that if this is left None, + you will be prompted for it if passwordless key based login is unavailable. + timeout : int [default: 60] + The time (in seconds) after which no activity will result in the tunnel + closing. This prevents orphaned tunnels from running forever. + """ + if pexpect is None: + raise ImportError("pexpect unavailable, use paramiko_tunnel") + ssh="ssh " + if keyfile: + ssh += "-i " + keyfile + + if ':' in server: + server, port = server.split(':') + ssh += " -p %s" % port + + cmd = "%s -O check %s" % (ssh, server) + (output, exitstatus) = pexpect.run(cmd, withexitstatus=True) + if not exitstatus: + pid = int(output[output.find("(pid=")+5:output.find(")")]) + cmd = "%s -O forward -L 127.0.0.1:%i:%s:%i %s" % ( + ssh, lport, remoteip, rport, server) + (output, exitstatus) = pexpect.run(cmd, withexitstatus=True) + if not exitstatus: + atexit.register(_stop_tunnel, cmd.replace("-O forward", "-O cancel", 1)) + return pid + cmd = "%s -f -S none -L 127.0.0.1:%i:%s:%i %s sleep %i" % ( + ssh, lport, remoteip, rport, server, timeout) + + # pop SSH_ASKPASS from env + env = os.environ.copy() + env.pop('SSH_ASKPASS', None) + + ssh_newkey = 'Are you sure you want to continue connecting' + tunnel = pexpect.spawn(cmd, env=env) + failed = False + while True: + try: + i = tunnel.expect([ssh_newkey, '[Pp]assword:'], timeout=.1) + if i==0: + raise SSHException('The authenticity of the host can\'t be established.') + except pexpect.TIMEOUT: + continue + except pexpect.EOF: + if tunnel.exitstatus: + print(tunnel.exitstatus) + print(tunnel.before) + print(tunnel.after) + raise RuntimeError("tunnel '%s' failed to start"%(cmd)) + else: + return tunnel.pid + else: + if failed: + print("Password rejected, try again") + password=None + if password is None: + password = getpass("%s's password: "%(server)) + tunnel.sendline(password) + failed = True + +def _stop_tunnel(cmd): + pexpect.run(cmd) + +def _split_server(server): + if '@' in server: + username,server = server.split('@', 1) + else: + username = getuser() + if ':' in server: + server, port = server.split(':') + port = int(port) + else: + port = 22 + return username, server, port + +def paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60): + """launch a tunner with paramiko in a subprocess. This should only be used + when shell ssh is unavailable (e.g. Windows). + + This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`, + as seen from `server`. + + If you are familiar with ssh tunnels, this creates the tunnel: + + ssh server -L localhost:lport:remoteip:rport + + keyfile and password may be specified, but ssh config is checked for defaults. + + + Parameters + ---------- + + lport : int + local port for connecting to the tunnel from this machine. + rport : int + port on the remote machine to connect to. + server : str + The ssh server to connect to. The full ssh server string will be parsed. + user@server:port + remoteip : str [Default: 127.0.0.1] + The remote ip, specifying the destination of the tunnel. + Default is localhost, which means that the tunnel would redirect + localhost:lport on this machine to localhost:rport on the *server*. + + keyfile : str; path to public key file + This specifies a key to be used in ssh login, default None. + Regular default ssh keys will be used without specifying this argument. + password : str; + Your ssh password to the ssh server. Note that if this is left None, + you will be prompted for it if passwordless key based login is unavailable. + timeout : int [default: 60] + The time (in seconds) after which no activity will result in the tunnel + closing. This prevents orphaned tunnels from running forever. + + """ + if paramiko is None: + raise ImportError("Paramiko not available") + + if password is None: + if not _try_passwordless_paramiko(server, keyfile): + password = getpass("%s's password: "%(server)) + + p = Process(target=_paramiko_tunnel, + args=(lport, rport, server, remoteip), + kwargs=dict(keyfile=keyfile, password=password)) + p.daemon=False + p.start() + atexit.register(_shutdown_process, p) + return p + +def _shutdown_process(p): + if p.is_alive(): + p.terminate() + +def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None): + """Function for actually starting a paramiko tunnel, to be passed + to multiprocessing.Process(target=this), and not called directly. + """ + username, server, port = _split_server(server) + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + + try: + client.connect(server, port, username=username, key_filename=keyfile, + look_for_keys=True, password=password) +# except paramiko.AuthenticationException: +# if password is None: +# password = getpass("%s@%s's password: "%(username, server)) +# client.connect(server, port, username=username, password=password) +# else: +# raise + except Exception as e: + print('*** Failed to connect to %s:%d: %r' % (server, port, e)) + sys.exit(1) + + # Don't let SIGINT kill the tunnel subprocess + signal.signal(signal.SIGINT, signal.SIG_IGN) + + try: + forward_tunnel(lport, remoteip, rport, client.get_transport()) + except KeyboardInterrupt: + print('SIGINT: Port forwarding stopped cleanly') + sys.exit(0) + except Exception as e: + print("Port forwarding stopped uncleanly: %s"%e) + sys.exit(255) + +if sys.platform == 'win32': + ssh_tunnel = paramiko_tunnel +else: + ssh_tunnel = openssh_tunnel + + +__all__ = ['tunnel_connection', 'ssh_tunnel', 'openssh_tunnel', 'paramiko_tunnel', 'try_passwordless_ssh'] + + diff --git a/src/console/zmq/sugar/__init__.py b/src/console/zmq/sugar/__init__.py new file mode 100755 index 00000000..d0510a44 --- /dev/null +++ b/src/console/zmq/sugar/__init__.py @@ -0,0 +1,27 @@ +"""pure-Python sugar wrappers for core 0MQ objects.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from zmq.sugar import ( + constants, context, frame, poll, socket, tracker, version +) +from zmq import error + +__all__ = ['constants'] +for submod in ( + constants, context, error, frame, poll, socket, tracker, version +): + __all__.extend(submod.__all__) + +from zmq.error import * +from zmq.sugar.context import * +from zmq.sugar.tracker import * +from zmq.sugar.socket import * +from zmq.sugar.constants import * +from zmq.sugar.frame import * +from zmq.sugar.poll import * +# from zmq.sugar.stopwatch import * +# from zmq.sugar._device import * +from zmq.sugar.version import * diff --git a/src/console/zmq/sugar/attrsettr.py b/src/console/zmq/sugar/attrsettr.py new file mode 100755 index 00000000..4bbd36d6 --- /dev/null +++ b/src/console/zmq/sugar/attrsettr.py @@ -0,0 +1,52 @@ +# coding: utf-8 +"""Mixin for mapping set/getattr to self.set/get""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from . import constants + +class AttributeSetter(object): + + def __setattr__(self, key, value): + """set zmq options by attribute""" + + # regular setattr only allowed for class-defined attributes + for obj in [self] + self.__class__.mro(): + if key in obj.__dict__: + object.__setattr__(self, key, value) + return + + upper_key = key.upper() + try: + opt = getattr(constants, upper_key) + except AttributeError: + raise AttributeError("%s has no such option: %s" % ( + self.__class__.__name__, upper_key) + ) + else: + self._set_attr_opt(upper_key, opt, value) + + def _set_attr_opt(self, name, opt, value): + """override if setattr should do something other than call self.set""" + self.set(opt, value) + + def __getattr__(self, key): + """get zmq options by attribute""" + upper_key = key.upper() + try: + opt = getattr(constants, upper_key) + except AttributeError: + raise AttributeError("%s has no such option: %s" % ( + self.__class__.__name__, upper_key) + ) + else: + return self._get_attr_opt(upper_key, opt) + + def _get_attr_opt(self, name, opt): + """override if getattr should do something other than call self.get""" + return self.get(opt) + + +__all__ = ['AttributeSetter'] diff --git a/src/console/zmq/sugar/constants.py b/src/console/zmq/sugar/constants.py new file mode 100755 index 00000000..88281176 --- /dev/null +++ b/src/console/zmq/sugar/constants.py @@ -0,0 +1,98 @@ +"""0MQ Constants.""" + +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +from zmq.backend import constants +from zmq.utils.constant_names import ( + base_names, + switched_sockopt_names, + int_sockopt_names, + int64_sockopt_names, + bytes_sockopt_names, + fd_sockopt_names, + ctx_opt_names, + msg_opt_names, +) + +#----------------------------------------------------------------------------- +# Python module level constants +#----------------------------------------------------------------------------- + +__all__ = [ + 'int_sockopts', + 'int64_sockopts', + 'bytes_sockopts', + 'ctx_opts', + 'ctx_opt_names', + ] + +int_sockopts = set() +int64_sockopts = set() +bytes_sockopts = set() +fd_sockopts = set() +ctx_opts = set() +msg_opts = set() + + +if constants.VERSION < 30000: + int64_sockopt_names.extend(switched_sockopt_names) +else: + int_sockopt_names.extend(switched_sockopt_names) + +_UNDEFINED = -9999 + +def _add_constant(name, container=None): + """add a constant to be defined + + optionally add it to one of the sets for use in get/setopt checkers + """ + c = getattr(constants, name, _UNDEFINED) + if c == _UNDEFINED: + return + globals()[name] = c + __all__.append(name) + if container is not None: + container.add(c) + return c + +for name in base_names: + _add_constant(name) + +for name in int_sockopt_names: + _add_constant(name, int_sockopts) + +for name in int64_sockopt_names: + _add_constant(name, int64_sockopts) + +for name in bytes_sockopt_names: + _add_constant(name, bytes_sockopts) + +for name in fd_sockopt_names: + _add_constant(name, fd_sockopts) + +for name in ctx_opt_names: + _add_constant(name, ctx_opts) + +for name in msg_opt_names: + _add_constant(name, msg_opts) + +# ensure some aliases are always defined +aliases = [ + ('DONTWAIT', 'NOBLOCK'), + ('XREQ', 'DEALER'), + ('XREP', 'ROUTER'), +] +for group in aliases: + undefined = set() + found = None + for name in group: + value = getattr(constants, name, -1) + if value != -1: + found = value + else: + undefined.add(name) + if found is not None: + for name in undefined: + globals()[name] = found + __all__.append(name) diff --git a/src/console/zmq/sugar/context.py b/src/console/zmq/sugar/context.py new file mode 100755 index 00000000..86a9c5dc --- /dev/null +++ b/src/console/zmq/sugar/context.py @@ -0,0 +1,192 @@ +# coding: utf-8 +"""Python bindings for 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import atexit +import weakref + +from zmq.backend import Context as ContextBase +from . import constants +from .attrsettr import AttributeSetter +from .constants import ENOTSUP, ctx_opt_names +from .socket import Socket +from zmq.error import ZMQError + +from zmq.utils.interop import cast_int_addr + + +class Context(ContextBase, AttributeSetter): + """Create a zmq Context + + A zmq Context creates sockets via its ``ctx.socket`` method. + """ + sockopts = None + _instance = None + _shadow = False + _exiting = False + + def __init__(self, io_threads=1, **kwargs): + super(Context, self).__init__(io_threads=io_threads, **kwargs) + if kwargs.get('shadow', False): + self._shadow = True + else: + self._shadow = False + self.sockopts = {} + + self._exiting = False + if not self._shadow: + ctx_ref = weakref.ref(self) + def _notify_atexit(): + ctx = ctx_ref() + if ctx is not None: + ctx._exiting = True + atexit.register(_notify_atexit) + + def __del__(self): + """deleting a Context should terminate it, without trying non-threadsafe destroy""" + if not self._shadow and not self._exiting: + self.term() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.term() + + @classmethod + def shadow(cls, address): + """Shadow an existing libzmq context + + address is the integer address of the libzmq context + or an FFI pointer to it. + + .. versionadded:: 14.1 + """ + address = cast_int_addr(address) + return cls(shadow=address) + + @classmethod + def shadow_pyczmq(cls, ctx): + """Shadow an existing pyczmq context + + ctx is the FFI `zctx_t *` pointer + + .. versionadded:: 14.1 + """ + from pyczmq import zctx + + underlying = zctx.underlying(ctx) + address = cast_int_addr(underlying) + return cls(shadow=address) + + # static method copied from tornado IOLoop.instance + @classmethod + def instance(cls, io_threads=1): + """Returns a global Context instance. + + Most single-threaded applications have a single, global Context. + Use this method instead of passing around Context instances + throughout your code. + + A common pattern for classes that depend on Contexts is to use + a default argument to enable programs with multiple Contexts + but not require the argument for simpler applications: + + class MyClass(object): + def __init__(self, context=None): + self.context = context or Context.instance() + """ + if cls._instance is None or cls._instance.closed: + cls._instance = cls(io_threads=io_threads) + return cls._instance + + #------------------------------------------------------------------------- + # Hooks for ctxopt completion + #------------------------------------------------------------------------- + + def __dir__(self): + keys = dir(self.__class__) + + for collection in ( + ctx_opt_names, + ): + keys.extend(collection) + return keys + + #------------------------------------------------------------------------- + # Creating Sockets + #------------------------------------------------------------------------- + + @property + def _socket_class(self): + return Socket + + def socket(self, socket_type): + """Create a Socket associated with this Context. + + Parameters + ---------- + socket_type : int + The socket type, which can be any of the 0MQ socket types: + REQ, REP, PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, etc. + """ + if self.closed: + raise ZMQError(ENOTSUP) + s = self._socket_class(self, socket_type) + for opt, value in self.sockopts.items(): + try: + s.setsockopt(opt, value) + except ZMQError: + # ignore ZMQErrors, which are likely for socket options + # that do not apply to a particular socket type, e.g. + # SUBSCRIBE for non-SUB sockets. + pass + return s + + def setsockopt(self, opt, value): + """set default socket options for new sockets created by this Context + + .. versionadded:: 13.0 + """ + self.sockopts[opt] = value + + def getsockopt(self, opt): + """get default socket options for new sockets created by this Context + + .. versionadded:: 13.0 + """ + return self.sockopts[opt] + + def _set_attr_opt(self, name, opt, value): + """set default sockopts as attributes""" + if name in constants.ctx_opt_names: + return self.set(opt, value) + else: + self.sockopts[opt] = value + + def _get_attr_opt(self, name, opt): + """get default sockopts as attributes""" + if name in constants.ctx_opt_names: + return self.get(opt) + else: + if opt not in self.sockopts: + raise AttributeError(name) + else: + return self.sockopts[opt] + + def __delattr__(self, key): + """delete default sockopts as attributes""" + key = key.upper() + try: + opt = getattr(constants, key) + except AttributeError: + raise AttributeError("no such socket option: %s" % key) + else: + if opt not in self.sockopts: + raise AttributeError(key) + else: + del self.sockopts[opt] + +__all__ = ['Context'] diff --git a/src/console/zmq/sugar/frame.py b/src/console/zmq/sugar/frame.py new file mode 100755 index 00000000..9f556c86 --- /dev/null +++ b/src/console/zmq/sugar/frame.py @@ -0,0 +1,19 @@ +# coding: utf-8 +"""0MQ Frame pure Python methods.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from .attrsettr import AttributeSetter +from zmq.backend import Frame as FrameBase + + +class Frame(FrameBase, AttributeSetter): + def __getitem__(self, key): + # map Frame['User-Id'] to Frame.get('User-Id') + return self.get(key) + +# keep deprecated alias +Message = Frame +__all__ = ['Frame', 'Message']
\ No newline at end of file diff --git a/src/console/zmq/sugar/poll.py b/src/console/zmq/sugar/poll.py new file mode 100755 index 00000000..c7b1d1bb --- /dev/null +++ b/src/console/zmq/sugar/poll.py @@ -0,0 +1,161 @@ +"""0MQ polling related functions and classes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import zmq +from zmq.backend import zmq_poll +from .constants import POLLIN, POLLOUT, POLLERR + +#----------------------------------------------------------------------------- +# Polling related methods +#----------------------------------------------------------------------------- + + +class Poller(object): + """A stateful poll interface that mirrors Python's built-in poll.""" + sockets = None + _map = {} + + def __init__(self): + self.sockets = [] + self._map = {} + + def __contains__(self, socket): + return socket in self._map + + def register(self, socket, flags=POLLIN|POLLOUT): + """p.register(socket, flags=POLLIN|POLLOUT) + + Register a 0MQ socket or native fd for I/O monitoring. + + register(s,0) is equivalent to unregister(s). + + Parameters + ---------- + socket : zmq.Socket or native socket + A zmq.Socket or any Python object having a ``fileno()`` + method that returns a valid file descriptor. + flags : int + The events to watch for. Can be POLLIN, POLLOUT or POLLIN|POLLOUT. + If `flags=0`, socket will be unregistered. + """ + if flags: + if socket in self._map: + idx = self._map[socket] + self.sockets[idx] = (socket, flags) + else: + idx = len(self.sockets) + self.sockets.append((socket, flags)) + self._map[socket] = idx + elif socket in self._map: + # uregister sockets registered with no events + self.unregister(socket) + else: + # ignore new sockets with no events + pass + + def modify(self, socket, flags=POLLIN|POLLOUT): + """Modify the flags for an already registered 0MQ socket or native fd.""" + self.register(socket, flags) + + def unregister(self, socket): + """Remove a 0MQ socket or native fd for I/O monitoring. + + Parameters + ---------- + socket : Socket + The socket instance to stop polling. + """ + idx = self._map.pop(socket) + self.sockets.pop(idx) + # shift indices after deletion + for socket, flags in self.sockets[idx:]: + self._map[socket] -= 1 + + def poll(self, timeout=None): + """Poll the registered 0MQ or native fds for I/O. + + Parameters + ---------- + timeout : float, int + The timeout in milliseconds. If None, no `timeout` (infinite). This + is in milliseconds to be compatible with ``select.poll()``. The + underlying zmq_poll uses microseconds and we convert to that in + this function. + + Returns + ------- + events : list of tuples + The list of events that are ready to be processed. + This is a list of tuples of the form ``(socket, event)``, where the 0MQ Socket + or integer fd is the first element, and the poll event mask (POLLIN, POLLOUT) is the second. + It is common to call ``events = dict(poller.poll())``, + which turns the list of tuples into a mapping of ``socket : event``. + """ + if timeout is None or timeout < 0: + timeout = -1 + elif isinstance(timeout, float): + timeout = int(timeout) + return zmq_poll(self.sockets, timeout=timeout) + + +def select(rlist, wlist, xlist, timeout=None): + """select(rlist, wlist, xlist, timeout=None) -> (rlist, wlist, xlist) + + Return the result of poll as a lists of sockets ready for r/w/exception. + + This has the same interface as Python's built-in ``select.select()`` function. + + Parameters + ---------- + timeout : float, int, optional + The timeout in seconds. If None, no timeout (infinite). This is in seconds to be + compatible with ``select.select()``. The underlying zmq_poll uses microseconds + and we convert to that in this function. + rlist : list of sockets/FDs + sockets/FDs to be polled for read events + wlist : list of sockets/FDs + sockets/FDs to be polled for write events + xlist : list of sockets/FDs + sockets/FDs to be polled for error events + + Returns + ------- + (rlist, wlist, xlist) : tuple of lists of sockets (length 3) + Lists correspond to sockets available for read/write/error events respectively. + """ + if timeout is None: + timeout = -1 + # Convert from sec -> us for zmq_poll. + # zmq_poll accepts 3.x style timeout in ms + timeout = int(timeout*1000.0) + if timeout < 0: + timeout = -1 + sockets = [] + for s in set(rlist + wlist + xlist): + flags = 0 + if s in rlist: + flags |= POLLIN + if s in wlist: + flags |= POLLOUT + if s in xlist: + flags |= POLLERR + sockets.append((s, flags)) + return_sockets = zmq_poll(sockets, timeout) + rlist, wlist, xlist = [], [], [] + for s, flags in return_sockets: + if flags & POLLIN: + rlist.append(s) + if flags & POLLOUT: + wlist.append(s) + if flags & POLLERR: + xlist.append(s) + return rlist, wlist, xlist + +#----------------------------------------------------------------------------- +# Symbols to export +#----------------------------------------------------------------------------- + +__all__ = [ 'Poller', 'select' ] diff --git a/src/console/zmq/sugar/socket.py b/src/console/zmq/sugar/socket.py new file mode 100755 index 00000000..c91589d7 --- /dev/null +++ b/src/console/zmq/sugar/socket.py @@ -0,0 +1,495 @@ +# coding: utf-8 +"""0MQ Socket pure Python methods.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import codecs +import random +import warnings + +import zmq +from zmq.backend import Socket as SocketBase +from .poll import Poller +from . import constants +from .attrsettr import AttributeSetter +from zmq.error import ZMQError, ZMQBindError +from zmq.utils import jsonapi +from zmq.utils.strtypes import bytes,unicode,basestring +from zmq.utils.interop import cast_int_addr + +from .constants import ( + SNDMORE, ENOTSUP, POLLIN, + int64_sockopt_names, + int_sockopt_names, + bytes_sockopt_names, + fd_sockopt_names, +) +try: + import cPickle + pickle = cPickle +except: + cPickle = None + import pickle + +try: + DEFAULT_PROTOCOL = pickle.DEFAULT_PROTOCOL +except AttributeError: + DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL + + +class Socket(SocketBase, AttributeSetter): + """The ZMQ socket object + + To create a Socket, first create a Context:: + + ctx = zmq.Context.instance() + + then call ``ctx.socket(socket_type)``:: + + s = ctx.socket(zmq.ROUTER) + + """ + _shadow = False + + def __del__(self): + if not self._shadow: + self.close() + + # socket as context manager: + def __enter__(self): + """Sockets are context managers + + .. versionadded:: 14.4 + """ + return self + + def __exit__(self, *args, **kwargs): + self.close() + + #------------------------------------------------------------------------- + # Socket creation + #------------------------------------------------------------------------- + + @classmethod + def shadow(cls, address): + """Shadow an existing libzmq socket + + address is the integer address of the libzmq socket + or an FFI pointer to it. + + .. versionadded:: 14.1 + """ + address = cast_int_addr(address) + return cls(shadow=address) + + #------------------------------------------------------------------------- + # Deprecated aliases + #------------------------------------------------------------------------- + + @property + def socket_type(self): + warnings.warn("Socket.socket_type is deprecated, use Socket.type", + DeprecationWarning + ) + return self.type + + #------------------------------------------------------------------------- + # Hooks for sockopt completion + #------------------------------------------------------------------------- + + def __dir__(self): + keys = dir(self.__class__) + for collection in ( + bytes_sockopt_names, + int_sockopt_names, + int64_sockopt_names, + fd_sockopt_names, + ): + keys.extend(collection) + return keys + + #------------------------------------------------------------------------- + # Getting/Setting options + #------------------------------------------------------------------------- + setsockopt = SocketBase.set + getsockopt = SocketBase.get + + def set_string(self, option, optval, encoding='utf-8'): + """set socket options with a unicode object + + This is simply a wrapper for setsockopt to protect from encoding ambiguity. + + See the 0MQ documentation for details on specific options. + + Parameters + ---------- + option : int + The name of the option to set. Can be any of: SUBSCRIBE, + UNSUBSCRIBE, IDENTITY + optval : unicode string (unicode on py2, str on py3) + The value of the option to set. + encoding : str + The encoding to be used, default is utf8 + """ + if not isinstance(optval, unicode): + raise TypeError("unicode strings only") + return self.set(option, optval.encode(encoding)) + + setsockopt_unicode = setsockopt_string = set_string + + def get_string(self, option, encoding='utf-8'): + """get the value of a socket option + + See the 0MQ documentation for details on specific options. + + Parameters + ---------- + option : int + The option to retrieve. + + Returns + ------- + optval : unicode string (unicode on py2, str on py3) + The value of the option as a unicode string. + """ + + if option not in constants.bytes_sockopts: + raise TypeError("option %i will not return a string to be decoded"%option) + return self.getsockopt(option).decode(encoding) + + getsockopt_unicode = getsockopt_string = get_string + + def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=100): + """bind this socket to a random port in a range + + Parameters + ---------- + addr : str + The address string without the port to pass to ``Socket.bind()``. + min_port : int, optional + The minimum port in the range of ports to try (inclusive). + max_port : int, optional + The maximum port in the range of ports to try (exclusive). + max_tries : int, optional + The maximum number of bind attempts to make. + + Returns + ------- + port : int + The port the socket was bound to. + + Raises + ------ + ZMQBindError + if `max_tries` reached before successful bind + """ + for i in range(max_tries): + try: + port = random.randrange(min_port, max_port) + self.bind('%s:%s' % (addr, port)) + except ZMQError as exception: + if not exception.errno == zmq.EADDRINUSE: + raise + else: + return port + raise ZMQBindError("Could not bind socket to random port.") + + def get_hwm(self): + """get the High Water Mark + + On libzmq ≥ 3, this gets SNDHWM if available, otherwise RCVHWM + """ + major = zmq.zmq_version_info()[0] + if major >= 3: + # return sndhwm, fallback on rcvhwm + try: + return self.getsockopt(zmq.SNDHWM) + except zmq.ZMQError as e: + pass + + return self.getsockopt(zmq.RCVHWM) + else: + return self.getsockopt(zmq.HWM) + + def set_hwm(self, value): + """set the High Water Mark + + On libzmq ≥ 3, this sets both SNDHWM and RCVHWM + """ + major = zmq.zmq_version_info()[0] + if major >= 3: + raised = None + try: + self.sndhwm = value + except Exception as e: + raised = e + try: + self.rcvhwm = value + except Exception: + raised = e + + if raised: + raise raised + else: + return self.setsockopt(zmq.HWM, value) + + hwm = property(get_hwm, set_hwm, + """property for High Water Mark + + Setting hwm sets both SNDHWM and RCVHWM as appropriate. + It gets SNDHWM if available, otherwise RCVHWM. + """ + ) + + #------------------------------------------------------------------------- + # Sending and receiving messages + #------------------------------------------------------------------------- + + def send_multipart(self, msg_parts, flags=0, copy=True, track=False): + """send a sequence of buffers as a multipart message + + The zmq.SNDMORE flag is added to all msg parts before the last. + + Parameters + ---------- + msg_parts : iterable + A sequence of objects to send as a multipart message. Each element + can be any sendable object (Frame, bytes, buffer-providers) + flags : int, optional + SNDMORE is handled automatically for frames before the last. + copy : bool, optional + Should the frame(s) be sent in a copying or non-copying manner. + track : bool, optional + Should the frame(s) be tracked for notification that ZMQ has + finished with it (ignored if copy=True). + + Returns + ------- + None : if copy or not track + MessageTracker : if track and not copy + a MessageTracker object, whose `pending` property will + be True until the last send is completed. + """ + for msg in msg_parts[:-1]: + self.send(msg, SNDMORE|flags, copy=copy, track=track) + # Send the last part without the extra SNDMORE flag. + return self.send(msg_parts[-1], flags, copy=copy, track=track) + + def recv_multipart(self, flags=0, copy=True, track=False): + """receive a multipart message as a list of bytes or Frame objects + + Parameters + ---------- + flags : int, optional + Any supported flag: NOBLOCK. If NOBLOCK is set, this method + will raise a ZMQError with EAGAIN if a message is not ready. + If NOBLOCK is not set, then this method will block until a + message arrives. + copy : bool, optional + Should the message frame(s) be received in a copying or non-copying manner? + If False a Frame object is returned for each part, if True a copy of + the bytes is made for each frame. + track : bool, optional + Should the message frame(s) be tracked for notification that ZMQ has + finished with it? (ignored if copy=True) + + Returns + ------- + msg_parts : list + A list of frames in the multipart message; either Frames or bytes, + depending on `copy`. + + """ + parts = [self.recv(flags, copy=copy, track=track)] + # have first part already, only loop while more to receive + while self.getsockopt(zmq.RCVMORE): + part = self.recv(flags, copy=copy, track=track) + parts.append(part) + + return parts + + def send_string(self, u, flags=0, copy=True, encoding='utf-8'): + """send a Python unicode string as a message with an encoding + + 0MQ communicates with raw bytes, so you must encode/decode + text (unicode on py2, str on py3) around 0MQ. + + Parameters + ---------- + u : Python unicode string (unicode on py2, str on py3) + The unicode string to send. + flags : int, optional + Any valid send flag. + encoding : str [default: 'utf-8'] + The encoding to be used + """ + if not isinstance(u, basestring): + raise TypeError("unicode/str objects only") + return self.send(u.encode(encoding), flags=flags, copy=copy) + + send_unicode = send_string + + def recv_string(self, flags=0, encoding='utf-8'): + """receive a unicode string, as sent by send_string + + Parameters + ---------- + flags : int + Any valid recv flag. + encoding : str [default: 'utf-8'] + The encoding to be used + + Returns + ------- + s : unicode string (unicode on py2, str on py3) + The Python unicode string that arrives as encoded bytes. + """ + b = self.recv(flags=flags) + return b.decode(encoding) + + recv_unicode = recv_string + + def send_pyobj(self, obj, flags=0, protocol=DEFAULT_PROTOCOL): + """send a Python object as a message using pickle to serialize + + Parameters + ---------- + obj : Python object + The Python object to send. + flags : int + Any valid send flag. + protocol : int + The pickle protocol number to use. The default is pickle.DEFAULT_PROTOCOl + where defined, and pickle.HIGHEST_PROTOCOL elsewhere. + """ + msg = pickle.dumps(obj, protocol) + return self.send(msg, flags) + + def recv_pyobj(self, flags=0): + """receive a Python object as a message using pickle to serialize + + Parameters + ---------- + flags : int + Any valid recv flag. + + Returns + ------- + obj : Python object + The Python object that arrives as a message. + """ + s = self.recv(flags) + return pickle.loads(s) + + def send_json(self, obj, flags=0, **kwargs): + """send a Python object as a message using json to serialize + + Keyword arguments are passed on to json.dumps + + Parameters + ---------- + obj : Python object + The Python object to send + flags : int + Any valid send flag + """ + msg = jsonapi.dumps(obj, **kwargs) + return self.send(msg, flags) + + def recv_json(self, flags=0, **kwargs): + """receive a Python object as a message using json to serialize + + Keyword arguments are passed on to json.loads + + Parameters + ---------- + flags : int + Any valid recv flag. + + Returns + ------- + obj : Python object + The Python object that arrives as a message. + """ + msg = self.recv(flags) + return jsonapi.loads(msg, **kwargs) + + _poller_class = Poller + + def poll(self, timeout=None, flags=POLLIN): + """poll the socket for events + + The default is to poll forever for incoming + events. Timeout is in milliseconds, if specified. + + Parameters + ---------- + timeout : int [default: None] + The timeout (in milliseconds) to wait for an event. If unspecified + (or specified None), will wait forever for an event. + flags : bitfield (int) [default: POLLIN] + The event flags to poll for (any combination of POLLIN|POLLOUT). + The default is to check for incoming events (POLLIN). + + Returns + ------- + events : bitfield (int) + The events that are ready and waiting. Will be 0 if no events were ready + by the time timeout was reached. + """ + + if self.closed: + raise ZMQError(ENOTSUP) + + p = self._poller_class() + p.register(self, flags) + evts = dict(p.poll(timeout)) + # return 0 if no events, otherwise return event bitfield + return evts.get(self, 0) + + def get_monitor_socket(self, events=None, addr=None): + """Return a connected PAIR socket ready to receive the event notifications. + + .. versionadded:: libzmq-4.0 + .. versionadded:: 14.0 + + Parameters + ---------- + events : bitfield (int) [default: ZMQ_EVENTS_ALL] + The bitmask defining which events are wanted. + addr : string [default: None] + The optional endpoint for the monitoring sockets. + + Returns + ------- + socket : (PAIR) + The socket is already connected and ready to receive messages. + """ + # safe-guard, method only available on libzmq >= 4 + if zmq.zmq_version_info() < (4,): + raise NotImplementedError("get_monitor_socket requires libzmq >= 4, have %s" % zmq.zmq_version()) + if addr is None: + # create endpoint name from internal fd + addr = "inproc://monitor.s-%d" % self.FD + if events is None: + # use all events + events = zmq.EVENT_ALL + # attach monitoring socket + self.monitor(addr, events) + # create new PAIR socket and connect it + ret = self.context.socket(zmq.PAIR) + ret.connect(addr) + return ret + + def disable_monitor(self): + """Shutdown the PAIR socket (created using get_monitor_socket) + that is serving socket events. + + .. versionadded:: 14.4 + """ + self.monitor(None, 0) + + +__all__ = ['Socket'] diff --git a/src/console/zmq/sugar/tracker.py b/src/console/zmq/sugar/tracker.py new file mode 100755 index 00000000..fb8c007f --- /dev/null +++ b/src/console/zmq/sugar/tracker.py @@ -0,0 +1,120 @@ +"""Tracker for zero-copy messages with 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time + +try: + # below 3.3 + from threading import _Event as Event +except (ImportError, AttributeError): + # python throws ImportError, cython throws AttributeError + from threading import Event + +from zmq.error import NotDone +from zmq.backend import Frame + +class MessageTracker(object): + """MessageTracker(*towatch) + + A class for tracking if 0MQ is done using one or more messages. + + When you send a 0MQ message, it is not sent immediately. The 0MQ IO thread + sends the message at some later time. Often you want to know when 0MQ has + actually sent the message though. This is complicated by the fact that + a single 0MQ message can be sent multiple times using different sockets. + This class allows you to track all of the 0MQ usages of a message. + + Parameters + ---------- + *towatch : tuple of Event, MessageTracker, Message instances. + This list of objects to track. This class can track the low-level + Events used by the Message class, other MessageTrackers or + actual Messages. + """ + events = None + peers = None + + def __init__(self, *towatch): + """MessageTracker(*towatch) + + Create a message tracker to track a set of mesages. + + Parameters + ---------- + *towatch : tuple of Event, MessageTracker, Message instances. + This list of objects to track. This class can track the low-level + Events used by the Message class, other MessageTrackers or + actual Messages. + """ + self.events = set() + self.peers = set() + for obj in towatch: + if isinstance(obj, Event): + self.events.add(obj) + elif isinstance(obj, MessageTracker): + self.peers.add(obj) + elif isinstance(obj, Frame): + if not obj.tracker: + raise ValueError("Not a tracked message") + self.peers.add(obj.tracker) + else: + raise TypeError("Require Events or Message Frames, not %s"%type(obj)) + + @property + def done(self): + """Is 0MQ completely done with the message(s) being tracked?""" + for evt in self.events: + if not evt.is_set(): + return False + for pm in self.peers: + if not pm.done: + return False + return True + + def wait(self, timeout=-1): + """mt.wait(timeout=-1) + + Wait for 0MQ to be done with the message or until `timeout`. + + Parameters + ---------- + timeout : float [default: -1, wait forever] + Maximum time in (s) to wait before raising NotDone. + + Returns + ------- + None + if done before `timeout` + + Raises + ------ + NotDone + if `timeout` reached before I am done. + """ + tic = time.time() + if timeout is False or timeout < 0: + remaining = 3600*24*7 # a week + else: + remaining = timeout + done = False + for evt in self.events: + if remaining < 0: + raise NotDone + evt.wait(timeout=remaining) + if not evt.is_set(): + raise NotDone + toc = time.time() + remaining -= (toc-tic) + tic = toc + + for peer in self.peers: + if remaining < 0: + raise NotDone + peer.wait(timeout=remaining) + toc = time.time() + remaining -= (toc-tic) + tic = toc + +__all__ = ['MessageTracker']
\ No newline at end of file diff --git a/src/console/zmq/sugar/version.py b/src/console/zmq/sugar/version.py new file mode 100755 index 00000000..ea8fbbc4 --- /dev/null +++ b/src/console/zmq/sugar/version.py @@ -0,0 +1,48 @@ +"""PyZMQ and 0MQ version functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from zmq.backend import zmq_version_info + + +VERSION_MAJOR = 14 +VERSION_MINOR = 5 +VERSION_PATCH = 0 +VERSION_EXTRA = "" +__version__ = '%i.%i.%i' % (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) + +if VERSION_EXTRA: + __version__ = "%s-%s" % (__version__, VERSION_EXTRA) + version_info = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH, float('inf')) +else: + version_info = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) + +__revision__ = '' + +def pyzmq_version(): + """return the version of pyzmq as a string""" + if __revision__: + return '@'.join([__version__,__revision__[:6]]) + else: + return __version__ + +def pyzmq_version_info(): + """return the pyzmq version as a tuple of at least three numbers + + If pyzmq is a development version, `inf` will be appended after the third integer. + """ + return version_info + + +def zmq_version(): + """return the version of libzmq as a string""" + return "%i.%i.%i" % zmq_version_info() + + +__all__ = ['zmq_version', 'zmq_version_info', + 'pyzmq_version','pyzmq_version_info', + '__version__', '__revision__' +] + diff --git a/src/console/zmq/tests/__init__.py b/src/console/zmq/tests/__init__.py new file mode 100755 index 00000000..325a3f19 --- /dev/null +++ b/src/console/zmq/tests/__init__.py @@ -0,0 +1,211 @@ +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +import functools +import sys +import time +from threading import Thread + +from unittest import TestCase + +import zmq +from zmq.utils import jsonapi + +try: + import gevent + from zmq import green as gzmq + have_gevent = True +except ImportError: + have_gevent = False + +try: + from unittest import SkipTest +except ImportError: + try: + from nose import SkipTest + except ImportError: + class SkipTest(Exception): + pass + +PYPY = 'PyPy' in sys.version + +#----------------------------------------------------------------------------- +# skip decorators (directly from unittest) +#----------------------------------------------------------------------------- + +_id = lambda x: x + +def skip(reason): + """ + Unconditionally skip a test. + """ + def decorator(test_item): + if not (isinstance(test_item, type) and issubclass(test_item, TestCase)): + @functools.wraps(test_item) + def skip_wrapper(*args, **kwargs): + raise SkipTest(reason) + test_item = skip_wrapper + + test_item.__unittest_skip__ = True + test_item.__unittest_skip_why__ = reason + return test_item + return decorator + +def skip_if(condition, reason="Skipped"): + """ + Skip a test if the condition is true. + """ + if condition: + return skip(reason) + return _id + +skip_pypy = skip_if(PYPY, "Doesn't work on PyPy") + +#----------------------------------------------------------------------------- +# Base test class +#----------------------------------------------------------------------------- + +class BaseZMQTestCase(TestCase): + green = False + + @property + def Context(self): + if self.green: + return gzmq.Context + else: + return zmq.Context + + def socket(self, socket_type): + s = self.context.socket(socket_type) + self.sockets.append(s) + return s + + def setUp(self): + if self.green and not have_gevent: + raise SkipTest("requires gevent") + self.context = self.Context.instance() + self.sockets = [] + + def tearDown(self): + contexts = set([self.context]) + while self.sockets: + sock = self.sockets.pop() + contexts.add(sock.context) # in case additional contexts are created + sock.close(0) + for ctx in contexts: + t = Thread(target=ctx.term) + t.daemon = True + t.start() + t.join(timeout=2) + if t.is_alive(): + # reset Context.instance, so the failure to term doesn't corrupt subsequent tests + zmq.sugar.context.Context._instance = None + raise RuntimeError("context could not terminate, open sockets likely remain in test") + + def create_bound_pair(self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'): + """Create a bound socket pair using a random port.""" + s1 = self.context.socket(type1) + s1.setsockopt(zmq.LINGER, 0) + port = s1.bind_to_random_port(interface) + s2 = self.context.socket(type2) + s2.setsockopt(zmq.LINGER, 0) + s2.connect('%s:%s' % (interface, port)) + self.sockets.extend([s1,s2]) + return s1, s2 + + def ping_pong(self, s1, s2, msg): + s1.send(msg) + msg2 = s2.recv() + s2.send(msg2) + msg3 = s1.recv() + return msg3 + + def ping_pong_json(self, s1, s2, o): + if jsonapi.jsonmod is None: + raise SkipTest("No json library") + s1.send_json(o) + o2 = s2.recv_json() + s2.send_json(o2) + o3 = s1.recv_json() + return o3 + + def ping_pong_pyobj(self, s1, s2, o): + s1.send_pyobj(o) + o2 = s2.recv_pyobj() + s2.send_pyobj(o2) + o3 = s1.recv_pyobj() + return o3 + + def assertRaisesErrno(self, errno, func, *args, **kwargs): + try: + func(*args, **kwargs) + except zmq.ZMQError as e: + self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \ +got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno))) + else: + self.fail("Function did not raise any error") + + def _select_recv(self, multipart, socket, **kwargs): + """call recv[_multipart] in a way that raises if there is nothing to receive""" + if zmq.zmq_version_info() >= (3,1,0): + # zmq 3.1 has a bug, where poll can return false positives, + # so we wait a little bit just in case + # See LIBZMQ-280 on JIRA + time.sleep(0.1) + + r,w,x = zmq.select([socket], [], [], timeout=5) + assert len(r) > 0, "Should have received a message" + kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0) + + recv = socket.recv_multipart if multipart else socket.recv + return recv(**kwargs) + + def recv(self, socket, **kwargs): + """call recv in a way that raises if there is nothing to receive""" + return self._select_recv(False, socket, **kwargs) + + def recv_multipart(self, socket, **kwargs): + """call recv_multipart in a way that raises if there is nothing to receive""" + return self._select_recv(True, socket, **kwargs) + + +class PollZMQTestCase(BaseZMQTestCase): + pass + +class GreenTest: + """Mixin for making green versions of test classes""" + green = True + + def assertRaisesErrno(self, errno, func, *args, **kwargs): + if errno == zmq.EAGAIN: + raise SkipTest("Skipping because we're green.") + try: + func(*args, **kwargs) + except zmq.ZMQError: + e = sys.exc_info()[1] + self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \ +got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno))) + else: + self.fail("Function did not raise any error") + + def tearDown(self): + contexts = set([self.context]) + while self.sockets: + sock = self.sockets.pop() + contexts.add(sock.context) # in case additional contexts are created + sock.close() + try: + gevent.joinall([gevent.spawn(ctx.term) for ctx in contexts], timeout=2, raise_error=True) + except gevent.Timeout: + raise RuntimeError("context could not terminate, open sockets likely remain in test") + + def skip_green(self): + raise SkipTest("Skipping because we are green") + +def skip_green(f): + def skipping_test(self, *args, **kwargs): + if self.green: + raise SkipTest("Skipping because we are green") + else: + return f(self, *args, **kwargs) + return skipping_test diff --git a/src/console/zmq/tests/test_auth.py b/src/console/zmq/tests/test_auth.py new file mode 100755 index 00000000..d350f61f --- /dev/null +++ b/src/console/zmq/tests/test_auth.py @@ -0,0 +1,431 @@ +# -*- coding: utf8 -*- + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import logging +import os +import shutil +import sys +import tempfile + +import zmq.auth +from zmq.auth.ioloop import IOLoopAuthenticator +from zmq.auth.thread import ThreadAuthenticator + +from zmq.eventloop import ioloop, zmqstream +from zmq.tests import (BaseZMQTestCase, SkipTest) + +class BaseAuthTestCase(BaseZMQTestCase): + def setUp(self): + if zmq.zmq_version_info() < (4,0): + raise SkipTest("security is new in libzmq 4.0") + try: + zmq.curve_keypair() + except zmq.ZMQError: + raise SkipTest("security requires libzmq to be linked against libsodium") + super(BaseAuthTestCase, self).setUp() + # enable debug logging while we run tests + logging.getLogger('zmq.auth').setLevel(logging.DEBUG) + self.auth = self.make_auth() + self.auth.start() + self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs() + + def make_auth(self): + raise NotImplementedError() + + def tearDown(self): + if self.auth: + self.auth.stop() + self.auth = None + self.remove_certs(self.base_dir) + super(BaseAuthTestCase, self).tearDown() + + def create_certs(self): + """Create CURVE certificates for a test""" + + # Create temporary CURVE keypairs for this test run. We create all keys in a + # temp directory and then move them into the appropriate private or public + # directory. + + base_dir = tempfile.mkdtemp() + keys_dir = os.path.join(base_dir, 'certificates') + public_keys_dir = os.path.join(base_dir, 'public_keys') + secret_keys_dir = os.path.join(base_dir, 'private_keys') + + os.mkdir(keys_dir) + os.mkdir(public_keys_dir) + os.mkdir(secret_keys_dir) + + server_public_file, server_secret_file = zmq.auth.create_certificates(keys_dir, "server") + client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "client") + + for key_file in os.listdir(keys_dir): + if key_file.endswith(".key"): + shutil.move(os.path.join(keys_dir, key_file), + os.path.join(public_keys_dir, '.')) + + for key_file in os.listdir(keys_dir): + if key_file.endswith(".key_secret"): + shutil.move(os.path.join(keys_dir, key_file), + os.path.join(secret_keys_dir, '.')) + + return (base_dir, public_keys_dir, secret_keys_dir) + + def remove_certs(self, base_dir): + """Remove certificates for a test""" + shutil.rmtree(base_dir) + + def load_certs(self, secret_keys_dir): + """Return server and client certificate keys""" + server_secret_file = os.path.join(secret_keys_dir, "server.key_secret") + client_secret_file = os.path.join(secret_keys_dir, "client.key_secret") + + server_public, server_secret = zmq.auth.load_certificate(server_secret_file) + client_public, client_secret = zmq.auth.load_certificate(client_secret_file) + + return server_public, server_secret, client_public, client_secret + + +class TestThreadAuthentication(BaseAuthTestCase): + """Test authentication running in a thread""" + + def make_auth(self): + return ThreadAuthenticator(self.context) + + def can_connect(self, server, client): + """Check if client can connect to server using tcp transport""" + result = False + iface = 'tcp://127.0.0.1' + port = server.bind_to_random_port(iface) + client.connect("%s:%i" % (iface, port)) + msg = [b"Hello World"] + server.send_multipart(msg) + if client.poll(1000): + rcvd_msg = client.recv_multipart() + self.assertEqual(rcvd_msg, msg) + result = True + return result + + def test_null(self): + """threaded auth - NULL""" + # A default NULL connection should always succeed, and not + # go through our authentication infrastructure at all. + self.auth.stop() + self.auth = None + + server = self.socket(zmq.PUSH) + client = self.socket(zmq.PULL) + self.assertTrue(self.can_connect(server, client)) + + # By setting a domain we switch on authentication for NULL sockets, + # though no policies are configured yet. The client connection + # should still be allowed. + server = self.socket(zmq.PUSH) + server.zap_domain = b'global' + client = self.socket(zmq.PULL) + self.assertTrue(self.can_connect(server, client)) + + def test_blacklist(self): + """threaded auth - Blacklist""" + # Blacklist 127.0.0.1, connection should fail + self.auth.deny('127.0.0.1') + server = self.socket(zmq.PUSH) + # By setting a domain we switch on authentication for NULL sockets, + # though no policies are configured yet. + server.zap_domain = b'global' + client = self.socket(zmq.PULL) + self.assertFalse(self.can_connect(server, client)) + + def test_whitelist(self): + """threaded auth - Whitelist""" + # Whitelist 127.0.0.1, connection should pass" + self.auth.allow('127.0.0.1') + server = self.socket(zmq.PUSH) + # By setting a domain we switch on authentication for NULL sockets, + # though no policies are configured yet. + server.zap_domain = b'global' + client = self.socket(zmq.PULL) + self.assertTrue(self.can_connect(server, client)) + + def test_plain(self): + """threaded auth - PLAIN""" + + # Try PLAIN authentication - without configuring server, connection should fail + server = self.socket(zmq.PUSH) + server.plain_server = True + client = self.socket(zmq.PULL) + client.plain_username = b'admin' + client.plain_password = b'Password' + self.assertFalse(self.can_connect(server, client)) + + # Try PLAIN authentication - with server configured, connection should pass + server = self.socket(zmq.PUSH) + server.plain_server = True + client = self.socket(zmq.PULL) + client.plain_username = b'admin' + client.plain_password = b'Password' + self.auth.configure_plain(domain='*', passwords={'admin': 'Password'}) + self.assertTrue(self.can_connect(server, client)) + + # Try PLAIN authentication - with bogus credentials, connection should fail + server = self.socket(zmq.PUSH) + server.plain_server = True + client = self.socket(zmq.PULL) + client.plain_username = b'admin' + client.plain_password = b'Bogus' + self.assertFalse(self.can_connect(server, client)) + + # Remove authenticator and check that a normal connection works + self.auth.stop() + self.auth = None + + server = self.socket(zmq.PUSH) + client = self.socket(zmq.PULL) + self.assertTrue(self.can_connect(server, client)) + client.close() + server.close() + + def test_curve(self): + """threaded auth - CURVE""" + self.auth.allow('127.0.0.1') + certs = self.load_certs(self.secret_keys_dir) + server_public, server_secret, client_public, client_secret = certs + + #Try CURVE authentication - without configuring server, connection should fail + server = self.socket(zmq.PUSH) + server.curve_publickey = server_public + server.curve_secretkey = server_secret + server.curve_server = True + client = self.socket(zmq.PULL) + client.curve_publickey = client_public + client.curve_secretkey = client_secret + client.curve_serverkey = server_public + self.assertFalse(self.can_connect(server, client)) + + #Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass + self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) + server = self.socket(zmq.PUSH) + server.curve_publickey = server_public + server.curve_secretkey = server_secret + server.curve_server = True + client = self.socket(zmq.PULL) + client.curve_publickey = client_public + client.curve_secretkey = client_secret + client.curve_serverkey = server_public + self.assertTrue(self.can_connect(server, client)) + + # Try CURVE authentication - with server configured, connection should pass + self.auth.configure_curve(domain='*', location=self.public_keys_dir) + server = self.socket(zmq.PUSH) + server.curve_publickey = server_public + server.curve_secretkey = server_secret + server.curve_server = True + client = self.socket(zmq.PULL) + client.curve_publickey = client_public + client.curve_secretkey = client_secret + client.curve_serverkey = server_public + self.assertTrue(self.can_connect(server, client)) + + # Remove authenticator and check that a normal connection works + self.auth.stop() + self.auth = None + + # Try connecting using NULL and no authentication enabled, connection should pass + server = self.socket(zmq.PUSH) + client = self.socket(zmq.PULL) + self.assertTrue(self.can_connect(server, client)) + + +def with_ioloop(method, expect_success=True): + """decorator for running tests with an IOLoop""" + def test_method(self): + r = method(self) + + loop = self.io_loop + if expect_success: + self.pullstream.on_recv(self.on_message_succeed) + else: + self.pullstream.on_recv(self.on_message_fail) + + t = loop.time() + loop.add_callback(self.attempt_connection) + loop.add_callback(self.send_msg) + if expect_success: + loop.add_timeout(t + 1, self.on_test_timeout_fail) + else: + loop.add_timeout(t + 1, self.on_test_timeout_succeed) + + loop.start() + if self.fail_msg: + self.fail(self.fail_msg) + + return r + return test_method + +def should_auth(method): + return with_ioloop(method, True) + +def should_not_auth(method): + return with_ioloop(method, False) + +class TestIOLoopAuthentication(BaseAuthTestCase): + """Test authentication running in ioloop""" + + def setUp(self): + self.fail_msg = None + self.io_loop = ioloop.IOLoop() + super(TestIOLoopAuthentication, self).setUp() + self.server = self.socket(zmq.PUSH) + self.client = self.socket(zmq.PULL) + self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop) + self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop) + + def make_auth(self): + return IOLoopAuthenticator(self.context, io_loop=self.io_loop) + + def tearDown(self): + if self.auth: + self.auth.stop() + self.auth = None + self.io_loop.close(all_fds=True) + super(TestIOLoopAuthentication, self).tearDown() + + def attempt_connection(self): + """Check if client can connect to server using tcp transport""" + iface = 'tcp://127.0.0.1' + port = self.server.bind_to_random_port(iface) + self.client.connect("%s:%i" % (iface, port)) + + def send_msg(self): + """Send a message from server to a client""" + msg = [b"Hello World"] + self.pushstream.send_multipart(msg) + + def on_message_succeed(self, frames): + """A message was received, as expected.""" + if frames != [b"Hello World"]: + self.fail_msg = "Unexpected message received" + self.io_loop.stop() + + def on_message_fail(self, frames): + """A message was received, unexpectedly.""" + self.fail_msg = 'Received messaged unexpectedly, security failed' + self.io_loop.stop() + + def on_test_timeout_succeed(self): + """Test timer expired, indicates test success""" + self.io_loop.stop() + + def on_test_timeout_fail(self): + """Test timer expired, indicates test failure""" + self.fail_msg = 'Test timed out' + self.io_loop.stop() + + @should_auth + def test_none(self): + """ioloop auth - NONE""" + # A default NULL connection should always succeed, and not + # go through our authentication infrastructure at all. + # no auth should be running + self.auth.stop() + self.auth = None + + @should_auth + def test_null(self): + """ioloop auth - NULL""" + # By setting a domain we switch on authentication for NULL sockets, + # though no policies are configured yet. The client connection + # should still be allowed. + self.server.zap_domain = b'global' + + @should_not_auth + def test_blacklist(self): + """ioloop auth - Blacklist""" + # Blacklist 127.0.0.1, connection should fail + self.auth.deny('127.0.0.1') + self.server.zap_domain = b'global' + + @should_auth + def test_whitelist(self): + """ioloop auth - Whitelist""" + # Whitelist 127.0.0.1, which overrides the blacklist, connection should pass" + self.auth.allow('127.0.0.1') + + self.server.setsockopt(zmq.ZAP_DOMAIN, b'global') + + @should_not_auth + def test_plain_unconfigured_server(self): + """ioloop auth - PLAIN, unconfigured server""" + self.client.plain_username = b'admin' + self.client.plain_password = b'Password' + # Try PLAIN authentication - without configuring server, connection should fail + self.server.plain_server = True + + @should_auth + def test_plain_configured_server(self): + """ioloop auth - PLAIN, configured server""" + self.client.plain_username = b'admin' + self.client.plain_password = b'Password' + # Try PLAIN authentication - with server configured, connection should pass + self.server.plain_server = True + self.auth.configure_plain(domain='*', passwords={'admin': 'Password'}) + + @should_not_auth + def test_plain_bogus_credentials(self): + """ioloop auth - PLAIN, bogus credentials""" + self.client.plain_username = b'admin' + self.client.plain_password = b'Bogus' + self.server.plain_server = True + + self.auth.configure_plain(domain='*', passwords={'admin': 'Password'}) + + @should_not_auth + def test_curve_unconfigured_server(self): + """ioloop auth - CURVE, unconfigured server""" + certs = self.load_certs(self.secret_keys_dir) + server_public, server_secret, client_public, client_secret = certs + + self.auth.allow('127.0.0.1') + + self.server.curve_publickey = server_public + self.server.curve_secretkey = server_secret + self.server.curve_server = True + + self.client.curve_publickey = client_public + self.client.curve_secretkey = client_secret + self.client.curve_serverkey = server_public + + @should_auth + def test_curve_allow_any(self): + """ioloop auth - CURVE, CURVE_ALLOW_ANY""" + certs = self.load_certs(self.secret_keys_dir) + server_public, server_secret, client_public, client_secret = certs + + self.auth.allow('127.0.0.1') + self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) + + self.server.curve_publickey = server_public + self.server.curve_secretkey = server_secret + self.server.curve_server = True + + self.client.curve_publickey = client_public + self.client.curve_secretkey = client_secret + self.client.curve_serverkey = server_public + + @should_auth + def test_curve_configured_server(self): + """ioloop auth - CURVE, configured server""" + self.auth.allow('127.0.0.1') + certs = self.load_certs(self.secret_keys_dir) + server_public, server_secret, client_public, client_secret = certs + + self.auth.configure_curve(domain='*', location=self.public_keys_dir) + + self.server.curve_publickey = server_public + self.server.curve_secretkey = server_secret + self.server.curve_server = True + + self.client.curve_publickey = client_public + self.client.curve_secretkey = client_secret + self.client.curve_serverkey = server_public diff --git a/src/console/zmq/tests/test_cffi_backend.py b/src/console/zmq/tests/test_cffi_backend.py new file mode 100755 index 00000000..1f85eebf --- /dev/null +++ b/src/console/zmq/tests/test_cffi_backend.py @@ -0,0 +1,310 @@ +# -*- coding: utf8 -*- + +import sys +import time + +from unittest import TestCase + +from zmq.tests import BaseZMQTestCase, SkipTest + +try: + from zmq.backend.cffi import ( + zmq_version_info, + PUSH, PULL, IDENTITY, + REQ, REP, POLLIN, POLLOUT, + ) + from zmq.backend.cffi._cffi import ffi, C + have_ffi_backend = True +except ImportError: + have_ffi_backend = False + + +class TestCFFIBackend(TestCase): + + def setUp(self): + if not have_ffi_backend or not 'PyPy' in sys.version: + raise SkipTest('PyPy Tests Only') + + def test_zmq_version_info(self): + version = zmq_version_info() + + assert version[0] in range(2,11) + + def test_zmq_ctx_new_destroy(self): + ctx = C.zmq_ctx_new() + + assert ctx != ffi.NULL + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_socket_open_close(self): + ctx = C.zmq_ctx_new() + socket = C.zmq_socket(ctx, PUSH) + + assert ctx != ffi.NULL + assert ffi.NULL != socket + assert 0 == C.zmq_close(socket) + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_setsockopt(self): + ctx = C.zmq_ctx_new() + socket = C.zmq_socket(ctx, PUSH) + + identity = ffi.new('char[3]', 'zmq') + ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3) + + assert ret == 0 + assert ctx != ffi.NULL + assert ffi.NULL != socket + assert 0 == C.zmq_close(socket) + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_getsockopt(self): + ctx = C.zmq_ctx_new() + socket = C.zmq_socket(ctx, PUSH) + + identity = ffi.new('char[]', 'zmq') + ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3) + assert ret == 0 + + option_len = ffi.new('size_t*', 3) + option = ffi.new('char*') + ret = C.zmq_getsockopt(socket, + IDENTITY, + ffi.cast('void*', option), + option_len) + + assert ret == 0 + assert ffi.string(ffi.cast('char*', option))[0] == "z" + assert ffi.string(ffi.cast('char*', option))[1] == "m" + assert ffi.string(ffi.cast('char*', option))[2] == "q" + assert ctx != ffi.NULL + assert ffi.NULL != socket + assert 0 == C.zmq_close(socket) + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_bind(self): + ctx = C.zmq_ctx_new() + socket = C.zmq_socket(ctx, 8) + + assert 0 == C.zmq_bind(socket, 'tcp://*:4444') + assert ctx != ffi.NULL + assert ffi.NULL != socket + assert 0 == C.zmq_close(socket) + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_bind_connect(self): + ctx = C.zmq_ctx_new() + + socket1 = C.zmq_socket(ctx, PUSH) + socket2 = C.zmq_socket(ctx, PULL) + + assert 0 == C.zmq_bind(socket1, 'tcp://*:4444') + assert 0 == C.zmq_connect(socket2, 'tcp://127.0.0.1:4444') + assert ctx != ffi.NULL + assert ffi.NULL != socket1 + assert ffi.NULL != socket2 + assert 0 == C.zmq_close(socket1) + assert 0 == C.zmq_close(socket2) + assert 0 == C.zmq_ctx_destroy(ctx) + + def test_zmq_msg_init_close(self): + zmq_msg = ffi.new('zmq_msg_t*') + + assert ffi.NULL != zmq_msg + assert 0 == C.zmq_msg_init(zmq_msg) + assert 0 == C.zmq_msg_close(zmq_msg) + + def test_zmq_msg_init_size(self): + zmq_msg = ffi.new('zmq_msg_t*') + + assert ffi.NULL != zmq_msg + assert 0 == C.zmq_msg_init_size(zmq_msg, 10) + assert 0 == C.zmq_msg_close(zmq_msg) + + def test_zmq_msg_init_data(self): + zmq_msg = ffi.new('zmq_msg_t*') + message = ffi.new('char[5]', 'Hello') + + assert 0 == C.zmq_msg_init_data(zmq_msg, + ffi.cast('void*', message), + 5, + ffi.NULL, + ffi.NULL) + + assert ffi.NULL != zmq_msg + assert 0 == C.zmq_msg_close(zmq_msg) + + def test_zmq_msg_data(self): + zmq_msg = ffi.new('zmq_msg_t*') + message = ffi.new('char[]', 'Hello') + assert 0 == C.zmq_msg_init_data(zmq_msg, + ffi.cast('void*', message), + 5, + ffi.NULL, + ffi.NULL) + + data = C.zmq_msg_data(zmq_msg) + + assert ffi.NULL != zmq_msg + assert ffi.string(ffi.cast("char*", data)) == 'Hello' + assert 0 == C.zmq_msg_close(zmq_msg) + + + def test_zmq_send(self): + ctx = C.zmq_ctx_new() + + sender = C.zmq_socket(ctx, REQ) + receiver = C.zmq_socket(ctx, REP) + + assert 0 == C.zmq_bind(receiver, 'tcp://*:7777') + assert 0 == C.zmq_connect(sender, 'tcp://127.0.0.1:7777') + + time.sleep(0.1) + + zmq_msg = ffi.new('zmq_msg_t*') + message = ffi.new('char[5]', 'Hello') + + C.zmq_msg_init_data(zmq_msg, + ffi.cast('void*', message), + ffi.cast('size_t', 5), + ffi.NULL, + ffi.NULL) + + assert 5 == C.zmq_msg_send(zmq_msg, sender, 0) + assert 0 == C.zmq_msg_close(zmq_msg) + assert C.zmq_close(sender) == 0 + assert C.zmq_close(receiver) == 0 + assert C.zmq_ctx_destroy(ctx) == 0 + + def test_zmq_recv(self): + ctx = C.zmq_ctx_new() + + sender = C.zmq_socket(ctx, REQ) + receiver = C.zmq_socket(ctx, REP) + + assert 0 == C.zmq_bind(receiver, 'tcp://*:2222') + assert 0 == C.zmq_connect(sender, 'tcp://127.0.0.1:2222') + + time.sleep(0.1) + + zmq_msg = ffi.new('zmq_msg_t*') + message = ffi.new('char[5]', 'Hello') + + C.zmq_msg_init_data(zmq_msg, + ffi.cast('void*', message), + ffi.cast('size_t', 5), + ffi.NULL, + ffi.NULL) + + zmq_msg2 = ffi.new('zmq_msg_t*') + C.zmq_msg_init(zmq_msg2) + + assert 5 == C.zmq_msg_send(zmq_msg, sender, 0) + assert 5 == C.zmq_msg_recv(zmq_msg2, receiver, 0) + assert 5 == C.zmq_msg_size(zmq_msg2) + assert b"Hello" == ffi.buffer(C.zmq_msg_data(zmq_msg2), + C.zmq_msg_size(zmq_msg2))[:] + assert C.zmq_close(sender) == 0 + assert C.zmq_close(receiver) == 0 + assert C.zmq_ctx_destroy(ctx) == 0 + + def test_zmq_poll(self): + ctx = C.zmq_ctx_new() + + sender = C.zmq_socket(ctx, REQ) + receiver = C.zmq_socket(ctx, REP) + + r1 = C.zmq_bind(receiver, 'tcp://*:3333') + r2 = C.zmq_connect(sender, 'tcp://127.0.0.1:3333') + + zmq_msg = ffi.new('zmq_msg_t*') + message = ffi.new('char[5]', 'Hello') + + C.zmq_msg_init_data(zmq_msg, + ffi.cast('void*', message), + ffi.cast('size_t', 5), + ffi.NULL, + ffi.NULL) + + receiver_pollitem = ffi.new('zmq_pollitem_t*') + receiver_pollitem.socket = receiver + receiver_pollitem.fd = 0 + receiver_pollitem.events = POLLIN | POLLOUT + receiver_pollitem.revents = 0 + + ret = C.zmq_poll(ffi.NULL, 0, 0) + assert ret == 0 + + ret = C.zmq_poll(receiver_pollitem, 1, 0) + assert ret == 0 + + ret = C.zmq_msg_send(zmq_msg, sender, 0) + print(ffi.string(C.zmq_strerror(C.zmq_errno()))) + assert ret == 5 + + time.sleep(0.2) + + ret = C.zmq_poll(receiver_pollitem, 1, 0) + assert ret == 1 + + assert int(receiver_pollitem.revents) & POLLIN + assert not int(receiver_pollitem.revents) & POLLOUT + + zmq_msg2 = ffi.new('zmq_msg_t*') + C.zmq_msg_init(zmq_msg2) + + ret_recv = C.zmq_msg_recv(zmq_msg2, receiver, 0) + assert ret_recv == 5 + + assert 5 == C.zmq_msg_size(zmq_msg2) + assert b"Hello" == ffi.buffer(C.zmq_msg_data(zmq_msg2), + C.zmq_msg_size(zmq_msg2))[:] + + sender_pollitem = ffi.new('zmq_pollitem_t*') + sender_pollitem.socket = sender + sender_pollitem.fd = 0 + sender_pollitem.events = POLLIN | POLLOUT + sender_pollitem.revents = 0 + + ret = C.zmq_poll(sender_pollitem, 1, 0) + assert ret == 0 + + zmq_msg_again = ffi.new('zmq_msg_t*') + message_again = ffi.new('char[11]', 'Hello Again') + + C.zmq_msg_init_data(zmq_msg_again, + ffi.cast('void*', message_again), + ffi.cast('size_t', 11), + ffi.NULL, + ffi.NULL) + + assert 11 == C.zmq_msg_send(zmq_msg_again, receiver, 0) + + time.sleep(0.2) + + assert 0 <= C.zmq_poll(sender_pollitem, 1, 0) + assert int(sender_pollitem.revents) & POLLIN + assert 11 == C.zmq_msg_recv(zmq_msg2, sender, 0) + assert 11 == C.zmq_msg_size(zmq_msg2) + assert b"Hello Again" == ffi.buffer(C.zmq_msg_data(zmq_msg2), + int(C.zmq_msg_size(zmq_msg2)))[:] + assert 0 == C.zmq_close(sender) + assert 0 == C.zmq_close(receiver) + assert 0 == C.zmq_ctx_destroy(ctx) + assert 0 == C.zmq_msg_close(zmq_msg) + assert 0 == C.zmq_msg_close(zmq_msg2) + assert 0 == C.zmq_msg_close(zmq_msg_again) + + def test_zmq_stopwatch_functions(self): + stopwatch = C.zmq_stopwatch_start() + ret = C.zmq_stopwatch_stop(stopwatch) + + assert ffi.NULL != stopwatch + assert 0 < int(ret) + + def test_zmq_sleep(self): + try: + C.zmq_sleep(1) + except Exception as e: + raise AssertionError("Error executing zmq_sleep(int)") + diff --git a/src/console/zmq/tests/test_constants.py b/src/console/zmq/tests/test_constants.py new file mode 100755 index 00000000..d32b2b48 --- /dev/null +++ b/src/console/zmq/tests/test_constants.py @@ -0,0 +1,104 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import json +from unittest import TestCase + +import zmq + +from zmq.utils import constant_names +from zmq.sugar import constants as sugar_constants +from zmq.backend import constants as backend_constants + +all_set = set(constant_names.all_names) + +class TestConstants(TestCase): + + def _duplicate_test(self, namelist, listname): + """test that a given list has no duplicates""" + dupes = {} + for name in set(namelist): + cnt = namelist.count(name) + if cnt > 1: + dupes[name] = cnt + if dupes: + self.fail("The following names occur more than once in %s: %s" % (listname, json.dumps(dupes, indent=2))) + + def test_duplicate_all(self): + return self._duplicate_test(constant_names.all_names, "all_names") + + def _change_key(self, change, version): + """return changed-in key""" + return "%s-in %d.%d.%d" % tuple([change] + list(version)) + + def test_duplicate_changed(self): + all_changed = [] + for change in ("new", "removed"): + d = getattr(constant_names, change + "_in") + for version, namelist in d.items(): + all_changed.extend(namelist) + self._duplicate_test(namelist, self._change_key(change, version)) + + self._duplicate_test(all_changed, "all-changed") + + def test_changed_in_all(self): + missing = {} + for change in ("new", "removed"): + d = getattr(constant_names, change + "_in") + for version, namelist in d.items(): + key = self._change_key(change, version) + for name in namelist: + if name not in all_set: + if key not in missing: + missing[key] = [] + missing[key].append(name) + + if missing: + self.fail( + "The following names are missing in `all_names`: %s" % json.dumps(missing, indent=2) + ) + + def test_no_negative_constants(self): + for name in sugar_constants.__all__: + self.assertNotEqual(getattr(zmq, name), sugar_constants._UNDEFINED) + + def test_undefined_constants(self): + all_aliases = [] + for alias_group in sugar_constants.aliases: + all_aliases.extend(alias_group) + + for name in all_set.difference(all_aliases): + raw = getattr(backend_constants, name) + if raw == sugar_constants._UNDEFINED: + self.assertRaises(AttributeError, getattr, zmq, name) + else: + self.assertEqual(getattr(zmq, name), raw) + + def test_new(self): + zmq_version = zmq.zmq_version_info() + for version, new_names in constant_names.new_in.items(): + should_have = zmq_version >= version + for name in new_names: + try: + value = getattr(zmq, name) + except AttributeError: + if should_have: + self.fail("AttributeError: zmq.%s" % name) + else: + if not should_have: + self.fail("Shouldn't have: zmq.%s=%s" % (name, value)) + + def test_removed(self): + zmq_version = zmq.zmq_version_info() + for version, new_names in constant_names.removed_in.items(): + should_have = zmq_version < version + for name in new_names: + try: + value = getattr(zmq, name) + except AttributeError: + if should_have: + self.fail("AttributeError: zmq.%s" % name) + else: + if not should_have: + self.fail("Shouldn't have: zmq.%s=%s" % (name, value)) + diff --git a/src/console/zmq/tests/test_context.py b/src/console/zmq/tests/test_context.py new file mode 100755 index 00000000..e3280778 --- /dev/null +++ b/src/console/zmq/tests/test_context.py @@ -0,0 +1,257 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import gc +import sys +import time +from threading import Thread, Event + +import zmq +from zmq.tests import ( + BaseZMQTestCase, have_gevent, GreenTest, skip_green, PYPY, SkipTest, +) + + +class TestContext(BaseZMQTestCase): + + def test_init(self): + c1 = self.Context() + self.assert_(isinstance(c1, self.Context)) + del c1 + c2 = self.Context() + self.assert_(isinstance(c2, self.Context)) + del c2 + c3 = self.Context() + self.assert_(isinstance(c3, self.Context)) + del c3 + + def test_dir(self): + ctx = self.Context() + self.assertTrue('socket' in dir(ctx)) + if zmq.zmq_version_info() > (3,): + self.assertTrue('IO_THREADS' in dir(ctx)) + ctx.term() + + def test_term(self): + c = self.Context() + c.term() + self.assert_(c.closed) + + def test_context_manager(self): + with self.Context() as c: + pass + self.assert_(c.closed) + + def test_fail_init(self): + self.assertRaisesErrno(zmq.EINVAL, self.Context, -1) + + def test_term_hang(self): + rep,req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER) + req.setsockopt(zmq.LINGER, 0) + req.send(b'hello', copy=False) + req.close() + rep.close() + self.context.term() + + def test_instance(self): + ctx = self.Context.instance() + c2 = self.Context.instance(io_threads=2) + self.assertTrue(c2 is ctx) + c2.term() + c3 = self.Context.instance() + c4 = self.Context.instance() + self.assertFalse(c3 is c2) + self.assertFalse(c3.closed) + self.assertTrue(c3 is c4) + + def test_many_sockets(self): + """opening and closing many sockets shouldn't cause problems""" + ctx = self.Context() + for i in range(16): + sockets = [ ctx.socket(zmq.REP) for i in range(65) ] + [ s.close() for s in sockets ] + # give the reaper a chance + time.sleep(1e-2) + ctx.term() + + def test_sockopts(self): + """setting socket options with ctx attributes""" + ctx = self.Context() + ctx.linger = 5 + self.assertEqual(ctx.linger, 5) + s = ctx.socket(zmq.REQ) + self.assertEqual(s.linger, 5) + self.assertEqual(s.getsockopt(zmq.LINGER), 5) + s.close() + # check that subscribe doesn't get set on sockets that don't subscribe: + ctx.subscribe = b'' + s = ctx.socket(zmq.REQ) + s.close() + + ctx.term() + + + def test_destroy(self): + """Context.destroy should close sockets""" + ctx = self.Context() + sockets = [ ctx.socket(zmq.REP) for i in range(65) ] + + # close half of the sockets + [ s.close() for s in sockets[::2] ] + + ctx.destroy() + # reaper is not instantaneous + time.sleep(1e-2) + for s in sockets: + self.assertTrue(s.closed) + + def test_destroy_linger(self): + """Context.destroy should set linger on closing sockets""" + req,rep = self.create_bound_pair(zmq.REQ, zmq.REP) + req.send(b'hi') + time.sleep(1e-2) + self.context.destroy(linger=0) + # reaper is not instantaneous + time.sleep(1e-2) + for s in (req,rep): + self.assertTrue(s.closed) + + def test_term_noclose(self): + """Context.term won't close sockets""" + ctx = self.Context() + s = ctx.socket(zmq.REQ) + self.assertFalse(s.closed) + t = Thread(target=ctx.term) + t.start() + t.join(timeout=0.1) + self.assertTrue(t.is_alive(), "Context should be waiting") + s.close() + t.join(timeout=0.1) + self.assertFalse(t.is_alive(), "Context should have closed") + + def test_gc(self): + """test close&term by garbage collection alone""" + if PYPY: + raise SkipTest("GC doesn't work ") + + # test credit @dln (GH #137): + def gcf(): + def inner(): + ctx = self.Context() + s = ctx.socket(zmq.PUSH) + inner() + gc.collect() + t = Thread(target=gcf) + t.start() + t.join(timeout=1) + self.assertFalse(t.is_alive(), "Garbage collection should have cleaned up context") + + def test_cyclic_destroy(self): + """ctx.destroy should succeed when cyclic ref prevents gc""" + # test credit @dln (GH #137): + class CyclicReference(object): + def __init__(self, parent=None): + self.parent = parent + + def crash(self, sock): + self.sock = sock + self.child = CyclicReference(self) + + def crash_zmq(): + ctx = self.Context() + sock = ctx.socket(zmq.PULL) + c = CyclicReference() + c.crash(sock) + ctx.destroy() + + crash_zmq() + + def test_term_thread(self): + """ctx.term should not crash active threads (#139)""" + ctx = self.Context() + evt = Event() + evt.clear() + + def block(): + s = ctx.socket(zmq.REP) + s.bind_to_random_port('tcp://127.0.0.1') + evt.set() + try: + s.recv() + except zmq.ZMQError as e: + self.assertEqual(e.errno, zmq.ETERM) + return + finally: + s.close() + self.fail("recv should have been interrupted with ETERM") + t = Thread(target=block) + t.start() + + evt.wait(1) + self.assertTrue(evt.is_set(), "sync event never fired") + time.sleep(0.01) + ctx.term() + t.join(timeout=1) + self.assertFalse(t.is_alive(), "term should have interrupted s.recv()") + + def test_destroy_no_sockets(self): + ctx = self.Context() + s = ctx.socket(zmq.PUB) + s.bind_to_random_port('tcp://127.0.0.1') + s.close() + ctx.destroy() + assert s.closed + assert ctx.closed + + def test_ctx_opts(self): + if zmq.zmq_version_info() < (3,): + raise SkipTest("context options require libzmq 3") + ctx = self.Context() + ctx.set(zmq.MAX_SOCKETS, 2) + self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 2) + ctx.max_sockets = 100 + self.assertEqual(ctx.max_sockets, 100) + self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 100) + + def test_shadow(self): + ctx = self.Context() + ctx2 = self.Context.shadow(ctx.underlying) + self.assertEqual(ctx.underlying, ctx2.underlying) + s = ctx.socket(zmq.PUB) + s.close() + del ctx2 + self.assertFalse(ctx.closed) + s = ctx.socket(zmq.PUB) + ctx2 = self.Context.shadow(ctx.underlying) + s2 = ctx2.socket(zmq.PUB) + s.close() + s2.close() + ctx.term() + self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB) + del ctx2 + + def test_shadow_pyczmq(self): + try: + from pyczmq import zctx, zsocket, zstr + except Exception: + raise SkipTest("Requires pyczmq") + + ctx = zctx.new() + a = zsocket.new(ctx, zmq.PUSH) + zsocket.bind(a, "inproc://a") + ctx2 = self.Context.shadow_pyczmq(ctx) + b = ctx2.socket(zmq.PULL) + b.connect("inproc://a") + zstr.send(a, b'hi') + rcvd = self.recv(b) + self.assertEqual(rcvd, b'hi') + b.close() + + +if False: # disable green context tests + class TestContextGreen(GreenTest, TestContext): + """gevent subclass of context tests""" + # skip tests that use real threads: + test_gc = GreenTest.skip_green + test_term_thread = GreenTest.skip_green + test_destroy_linger = GreenTest.skip_green diff --git a/src/console/zmq/tests/test_device.py b/src/console/zmq/tests/test_device.py new file mode 100755 index 00000000..f8305074 --- /dev/null +++ b/src/console/zmq/tests/test_device.py @@ -0,0 +1,146 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time + +import zmq +from zmq import devices +from zmq.tests import BaseZMQTestCase, SkipTest, have_gevent, GreenTest, PYPY +from zmq.utils.strtypes import (bytes,unicode,basestring) + +if PYPY: + # cleanup of shared Context doesn't work on PyPy + devices.Device.context_factory = zmq.Context + +class TestDevice(BaseZMQTestCase): + + def test_device_types(self): + for devtype in (zmq.STREAMER, zmq.FORWARDER, zmq.QUEUE): + dev = devices.Device(devtype, zmq.PAIR, zmq.PAIR) + self.assertEqual(dev.device_type, devtype) + del dev + + def test_device_attributes(self): + dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB) + self.assertEqual(dev.in_type, zmq.SUB) + self.assertEqual(dev.out_type, zmq.PUB) + self.assertEqual(dev.device_type, zmq.QUEUE) + self.assertEqual(dev.daemon, True) + del dev + + def test_tsdevice_attributes(self): + dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB) + self.assertEqual(dev.in_type, zmq.SUB) + self.assertEqual(dev.out_type, zmq.PUB) + self.assertEqual(dev.device_type, zmq.QUEUE) + self.assertEqual(dev.daemon, True) + del dev + + + def test_single_socket_forwarder_connect(self): + dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) + req = self.context.socket(zmq.REQ) + port = req.bind_to_random_port('tcp://127.0.0.1') + dev.connect_in('tcp://127.0.0.1:%i'%port) + dev.start() + time.sleep(.25) + msg = b'hello' + req.send(msg) + self.assertEqual(msg, self.recv(req)) + del dev + req.close() + dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) + req = self.context.socket(zmq.REQ) + port = req.bind_to_random_port('tcp://127.0.0.1') + dev.connect_out('tcp://127.0.0.1:%i'%port) + dev.start() + time.sleep(.25) + msg = b'hello again' + req.send(msg) + self.assertEqual(msg, self.recv(req)) + del dev + req.close() + + def test_single_socket_forwarder_bind(self): + dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) + # select random port: + binder = self.context.socket(zmq.REQ) + port = binder.bind_to_random_port('tcp://127.0.0.1') + binder.close() + time.sleep(0.1) + req = self.context.socket(zmq.REQ) + req.connect('tcp://127.0.0.1:%i'%port) + dev.bind_in('tcp://127.0.0.1:%i'%port) + dev.start() + time.sleep(.25) + msg = b'hello' + req.send(msg) + self.assertEqual(msg, self.recv(req)) + del dev + req.close() + dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) + # select random port: + binder = self.context.socket(zmq.REQ) + port = binder.bind_to_random_port('tcp://127.0.0.1') + binder.close() + time.sleep(0.1) + req = self.context.socket(zmq.REQ) + req.connect('tcp://127.0.0.1:%i'%port) + dev.bind_in('tcp://127.0.0.1:%i'%port) + dev.start() + time.sleep(.25) + msg = b'hello again' + req.send(msg) + self.assertEqual(msg, self.recv(req)) + del dev + req.close() + + def test_proxy(self): + if zmq.zmq_version_info() < (3,2): + raise SkipTest("Proxies only in libzmq >= 3") + dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH) + binder = self.context.socket(zmq.REQ) + iface = 'tcp://127.0.0.1' + port = binder.bind_to_random_port(iface) + port2 = binder.bind_to_random_port(iface) + port3 = binder.bind_to_random_port(iface) + binder.close() + time.sleep(0.1) + dev.bind_in("%s:%i" % (iface, port)) + dev.bind_out("%s:%i" % (iface, port2)) + dev.bind_mon("%s:%i" % (iface, port3)) + dev.start() + time.sleep(0.25) + msg = b'hello' + push = self.context.socket(zmq.PUSH) + push.connect("%s:%i" % (iface, port)) + pull = self.context.socket(zmq.PULL) + pull.connect("%s:%i" % (iface, port2)) + mon = self.context.socket(zmq.PULL) + mon.connect("%s:%i" % (iface, port3)) + push.send(msg) + self.sockets.extend([push, pull, mon]) + self.assertEqual(msg, self.recv(pull)) + self.assertEqual(msg, self.recv(mon)) + +if have_gevent: + import gevent + import zmq.green + + class TestDeviceGreen(GreenTest, BaseZMQTestCase): + + def test_green_device(self): + rep = self.context.socket(zmq.REP) + req = self.context.socket(zmq.REQ) + self.sockets.extend([req, rep]) + port = rep.bind_to_random_port('tcp://127.0.0.1') + g = gevent.spawn(zmq.green.device, zmq.QUEUE, rep, rep) + req.connect('tcp://127.0.0.1:%i' % port) + req.send(b'hi') + timeout = gevent.Timeout(3) + timeout.start() + receiver = gevent.spawn(req.recv) + self.assertEqual(receiver.get(2), b'hi') + timeout.cancel() + g.kill(block=True) + diff --git a/src/console/zmq/tests/test_error.py b/src/console/zmq/tests/test_error.py new file mode 100755 index 00000000..a2eee14a --- /dev/null +++ b/src/console/zmq/tests/test_error.py @@ -0,0 +1,43 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys +import time + +import zmq +from zmq import ZMQError, strerror, Again, ContextTerminated +from zmq.tests import BaseZMQTestCase + +if sys.version_info[0] >= 3: + long = int + +class TestZMQError(BaseZMQTestCase): + + def test_strerror(self): + """test that strerror gets the right type.""" + for i in range(10): + e = strerror(i) + self.assertTrue(isinstance(e, str)) + + def test_zmqerror(self): + for errno in range(10): + e = ZMQError(errno) + self.assertEqual(e.errno, errno) + self.assertEqual(str(e), strerror(errno)) + + def test_again(self): + s = self.context.socket(zmq.REP) + self.assertRaises(Again, s.recv, zmq.NOBLOCK) + self.assertRaisesErrno(zmq.EAGAIN, s.recv, zmq.NOBLOCK) + s.close() + + def atest_ctxterm(self): + s = self.context.socket(zmq.REP) + t = Thread(target=self.context.term) + t.start() + self.assertRaises(ContextTerminated, s.recv, zmq.NOBLOCK) + self.assertRaisesErrno(zmq.TERM, s.recv, zmq.NOBLOCK) + s.close() + t.join() + diff --git a/src/console/zmq/tests/test_etc.py b/src/console/zmq/tests/test_etc.py new file mode 100755 index 00000000..ad224064 --- /dev/null +++ b/src/console/zmq/tests/test_etc.py @@ -0,0 +1,15 @@ +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +import sys + +import zmq + +from . import skip_if + +@skip_if(zmq.zmq_version_info() < (4,1), "libzmq < 4.1") +def test_has(): + assert not zmq.has('something weird') + has_ipc = zmq.has('ipc') + not_windows = not sys.platform.startswith('win') + assert has_ipc == not_windows diff --git a/src/console/zmq/tests/test_imports.py b/src/console/zmq/tests/test_imports.py new file mode 100755 index 00000000..c0ddfaac --- /dev/null +++ b/src/console/zmq/tests/test_imports.py @@ -0,0 +1,62 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys +from unittest import TestCase + +class TestImports(TestCase): + """Test Imports - the quickest test to ensure that we haven't + introduced version-incompatible syntax errors.""" + + def test_toplevel(self): + """test toplevel import""" + import zmq + + def test_core(self): + """test core imports""" + from zmq import Context + from zmq import Socket + from zmq import Poller + from zmq import Frame + from zmq import constants + from zmq import device, proxy + from zmq import Stopwatch + from zmq import ( + zmq_version, + zmq_version_info, + pyzmq_version, + pyzmq_version_info, + ) + + def test_devices(self): + """test device imports""" + import zmq.devices + from zmq.devices import basedevice + from zmq.devices import monitoredqueue + from zmq.devices import monitoredqueuedevice + + def test_log(self): + """test log imports""" + import zmq.log + from zmq.log import handlers + + def test_eventloop(self): + """test eventloop imports""" + import zmq.eventloop + from zmq.eventloop import ioloop + from zmq.eventloop import zmqstream + from zmq.eventloop.minitornado.platform import auto + from zmq.eventloop.minitornado import ioloop + + def test_utils(self): + """test util imports""" + import zmq.utils + from zmq.utils import strtypes + from zmq.utils import jsonapi + + def test_ssh(self): + """test ssh imports""" + from zmq.ssh import tunnel + + + diff --git a/src/console/zmq/tests/test_ioloop.py b/src/console/zmq/tests/test_ioloop.py new file mode 100755 index 00000000..2a8b1153 --- /dev/null +++ b/src/console/zmq/tests/test_ioloop.py @@ -0,0 +1,113 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +import os +import threading + +import zmq +from zmq.tests import BaseZMQTestCase +from zmq.eventloop import ioloop +from zmq.eventloop.minitornado.ioloop import _Timeout +try: + from tornado.ioloop import PollIOLoop, IOLoop as BaseIOLoop +except ImportError: + from zmq.eventloop.minitornado.ioloop import IOLoop as BaseIOLoop + + +def printer(): + os.system("say hello") + raise Exception + print (time.time()) + + +class Delay(threading.Thread): + def __init__(self, f, delay=1): + self.f=f + self.delay=delay + self.aborted=False + self.cond=threading.Condition() + super(Delay, self).__init__() + + def run(self): + self.cond.acquire() + self.cond.wait(self.delay) + self.cond.release() + if not self.aborted: + self.f() + + def abort(self): + self.aborted=True + self.cond.acquire() + self.cond.notify() + self.cond.release() + + +class TestIOLoop(BaseZMQTestCase): + + def test_simple(self): + """simple IOLoop creation test""" + loop = ioloop.IOLoop() + dc = ioloop.PeriodicCallback(loop.stop, 200, loop) + pc = ioloop.PeriodicCallback(lambda : None, 10, loop) + pc.start() + dc.start() + t = Delay(loop.stop,1) + t.start() + loop.start() + if t.isAlive(): + t.abort() + else: + self.fail("IOLoop failed to exit") + + def test_timeout_compare(self): + """test timeout comparisons""" + loop = ioloop.IOLoop() + t = _Timeout(1, 2, loop) + t2 = _Timeout(1, 3, loop) + self.assertEqual(t < t2, id(t) < id(t2)) + t2 = _Timeout(2,1, loop) + self.assertTrue(t < t2) + + def test_poller_events(self): + """Tornado poller implementation maps events correctly""" + req,rep = self.create_bound_pair(zmq.REQ, zmq.REP) + poller = ioloop.ZMQPoller() + poller.register(req, ioloop.IOLoop.READ) + poller.register(rep, ioloop.IOLoop.READ) + events = dict(poller.poll(0)) + self.assertEqual(events.get(rep), None) + self.assertEqual(events.get(req), None) + + poller.register(req, ioloop.IOLoop.WRITE) + poller.register(rep, ioloop.IOLoop.WRITE) + events = dict(poller.poll(1)) + self.assertEqual(events.get(req), ioloop.IOLoop.WRITE) + self.assertEqual(events.get(rep), None) + + poller.register(rep, ioloop.IOLoop.READ) + req.send(b'hi') + events = dict(poller.poll(1)) + self.assertEqual(events.get(rep), ioloop.IOLoop.READ) + self.assertEqual(events.get(req), None) + + def test_instance(self): + """Test IOLoop.instance returns the right object""" + loop = ioloop.IOLoop.instance() + self.assertEqual(loop.__class__, ioloop.IOLoop) + loop = BaseIOLoop.instance() + self.assertEqual(loop.__class__, ioloop.IOLoop) + + def test_close_all(self): + """Test close(all_fds=True)""" + loop = ioloop.IOLoop.instance() + req,rep = self.create_bound_pair(zmq.REQ, zmq.REP) + loop.add_handler(req, lambda msg: msg, ioloop.IOLoop.READ) + loop.add_handler(rep, lambda msg: msg, ioloop.IOLoop.READ) + self.assertEqual(req.closed, False) + self.assertEqual(rep.closed, False) + loop.close(all_fds=True) + self.assertEqual(req.closed, True) + self.assertEqual(rep.closed, True) + + diff --git a/src/console/zmq/tests/test_log.py b/src/console/zmq/tests/test_log.py new file mode 100755 index 00000000..9206f095 --- /dev/null +++ b/src/console/zmq/tests/test_log.py @@ -0,0 +1,116 @@ +# encoding: utf-8 + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import logging +import time +from unittest import TestCase + +import zmq +from zmq.log import handlers +from zmq.utils.strtypes import b, u +from zmq.tests import BaseZMQTestCase + + +class TestPubLog(BaseZMQTestCase): + + iface = 'inproc://zmqlog' + topic= 'zmq' + + @property + def logger(self): + # print dir(self) + logger = logging.getLogger('zmqtest') + logger.setLevel(logging.DEBUG) + return logger + + def connect_handler(self, topic=None): + topic = self.topic if topic is None else topic + logger = self.logger + pub,sub = self.create_bound_pair(zmq.PUB, zmq.SUB) + handler = handlers.PUBHandler(pub) + handler.setLevel(logging.DEBUG) + handler.root_topic = topic + logger.addHandler(handler) + sub.setsockopt(zmq.SUBSCRIBE, b(topic)) + time.sleep(0.1) + return logger, handler, sub + + def test_init_iface(self): + logger = self.logger + ctx = self.context + handler = handlers.PUBHandler(self.iface) + self.assertFalse(handler.ctx is ctx) + self.sockets.append(handler.socket) + # handler.ctx.term() + handler = handlers.PUBHandler(self.iface, self.context) + self.sockets.append(handler.socket) + self.assertTrue(handler.ctx is ctx) + handler.setLevel(logging.DEBUG) + handler.root_topic = self.topic + logger.addHandler(handler) + sub = ctx.socket(zmq.SUB) + self.sockets.append(sub) + sub.setsockopt(zmq.SUBSCRIBE, b(self.topic)) + sub.connect(self.iface) + import time; time.sleep(0.25) + msg1 = 'message' + logger.info(msg1) + + (topic, msg2) = sub.recv_multipart() + self.assertEqual(topic, b'zmq.INFO') + self.assertEqual(msg2, b(msg1)+b'\n') + logger.removeHandler(handler) + + def test_init_socket(self): + pub,sub = self.create_bound_pair(zmq.PUB, zmq.SUB) + logger = self.logger + handler = handlers.PUBHandler(pub) + handler.setLevel(logging.DEBUG) + handler.root_topic = self.topic + logger.addHandler(handler) + + self.assertTrue(handler.socket is pub) + self.assertTrue(handler.ctx is pub.context) + self.assertTrue(handler.ctx is self.context) + sub.setsockopt(zmq.SUBSCRIBE, b(self.topic)) + import time; time.sleep(0.1) + msg1 = 'message' + logger.info(msg1) + + (topic, msg2) = sub.recv_multipart() + self.assertEqual(topic, b'zmq.INFO') + self.assertEqual(msg2, b(msg1)+b'\n') + logger.removeHandler(handler) + + def test_root_topic(self): + logger, handler, sub = self.connect_handler() + handler.socket.bind(self.iface) + sub2 = sub.context.socket(zmq.SUB) + self.sockets.append(sub2) + sub2.connect(self.iface) + sub2.setsockopt(zmq.SUBSCRIBE, b'') + handler.root_topic = b'twoonly' + msg1 = 'ignored' + logger.info(msg1) + self.assertRaisesErrno(zmq.EAGAIN, sub.recv, zmq.NOBLOCK) + topic,msg2 = sub2.recv_multipart() + self.assertEqual(topic, b'twoonly.INFO') + self.assertEqual(msg2, b(msg1)+b'\n') + + logger.removeHandler(handler) + + def test_unicode_message(self): + logger, handler, sub = self.connect_handler() + base_topic = b(self.topic + '.INFO') + for msg, expected in [ + (u('hello'), [base_topic, b('hello\n')]), + (u('héllo'), [base_topic, b('héllo\n')]), + (u('tøpic::héllo'), [base_topic + b('.tøpic'), b('héllo\n')]), + ]: + logger.info(msg) + received = sub.recv_multipart() + self.assertEqual(received, expected) + diff --git a/src/console/zmq/tests/test_message.py b/src/console/zmq/tests/test_message.py new file mode 100755 index 00000000..d8770bdf --- /dev/null +++ b/src/console/zmq/tests/test_message.py @@ -0,0 +1,362 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import copy +import sys +try: + from sys import getrefcount as grc +except ImportError: + grc = None + +import time +from pprint import pprint +from unittest import TestCase + +import zmq +from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy, PYPY +from zmq.utils.strtypes import unicode, bytes, b, u + + +# some useful constants: + +x = b'x' + +try: + view = memoryview +except NameError: + view = buffer + +if grc: + rc0 = grc(x) + v = view(x) + view_rc = grc(x) - rc0 + +def await_gc(obj, rc): + """wait for refcount on an object to drop to an expected value + + Necessary because of the zero-copy gc thread, + which can take some time to receive its DECREF message. + """ + for i in range(50): + # rc + 2 because of the refs in this function + if grc(obj) <= rc + 2: + return + time.sleep(0.05) + +class TestFrame(BaseZMQTestCase): + + @skip_pypy + def test_above_30(self): + """Message above 30 bytes are never copied by 0MQ.""" + for i in range(5, 16): # 32, 64,..., 65536 + s = (2**i)*x + self.assertEqual(grc(s), 2) + m = zmq.Frame(s) + self.assertEqual(grc(s), 4) + del m + await_gc(s, 2) + self.assertEqual(grc(s), 2) + del s + + def test_str(self): + """Test the str representations of the Frames.""" + for i in range(16): + s = (2**i)*x + m = zmq.Frame(s) + m_str = str(m) + m_str_b = b(m_str) # py3compat + self.assertEqual(s, m_str_b) + + def test_bytes(self): + """Test the Frame.bytes property.""" + for i in range(1,16): + s = (2**i)*x + m = zmq.Frame(s) + b = m.bytes + self.assertEqual(s, m.bytes) + if not PYPY: + # check that it copies + self.assert_(b is not s) + # check that it copies only once + self.assert_(b is m.bytes) + + def test_unicode(self): + """Test the unicode representations of the Frames.""" + s = u('asdf') + self.assertRaises(TypeError, zmq.Frame, s) + for i in range(16): + s = (2**i)*u('§') + m = zmq.Frame(s.encode('utf8')) + self.assertEqual(s, unicode(m.bytes,'utf8')) + + def test_len(self): + """Test the len of the Frames.""" + for i in range(16): + s = (2**i)*x + m = zmq.Frame(s) + self.assertEqual(len(s), len(m)) + + @skip_pypy + def test_lifecycle1(self): + """Run through a ref counting cycle with a copy.""" + for i in range(5, 16): # 32, 64,..., 65536 + s = (2**i)*x + rc = 2 + self.assertEqual(grc(s), rc) + m = zmq.Frame(s) + rc += 2 + self.assertEqual(grc(s), rc) + m2 = copy.copy(m) + rc += 1 + self.assertEqual(grc(s), rc) + buf = m2.buffer + + rc += view_rc + self.assertEqual(grc(s), rc) + + self.assertEqual(s, b(str(m))) + self.assertEqual(s, bytes(m2)) + self.assertEqual(s, m.bytes) + # self.assert_(s is str(m)) + # self.assert_(s is str(m2)) + del m2 + rc -= 1 + self.assertEqual(grc(s), rc) + rc -= view_rc + del buf + self.assertEqual(grc(s), rc) + del m + rc -= 2 + await_gc(s, rc) + self.assertEqual(grc(s), rc) + self.assertEqual(rc, 2) + del s + + @skip_pypy + def test_lifecycle2(self): + """Run through a different ref counting cycle with a copy.""" + for i in range(5, 16): # 32, 64,..., 65536 + s = (2**i)*x + rc = 2 + self.assertEqual(grc(s), rc) + m = zmq.Frame(s) + rc += 2 + self.assertEqual(grc(s), rc) + m2 = copy.copy(m) + rc += 1 + self.assertEqual(grc(s), rc) + buf = m.buffer + rc += view_rc + self.assertEqual(grc(s), rc) + self.assertEqual(s, b(str(m))) + self.assertEqual(s, bytes(m2)) + self.assertEqual(s, m2.bytes) + self.assertEqual(s, m.bytes) + # self.assert_(s is str(m)) + # self.assert_(s is str(m2)) + del buf + self.assertEqual(grc(s), rc) + del m + # m.buffer is kept until m is del'd + rc -= view_rc + rc -= 1 + self.assertEqual(grc(s), rc) + del m2 + rc -= 2 + await_gc(s, rc) + self.assertEqual(grc(s), rc) + self.assertEqual(rc, 2) + del s + + @skip_pypy + def test_tracker(self): + m = zmq.Frame(b'asdf', track=True) + self.assertFalse(m.tracker.done) + pm = zmq.MessageTracker(m) + self.assertFalse(pm.done) + del m + for i in range(10): + if pm.done: + break + time.sleep(0.1) + self.assertTrue(pm.done) + + def test_no_tracker(self): + m = zmq.Frame(b'asdf', track=False) + self.assertEqual(m.tracker, None) + m2 = copy.copy(m) + self.assertEqual(m2.tracker, None) + self.assertRaises(ValueError, zmq.MessageTracker, m) + + @skip_pypy + def test_multi_tracker(self): + m = zmq.Frame(b'asdf', track=True) + m2 = zmq.Frame(b'whoda', track=True) + mt = zmq.MessageTracker(m,m2) + self.assertFalse(m.tracker.done) + self.assertFalse(mt.done) + self.assertRaises(zmq.NotDone, mt.wait, 0.1) + del m + time.sleep(0.1) + self.assertRaises(zmq.NotDone, mt.wait, 0.1) + self.assertFalse(mt.done) + del m2 + self.assertTrue(mt.wait() is None) + self.assertTrue(mt.done) + + + def test_buffer_in(self): + """test using a buffer as input""" + ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√") + m = zmq.Frame(view(ins)) + + def test_bad_buffer_in(self): + """test using a bad object""" + self.assertRaises(TypeError, zmq.Frame, 5) + self.assertRaises(TypeError, zmq.Frame, object()) + + def test_buffer_out(self): + """receiving buffered output""" + ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√") + m = zmq.Frame(ins) + outb = m.buffer + self.assertTrue(isinstance(outb, view)) + self.assert_(outb is m.buffer) + self.assert_(m.buffer is m.buffer) + + def test_multisend(self): + """ensure that a message remains intact after multiple sends""" + a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + s = b"message" + m = zmq.Frame(s) + self.assertEqual(s, m.bytes) + + a.send(m, copy=False) + time.sleep(0.1) + self.assertEqual(s, m.bytes) + a.send(m, copy=False) + time.sleep(0.1) + self.assertEqual(s, m.bytes) + a.send(m, copy=True) + time.sleep(0.1) + self.assertEqual(s, m.bytes) + a.send(m, copy=True) + time.sleep(0.1) + self.assertEqual(s, m.bytes) + for i in range(4): + r = b.recv() + self.assertEqual(s,r) + self.assertEqual(s, m.bytes) + + def test_buffer_numpy(self): + """test non-copying numpy array messages""" + try: + import numpy + except ImportError: + raise SkipTest("numpy required") + rand = numpy.random.randint + shapes = [ rand(2,16) for i in range(5) ] + for i in range(1,len(shapes)+1): + shape = shapes[:i] + A = numpy.random.random(shape) + m = zmq.Frame(A) + if view.__name__ == 'buffer': + self.assertEqual(A.data, m.buffer) + B = numpy.frombuffer(m.buffer,dtype=A.dtype).reshape(A.shape) + else: + self.assertEqual(memoryview(A), m.buffer) + B = numpy.array(m.buffer,dtype=A.dtype).reshape(A.shape) + self.assertEqual((A==B).all(), True) + + def test_memoryview(self): + """test messages from memoryview""" + major,minor = sys.version_info[:2] + if not (major >= 3 or (major == 2 and minor >= 7)): + raise SkipTest("memoryviews only in python >= 2.7") + + s = b'carrotjuice' + v = memoryview(s) + m = zmq.Frame(s) + buf = m.buffer + s2 = buf.tobytes() + self.assertEqual(s2,s) + self.assertEqual(m.bytes,s) + + def test_noncopying_recv(self): + """check for clobbering message buffers""" + null = b'\0'*64 + sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + for i in range(32): + # try a few times + sb.send(null, copy=False) + m = sa.recv(copy=False) + mb = m.bytes + # buf = view(m) + buf = m.buffer + del m + for i in range(5): + ff=b'\xff'*(40 + i*10) + sb.send(ff, copy=False) + m2 = sa.recv(copy=False) + if view.__name__ == 'buffer': + b = bytes(buf) + else: + b = buf.tobytes() + self.assertEqual(b, null) + self.assertEqual(mb, null) + self.assertEqual(m2.bytes, ff) + + @skip_pypy + def test_buffer_numpy(self): + """test non-copying numpy array messages""" + try: + import numpy + except ImportError: + raise SkipTest("requires numpy") + if sys.version_info < (2,7): + raise SkipTest("requires new-style buffer interface (py >= 2.7)") + rand = numpy.random.randint + shapes = [ rand(2,5) for i in range(5) ] + a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + dtypes = [int, float, '>i4', 'B'] + for i in range(1,len(shapes)+1): + shape = shapes[:i] + for dt in dtypes: + A = numpy.empty(shape, dtype=dt) + while numpy.isnan(A).any(): + # don't let nan sneak in + A = numpy.ndarray(shape, dtype=dt) + a.send(A, copy=False) + msg = b.recv(copy=False) + + B = numpy.frombuffer(msg, A.dtype).reshape(A.shape) + self.assertEqual(A.shape, B.shape) + self.assertTrue((A==B).all()) + A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')]) + A['a'] = 1024 + A['b'] = 1e9 + A['c'] = 'hello there' + a.send(A, copy=False) + msg = b.recv(copy=False) + + B = numpy.frombuffer(msg, A.dtype).reshape(A.shape) + self.assertEqual(A.shape, B.shape) + self.assertTrue((A==B).all()) + + def test_frame_more(self): + """test Frame.more attribute""" + frame = zmq.Frame(b"hello") + self.assertFalse(frame.more) + sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + sa.send_multipart([b'hi', b'there']) + frame = self.recv(sb, copy=False) + self.assertTrue(frame.more) + if zmq.zmq_version_info()[0] >= 3 and not PYPY: + self.assertTrue(frame.get(zmq.MORE)) + frame = self.recv(sb, copy=False) + self.assertFalse(frame.more) + if zmq.zmq_version_info()[0] >= 3 and not PYPY: + self.assertFalse(frame.get(zmq.MORE)) + diff --git a/src/console/zmq/tests/test_monitor.py b/src/console/zmq/tests/test_monitor.py new file mode 100755 index 00000000..4f035388 --- /dev/null +++ b/src/console/zmq/tests/test_monitor.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import sys +import time +import struct + +from unittest import TestCase + +import zmq +from zmq.tests import BaseZMQTestCase, skip_if, skip_pypy +from zmq.utils.monitor import recv_monitor_message + +skip_lt_4 = skip_if(zmq.zmq_version_info() < (4,), "requires zmq >= 4") + +class TestSocketMonitor(BaseZMQTestCase): + + @skip_lt_4 + def test_monitor(self): + """Test monitoring interface for sockets.""" + s_rep = self.context.socket(zmq.REP) + s_req = self.context.socket(zmq.REQ) + self.sockets.extend([s_rep, s_req]) + s_req.bind("tcp://127.0.0.1:6666") + # try monitoring the REP socket + + s_rep.monitor("inproc://monitor.rep", zmq.EVENT_ALL) + # create listening socket for monitor + s_event = self.context.socket(zmq.PAIR) + self.sockets.append(s_event) + s_event.connect("inproc://monitor.rep") + s_event.linger = 0 + # test receive event for connect event + s_rep.connect("tcp://127.0.0.1:6666") + m = recv_monitor_message(s_event) + if m['event'] == zmq.EVENT_CONNECT_DELAYED: + self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6666") + # test receive event for connected event + m = recv_monitor_message(s_event) + self.assertEqual(m['event'], zmq.EVENT_CONNECTED) + self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6666") + + # test monitor can be disabled. + s_rep.disable_monitor() + m = recv_monitor_message(s_event) + self.assertEqual(m['event'], zmq.EVENT_MONITOR_STOPPED) + + + @skip_lt_4 + def test_monitor_connected(self): + """Test connected monitoring socket.""" + s_rep = self.context.socket(zmq.REP) + s_req = self.context.socket(zmq.REQ) + self.sockets.extend([s_rep, s_req]) + s_req.bind("tcp://127.0.0.1:6667") + # try monitoring the REP socket + # create listening socket for monitor + s_event = s_rep.get_monitor_socket() + s_event.linger = 0 + self.sockets.append(s_event) + # test receive event for connect event + s_rep.connect("tcp://127.0.0.1:6667") + m = recv_monitor_message(s_event) + if m['event'] == zmq.EVENT_CONNECT_DELAYED: + self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6667") + # test receive event for connected event + m = recv_monitor_message(s_event) + self.assertEqual(m['event'], zmq.EVENT_CONNECTED) + self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6667") diff --git a/src/console/zmq/tests/test_monqueue.py b/src/console/zmq/tests/test_monqueue.py new file mode 100755 index 00000000..e855602e --- /dev/null +++ b/src/console/zmq/tests/test_monqueue.py @@ -0,0 +1,227 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +from unittest import TestCase + +import zmq +from zmq import devices + +from zmq.tests import BaseZMQTestCase, SkipTest, PYPY +from zmq.utils.strtypes import unicode + + +if PYPY or zmq.zmq_version_info() >= (4,1): + # cleanup of shared Context doesn't work on PyPy + # there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052) + devices.Device.context_factory = zmq.Context + + +class TestMonitoredQueue(BaseZMQTestCase): + + sockets = [] + + def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'): + self.device = devices.ThreadMonitoredQueue(zmq.PAIR, zmq.PAIR, zmq.PUB, + in_prefix, out_prefix) + alice = self.context.socket(zmq.PAIR) + bob = self.context.socket(zmq.PAIR) + mon = self.context.socket(zmq.SUB) + + aport = alice.bind_to_random_port('tcp://127.0.0.1') + bport = bob.bind_to_random_port('tcp://127.0.0.1') + mport = mon.bind_to_random_port('tcp://127.0.0.1') + mon.setsockopt(zmq.SUBSCRIBE, mon_sub) + + self.device.connect_in("tcp://127.0.0.1:%i"%aport) + self.device.connect_out("tcp://127.0.0.1:%i"%bport) + self.device.connect_mon("tcp://127.0.0.1:%i"%mport) + self.device.start() + time.sleep(.2) + try: + # this is currenlty necessary to ensure no dropped monitor messages + # see LIBZMQ-248 for more info + mon.recv_multipart(zmq.NOBLOCK) + except zmq.ZMQError: + pass + self.sockets.extend([alice, bob, mon]) + return alice, bob, mon + + + def teardown_device(self): + for socket in self.sockets: + socket.close() + del socket + del self.device + + def test_reply(self): + alice, bob, mon = self.build_device() + alices = b"hello bob".split() + alice.send_multipart(alices) + bobs = self.recv_multipart(bob) + self.assertEqual(alices, bobs) + bobs = b"hello alice".split() + bob.send_multipart(bobs) + alices = self.recv_multipart(alice) + self.assertEqual(alices, bobs) + self.teardown_device() + + def test_queue(self): + alice, bob, mon = self.build_device() + alices = b"hello bob".split() + alice.send_multipart(alices) + alices2 = b"hello again".split() + alice.send_multipart(alices2) + alices3 = b"hello again and again".split() + alice.send_multipart(alices3) + bobs = self.recv_multipart(bob) + self.assertEqual(alices, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices2, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices3, bobs) + bobs = b"hello alice".split() + bob.send_multipart(bobs) + alices = self.recv_multipart(alice) + self.assertEqual(alices, bobs) + self.teardown_device() + + def test_monitor(self): + alice, bob, mon = self.build_device() + alices = b"hello bob".split() + alice.send_multipart(alices) + alices2 = b"hello again".split() + alice.send_multipart(alices2) + alices3 = b"hello again and again".split() + alice.send_multipart(alices3) + bobs = self.recv_multipart(bob) + self.assertEqual(alices, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'in']+bobs, mons) + bobs = self.recv_multipart(bob) + self.assertEqual(alices2, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices3, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'in']+alices2, mons) + bobs = b"hello alice".split() + bob.send_multipart(bobs) + alices = self.recv_multipart(alice) + self.assertEqual(alices, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'in']+alices3, mons) + mons = self.recv_multipart(mon) + self.assertEqual([b'out']+bobs, mons) + self.teardown_device() + + def test_prefix(self): + alice, bob, mon = self.build_device(b"", b'foo', b'bar') + alices = b"hello bob".split() + alice.send_multipart(alices) + alices2 = b"hello again".split() + alice.send_multipart(alices2) + alices3 = b"hello again and again".split() + alice.send_multipart(alices3) + bobs = self.recv_multipart(bob) + self.assertEqual(alices, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'foo']+bobs, mons) + bobs = self.recv_multipart(bob) + self.assertEqual(alices2, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices3, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'foo']+alices2, mons) + bobs = b"hello alice".split() + bob.send_multipart(bobs) + alices = self.recv_multipart(alice) + self.assertEqual(alices, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'foo']+alices3, mons) + mons = self.recv_multipart(mon) + self.assertEqual([b'bar']+bobs, mons) + self.teardown_device() + + def test_monitor_subscribe(self): + alice, bob, mon = self.build_device(b"out") + alices = b"hello bob".split() + alice.send_multipart(alices) + alices2 = b"hello again".split() + alice.send_multipart(alices2) + alices3 = b"hello again and again".split() + alice.send_multipart(alices3) + bobs = self.recv_multipart(bob) + self.assertEqual(alices, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices2, bobs) + bobs = self.recv_multipart(bob) + self.assertEqual(alices3, bobs) + bobs = b"hello alice".split() + bob.send_multipart(bobs) + alices = self.recv_multipart(alice) + self.assertEqual(alices, bobs) + mons = self.recv_multipart(mon) + self.assertEqual([b'out']+bobs, mons) + self.teardown_device() + + def test_router_router(self): + """test router-router MQ devices""" + dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out') + self.device = dev + dev.setsockopt_in(zmq.LINGER, 0) + dev.setsockopt_out(zmq.LINGER, 0) + dev.setsockopt_mon(zmq.LINGER, 0) + + binder = self.context.socket(zmq.DEALER) + porta = binder.bind_to_random_port('tcp://127.0.0.1') + portb = binder.bind_to_random_port('tcp://127.0.0.1') + binder.close() + time.sleep(0.1) + a = self.context.socket(zmq.DEALER) + a.identity = b'a' + b = self.context.socket(zmq.DEALER) + b.identity = b'b' + self.sockets.extend([a, b]) + + a.connect('tcp://127.0.0.1:%i'%porta) + dev.bind_in('tcp://127.0.0.1:%i'%porta) + b.connect('tcp://127.0.0.1:%i'%portb) + dev.bind_out('tcp://127.0.0.1:%i'%portb) + dev.start() + time.sleep(0.2) + if zmq.zmq_version_info() >= (3,1,0): + # flush erroneous poll state, due to LIBZMQ-280 + ping_msg = [ b'ping', b'pong' ] + for s in (a,b): + s.send_multipart(ping_msg) + try: + s.recv(zmq.NOBLOCK) + except zmq.ZMQError: + pass + msg = [ b'hello', b'there' ] + a.send_multipart([b'b']+msg) + bmsg = self.recv_multipart(b) + self.assertEqual(bmsg, [b'a']+msg) + b.send_multipart(bmsg) + amsg = self.recv_multipart(a) + self.assertEqual(amsg, [b'b']+msg) + self.teardown_device() + + def test_default_mq_args(self): + self.device = dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.DEALER, zmq.PUB) + dev.setsockopt_in(zmq.LINGER, 0) + dev.setsockopt_out(zmq.LINGER, 0) + dev.setsockopt_mon(zmq.LINGER, 0) + # this will raise if default args are wrong + dev.start() + self.teardown_device() + + def test_mq_check_prefix(self): + ins = self.context.socket(zmq.ROUTER) + outs = self.context.socket(zmq.DEALER) + mons = self.context.socket(zmq.PUB) + self.sockets.extend([ins, outs, mons]) + + ins = unicode('in') + outs = unicode('out') + self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons) diff --git a/src/console/zmq/tests/test_multipart.py b/src/console/zmq/tests/test_multipart.py new file mode 100755 index 00000000..24d41be0 --- /dev/null +++ b/src/console/zmq/tests/test_multipart.py @@ -0,0 +1,35 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import zmq + + +from zmq.tests import BaseZMQTestCase, SkipTest, have_gevent, GreenTest + + +class TestMultipart(BaseZMQTestCase): + + def test_router_dealer(self): + router, dealer = self.create_bound_pair(zmq.ROUTER, zmq.DEALER) + + msg1 = b'message1' + dealer.send(msg1) + ident = self.recv(router) + more = router.rcvmore + self.assertEqual(more, True) + msg2 = self.recv(router) + self.assertEqual(msg1, msg2) + more = router.rcvmore + self.assertEqual(more, False) + + def test_basic_multipart(self): + a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + msg = [ b'hi', b'there', b'b'] + a.send_multipart(msg) + recvd = b.recv_multipart() + self.assertEqual(msg, recvd) + +if have_gevent: + class TestMultipartGreen(GreenTest, TestMultipart): + pass diff --git a/src/console/zmq/tests/test_pair.py b/src/console/zmq/tests/test_pair.py new file mode 100755 index 00000000..e88c1e8b --- /dev/null +++ b/src/console/zmq/tests/test_pair.py @@ -0,0 +1,53 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import zmq + + +from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest + + +x = b' ' +class TestPair(BaseZMQTestCase): + + def test_basic(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + + msg1 = b'message1' + msg2 = self.ping_pong(s1, s2, msg1) + self.assertEqual(msg1, msg2) + + def test_multiple(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + + for i in range(10): + msg = i*x + s1.send(msg) + + for i in range(10): + msg = i*x + s2.send(msg) + + for i in range(10): + msg = s1.recv() + self.assertEqual(msg, i*x) + + for i in range(10): + msg = s2.recv() + self.assertEqual(msg, i*x) + + def test_json(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + o = dict(a=10,b=list(range(10))) + o2 = self.ping_pong_json(s1, s2, o) + + def test_pyobj(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + o = dict(a=10,b=range(10)) + o2 = self.ping_pong_pyobj(s1, s2, o) + +if have_gevent: + class TestReqRepGreen(GreenTest, TestPair): + pass + diff --git a/src/console/zmq/tests/test_poll.py b/src/console/zmq/tests/test_poll.py new file mode 100755 index 00000000..57346c89 --- /dev/null +++ b/src/console/zmq/tests/test_poll.py @@ -0,0 +1,229 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import time +from unittest import TestCase + +import zmq + +from zmq.tests import PollZMQTestCase, have_gevent, GreenTest + +def wait(): + time.sleep(.25) + + +class TestPoll(PollZMQTestCase): + + Poller = zmq.Poller + + # This test is failing due to this issue: + # http://github.com/sustrik/zeromq2/issues#issue/26 + def test_pair(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + + # Sleep to allow sockets to connect. + wait() + + poller = self.Poller() + poller.register(s1, zmq.POLLIN|zmq.POLLOUT) + poller.register(s2, zmq.POLLIN|zmq.POLLOUT) + # Poll result should contain both sockets + socks = dict(poller.poll()) + # Now make sure that both are send ready. + self.assertEqual(socks[s1], zmq.POLLOUT) + self.assertEqual(socks[s2], zmq.POLLOUT) + # Now do a send on both, wait and test for zmq.POLLOUT|zmq.POLLIN + s1.send(b'msg1') + s2.send(b'msg2') + wait() + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLOUT|zmq.POLLIN) + self.assertEqual(socks[s2], zmq.POLLOUT|zmq.POLLIN) + # Make sure that both are in POLLOUT after recv. + s1.recv() + s2.recv() + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLOUT) + self.assertEqual(socks[s2], zmq.POLLOUT) + + poller.unregister(s1) + poller.unregister(s2) + + # Wait for everything to finish. + wait() + + def test_reqrep(self): + s1, s2 = self.create_bound_pair(zmq.REP, zmq.REQ) + + # Sleep to allow sockets to connect. + wait() + + poller = self.Poller() + poller.register(s1, zmq.POLLIN|zmq.POLLOUT) + poller.register(s2, zmq.POLLIN|zmq.POLLOUT) + + # Make sure that s1 is in state 0 and s2 is in POLLOUT + socks = dict(poller.poll()) + self.assertEqual(s1 in socks, 0) + self.assertEqual(socks[s2], zmq.POLLOUT) + + # Make sure that s2 goes immediately into state 0 after send. + s2.send(b'msg1') + socks = dict(poller.poll()) + self.assertEqual(s2 in socks, 0) + + # Make sure that s1 goes into POLLIN state after a time.sleep(). + time.sleep(0.5) + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLIN) + + # Make sure that s1 goes into POLLOUT after recv. + s1.recv() + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLOUT) + + # Make sure s1 goes into state 0 after send. + s1.send(b'msg2') + socks = dict(poller.poll()) + self.assertEqual(s1 in socks, 0) + + # Wait and then see that s2 is in POLLIN. + time.sleep(0.5) + socks = dict(poller.poll()) + self.assertEqual(socks[s2], zmq.POLLIN) + + # Make sure that s2 is in POLLOUT after recv. + s2.recv() + socks = dict(poller.poll()) + self.assertEqual(socks[s2], zmq.POLLOUT) + + poller.unregister(s1) + poller.unregister(s2) + + # Wait for everything to finish. + wait() + + def test_no_events(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + poller = self.Poller() + poller.register(s1, zmq.POLLIN|zmq.POLLOUT) + poller.register(s2, 0) + self.assertTrue(s1 in poller) + self.assertFalse(s2 in poller) + poller.register(s1, 0) + self.assertFalse(s1 in poller) + + def test_pubsub(self): + s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB) + s2.setsockopt(zmq.SUBSCRIBE, b'') + + # Sleep to allow sockets to connect. + wait() + + poller = self.Poller() + poller.register(s1, zmq.POLLIN|zmq.POLLOUT) + poller.register(s2, zmq.POLLIN) + + # Now make sure that both are send ready. + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLOUT) + self.assertEqual(s2 in socks, 0) + # Make sure that s1 stays in POLLOUT after a send. + s1.send(b'msg1') + socks = dict(poller.poll()) + self.assertEqual(socks[s1], zmq.POLLOUT) + + # Make sure that s2 is POLLIN after waiting. + wait() + socks = dict(poller.poll()) + self.assertEqual(socks[s2], zmq.POLLIN) + + # Make sure that s2 goes into 0 after recv. + s2.recv() + socks = dict(poller.poll()) + self.assertEqual(s2 in socks, 0) + + poller.unregister(s1) + poller.unregister(s2) + + # Wait for everything to finish. + wait() + def test_timeout(self): + """make sure Poller.poll timeout has the right units (milliseconds).""" + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + poller = self.Poller() + poller.register(s1, zmq.POLLIN) + tic = time.time() + evt = poller.poll(.005) + toc = time.time() + self.assertTrue(toc-tic < 0.1) + tic = time.time() + evt = poller.poll(5) + toc = time.time() + self.assertTrue(toc-tic < 0.1) + self.assertTrue(toc-tic > .001) + tic = time.time() + evt = poller.poll(500) + toc = time.time() + self.assertTrue(toc-tic < 1) + self.assertTrue(toc-tic > 0.1) + +class TestSelect(PollZMQTestCase): + + def test_pair(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + + # Sleep to allow sockets to connect. + wait() + + rlist, wlist, xlist = zmq.select([s1, s2], [s1, s2], [s1, s2]) + self.assert_(s1 in wlist) + self.assert_(s2 in wlist) + self.assert_(s1 not in rlist) + self.assert_(s2 not in rlist) + + def test_timeout(self): + """make sure select timeout has the right units (seconds).""" + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + tic = time.time() + r,w,x = zmq.select([s1,s2],[],[],.005) + toc = time.time() + self.assertTrue(toc-tic < 1) + self.assertTrue(toc-tic > 0.001) + tic = time.time() + r,w,x = zmq.select([s1,s2],[],[],.25) + toc = time.time() + self.assertTrue(toc-tic < 1) + self.assertTrue(toc-tic > 0.1) + + +if have_gevent: + import gevent + from zmq import green as gzmq + + class TestPollGreen(GreenTest, TestPoll): + Poller = gzmq.Poller + + def test_wakeup(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + poller = self.Poller() + poller.register(s2, zmq.POLLIN) + + tic = time.time() + r = gevent.spawn(lambda: poller.poll(10000)) + s = gevent.spawn(lambda: s1.send(b'msg1')) + r.join() + toc = time.time() + self.assertTrue(toc-tic < 1) + + def test_socket_poll(self): + s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + + tic = time.time() + r = gevent.spawn(lambda: s2.poll(10000)) + s = gevent.spawn(lambda: s1.send(b'msg1')) + r.join() + toc = time.time() + self.assertTrue(toc-tic < 1) + diff --git a/src/console/zmq/tests/test_pubsub.py b/src/console/zmq/tests/test_pubsub.py new file mode 100755 index 00000000..a3ee22aa --- /dev/null +++ b/src/console/zmq/tests/test_pubsub.py @@ -0,0 +1,41 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import time +from unittest import TestCase + +import zmq + +from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest + + +class TestPubSub(BaseZMQTestCase): + + pass + + # We are disabling this test while an issue is being resolved. + def test_basic(self): + s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB) + s2.setsockopt(zmq.SUBSCRIBE,b'') + time.sleep(0.1) + msg1 = b'message' + s1.send(msg1) + msg2 = s2.recv() # This is blocking! + self.assertEqual(msg1, msg2) + + def test_topic(self): + s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB) + s2.setsockopt(zmq.SUBSCRIBE, b'x') + time.sleep(0.1) + msg1 = b'message' + s1.send(msg1) + self.assertRaisesErrno(zmq.EAGAIN, s2.recv, zmq.NOBLOCK) + msg1 = b'xmessage' + s1.send(msg1) + msg2 = s2.recv() + self.assertEqual(msg1, msg2) + +if have_gevent: + class TestPubSubGreen(GreenTest, TestPubSub): + pass diff --git a/src/console/zmq/tests/test_reqrep.py b/src/console/zmq/tests/test_reqrep.py new file mode 100755 index 00000000..de17f2b3 --- /dev/null +++ b/src/console/zmq/tests/test_reqrep.py @@ -0,0 +1,62 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from unittest import TestCase + +import zmq +from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest + + +class TestReqRep(BaseZMQTestCase): + + def test_basic(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + + msg1 = b'message 1' + msg2 = self.ping_pong(s1, s2, msg1) + self.assertEqual(msg1, msg2) + + def test_multiple(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + + for i in range(10): + msg1 = i*b' ' + msg2 = self.ping_pong(s1, s2, msg1) + self.assertEqual(msg1, msg2) + + def test_bad_send_recv(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + + if zmq.zmq_version() != '2.1.8': + # this doesn't work on 2.1.8 + for copy in (True,False): + self.assertRaisesErrno(zmq.EFSM, s1.recv, copy=copy) + self.assertRaisesErrno(zmq.EFSM, s2.send, b'asdf', copy=copy) + + # I have to have this or we die on an Abort trap. + msg1 = b'asdf' + msg2 = self.ping_pong(s1, s2, msg1) + self.assertEqual(msg1, msg2) + + def test_json(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + o = dict(a=10,b=list(range(10))) + o2 = self.ping_pong_json(s1, s2, o) + + def test_pyobj(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + o = dict(a=10,b=range(10)) + o2 = self.ping_pong_pyobj(s1, s2, o) + + def test_large_msg(self): + s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + msg1 = 10000*b'X' + + for i in range(10): + msg2 = self.ping_pong(s1, s2, msg1) + self.assertEqual(msg1, msg2) + +if have_gevent: + class TestReqRepGreen(GreenTest, TestReqRep): + pass diff --git a/src/console/zmq/tests/test_security.py b/src/console/zmq/tests/test_security.py new file mode 100755 index 00000000..687b7e0f --- /dev/null +++ b/src/console/zmq/tests/test_security.py @@ -0,0 +1,212 @@ +"""Test libzmq security (libzmq >= 3.3.0)""" +# -*- coding: utf8 -*- + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import os +from threading import Thread + +import zmq +from zmq.tests import ( + BaseZMQTestCase, SkipTest, PYPY +) +from zmq.utils import z85 + + +USER = b"admin" +PASS = b"password" + +class TestSecurity(BaseZMQTestCase): + + def setUp(self): + if zmq.zmq_version_info() < (4,0): + raise SkipTest("security is new in libzmq 4.0") + try: + zmq.curve_keypair() + except zmq.ZMQError: + raise SkipTest("security requires libzmq to be linked against libsodium") + super(TestSecurity, self).setUp() + + + def zap_handler(self): + socket = self.context.socket(zmq.REP) + socket.bind("inproc://zeromq.zap.01") + try: + msg = self.recv_multipart(socket) + + version, sequence, domain, address, identity, mechanism = msg[:6] + if mechanism == b'PLAIN': + username, password = msg[6:] + elif mechanism == b'CURVE': + key = msg[6] + + self.assertEqual(version, b"1.0") + self.assertEqual(identity, b"IDENT") + reply = [version, sequence] + if mechanism == b'CURVE' or \ + (mechanism == b'PLAIN' and username == USER and password == PASS) or \ + (mechanism == b'NULL'): + reply.extend([ + b"200", + b"OK", + b"anonymous", + b"\5Hello\0\0\0\5World", + ]) + else: + reply.extend([ + b"400", + b"Invalid username or password", + b"", + b"", + ]) + socket.send_multipart(reply) + finally: + socket.close() + + def start_zap(self): + self.zap_thread = Thread(target=self.zap_handler) + self.zap_thread.start() + + def stop_zap(self): + self.zap_thread.join() + + def bounce(self, server, client, test_metadata=True): + msg = [os.urandom(64), os.urandom(64)] + client.send_multipart(msg) + frames = self.recv_multipart(server, copy=False) + recvd = list(map(lambda x: x.bytes, frames)) + + try: + if test_metadata and not PYPY: + for frame in frames: + self.assertEqual(frame.get('User-Id'), 'anonymous') + self.assertEqual(frame.get('Hello'), 'World') + self.assertEqual(frame['Socket-Type'], 'DEALER') + except zmq.ZMQVersionError: + pass + + self.assertEqual(recvd, msg) + server.send_multipart(recvd) + msg2 = self.recv_multipart(client) + self.assertEqual(msg2, msg) + + def test_null(self): + """test NULL (default) security""" + server = self.socket(zmq.DEALER) + client = self.socket(zmq.DEALER) + self.assertEqual(client.MECHANISM, zmq.NULL) + self.assertEqual(server.mechanism, zmq.NULL) + self.assertEqual(client.plain_server, 0) + self.assertEqual(server.plain_server, 0) + iface = 'tcp://127.0.0.1' + port = server.bind_to_random_port(iface) + client.connect("%s:%i" % (iface, port)) + self.bounce(server, client, False) + + def test_plain(self): + """test PLAIN authentication""" + server = self.socket(zmq.DEALER) + server.identity = b'IDENT' + client = self.socket(zmq.DEALER) + self.assertEqual(client.plain_username, b'') + self.assertEqual(client.plain_password, b'') + client.plain_username = USER + client.plain_password = PASS + self.assertEqual(client.getsockopt(zmq.PLAIN_USERNAME), USER) + self.assertEqual(client.getsockopt(zmq.PLAIN_PASSWORD), PASS) + self.assertEqual(client.plain_server, 0) + self.assertEqual(server.plain_server, 0) + server.plain_server = True + self.assertEqual(server.mechanism, zmq.PLAIN) + self.assertEqual(client.mechanism, zmq.PLAIN) + + assert not client.plain_server + assert server.plain_server + + self.start_zap() + + iface = 'tcp://127.0.0.1' + port = server.bind_to_random_port(iface) + client.connect("%s:%i" % (iface, port)) + self.bounce(server, client) + self.stop_zap() + + def skip_plain_inauth(self): + """test PLAIN failed authentication""" + server = self.socket(zmq.DEALER) + server.identity = b'IDENT' + client = self.socket(zmq.DEALER) + self.sockets.extend([server, client]) + client.plain_username = USER + client.plain_password = b'incorrect' + server.plain_server = True + self.assertEqual(server.mechanism, zmq.PLAIN) + self.assertEqual(client.mechanism, zmq.PLAIN) + + self.start_zap() + + iface = 'tcp://127.0.0.1' + port = server.bind_to_random_port(iface) + client.connect("%s:%i" % (iface, port)) + client.send(b'ping') + server.rcvtimeo = 250 + self.assertRaisesErrno(zmq.EAGAIN, server.recv) + self.stop_zap() + + def test_keypair(self): + """test curve_keypair""" + try: + public, secret = zmq.curve_keypair() + except zmq.ZMQError: + raise SkipTest("CURVE unsupported") + + self.assertEqual(type(secret), bytes) + self.assertEqual(type(public), bytes) + self.assertEqual(len(secret), 40) + self.assertEqual(len(public), 40) + + # verify that it is indeed Z85 + bsecret, bpublic = [ z85.decode(key) for key in (public, secret) ] + self.assertEqual(type(bsecret), bytes) + self.assertEqual(type(bpublic), bytes) + self.assertEqual(len(bsecret), 32) + self.assertEqual(len(bpublic), 32) + + + def test_curve(self): + """test CURVE encryption""" + server = self.socket(zmq.DEALER) + server.identity = b'IDENT' + client = self.socket(zmq.DEALER) + self.sockets.extend([server, client]) + try: + server.curve_server = True + except zmq.ZMQError as e: + # will raise EINVAL if not linked against libsodium + if e.errno == zmq.EINVAL: + raise SkipTest("CURVE unsupported") + + server_public, server_secret = zmq.curve_keypair() + client_public, client_secret = zmq.curve_keypair() + + server.curve_secretkey = server_secret + server.curve_publickey = server_public + client.curve_serverkey = server_public + client.curve_publickey = client_public + client.curve_secretkey = client_secret + + self.assertEqual(server.mechanism, zmq.CURVE) + self.assertEqual(client.mechanism, zmq.CURVE) + + self.assertEqual(server.get(zmq.CURVE_SERVER), True) + self.assertEqual(client.get(zmq.CURVE_SERVER), False) + + self.start_zap() + + iface = 'tcp://127.0.0.1' + port = server.bind_to_random_port(iface) + client.connect("%s:%i" % (iface, port)) + self.bounce(server, client) + self.stop_zap() + diff --git a/src/console/zmq/tests/test_socket.py b/src/console/zmq/tests/test_socket.py new file mode 100755 index 00000000..5c842edc --- /dev/null +++ b/src/console/zmq/tests/test_socket.py @@ -0,0 +1,450 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +import warnings + +import zmq +from zmq.tests import ( + BaseZMQTestCase, SkipTest, have_gevent, GreenTest, skip_pypy, skip_if +) +from zmq.utils.strtypes import bytes, unicode + + +class TestSocket(BaseZMQTestCase): + + def test_create(self): + ctx = self.Context() + s = ctx.socket(zmq.PUB) + # Superluminal protocol not yet implemented + self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a') + self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a') + self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://') + s.close() + del ctx + + def test_context_manager(self): + url = 'inproc://a' + with self.Context() as ctx: + with ctx.socket(zmq.PUSH) as a: + a.bind(url) + with ctx.socket(zmq.PULL) as b: + b.connect(url) + msg = b'hi' + a.send(msg) + rcvd = self.recv(b) + self.assertEqual(rcvd, msg) + self.assertEqual(b.closed, True) + self.assertEqual(a.closed, True) + self.assertEqual(ctx.closed, True) + + def test_dir(self): + ctx = self.Context() + s = ctx.socket(zmq.PUB) + self.assertTrue('send' in dir(s)) + self.assertTrue('IDENTITY' in dir(s)) + self.assertTrue('AFFINITY' in dir(s)) + self.assertTrue('FD' in dir(s)) + s.close() + ctx.term() + + def test_bind_unicode(self): + s = self.socket(zmq.PUB) + p = s.bind_to_random_port(unicode("tcp://*")) + + def test_connect_unicode(self): + s = self.socket(zmq.PUB) + s.connect(unicode("tcp://127.0.0.1:5555")) + + def test_bind_to_random_port(self): + # Check that bind_to_random_port do not hide usefull exception + ctx = self.Context() + c = ctx.socket(zmq.PUB) + # Invalid format + try: + c.bind_to_random_port('tcp:*') + except zmq.ZMQError as e: + self.assertEqual(e.errno, zmq.EINVAL) + # Invalid protocol + try: + c.bind_to_random_port('rand://*') + except zmq.ZMQError as e: + self.assertEqual(e.errno, zmq.EPROTONOSUPPORT) + + def test_identity(self): + s = self.context.socket(zmq.PULL) + self.sockets.append(s) + ident = b'identity\0\0' + s.identity = ident + self.assertEqual(s.get(zmq.IDENTITY), ident) + + def test_unicode_sockopts(self): + """test setting/getting sockopts with unicode strings""" + topic = "tést" + if str is not unicode: + topic = topic.decode('utf8') + p,s = self.create_bound_pair(zmq.PUB, zmq.SUB) + self.assertEqual(s.send_unicode, s.send_unicode) + self.assertEqual(p.recv_unicode, p.recv_unicode) + self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic) + self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic) + s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16') + self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic) + s.setsockopt_unicode(zmq.SUBSCRIBE, topic) + self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY) + self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE) + + identb = s.getsockopt(zmq.IDENTITY) + identu = identb.decode('utf16') + identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16') + self.assertEqual(identu, identu2) + time.sleep(0.1) # wait for connection/subscription + p.send_unicode(topic,zmq.SNDMORE) + p.send_unicode(topic*2, encoding='latin-1') + self.assertEqual(topic, s.recv_unicode()) + self.assertEqual(topic*2, s.recv_unicode(encoding='latin-1')) + + def test_int_sockopts(self): + "test integer sockopts" + v = zmq.zmq_version_info() + if v < (3,0): + default_hwm = 0 + else: + default_hwm = 1000 + p,s = self.create_bound_pair(zmq.PUB, zmq.SUB) + p.setsockopt(zmq.LINGER, 0) + self.assertEqual(p.getsockopt(zmq.LINGER), 0) + p.setsockopt(zmq.LINGER, -1) + self.assertEqual(p.getsockopt(zmq.LINGER), -1) + self.assertEqual(p.hwm, default_hwm) + p.hwm = 11 + self.assertEqual(p.hwm, 11) + # p.setsockopt(zmq.EVENTS, zmq.POLLIN) + self.assertEqual(p.getsockopt(zmq.EVENTS), zmq.POLLOUT) + self.assertRaisesErrno(zmq.EINVAL, p.setsockopt,zmq.EVENTS, 2**7-1) + self.assertEqual(p.getsockopt(zmq.TYPE), p.socket_type) + self.assertEqual(p.getsockopt(zmq.TYPE), zmq.PUB) + self.assertEqual(s.getsockopt(zmq.TYPE), s.socket_type) + self.assertEqual(s.getsockopt(zmq.TYPE), zmq.SUB) + + # check for overflow / wrong type: + errors = [] + backref = {} + constants = zmq.constants + for name in constants.__all__: + value = getattr(constants, name) + if isinstance(value, int): + backref[value] = name + for opt in zmq.constants.int_sockopts.union(zmq.constants.int64_sockopts): + sopt = backref[opt] + if sopt.startswith(( + 'ROUTER', 'XPUB', 'TCP', 'FAIL', + 'REQ_', 'CURVE_', 'PROBE_ROUTER', + 'IPC_FILTER', 'GSSAPI', + )): + # some sockopts are write-only + continue + try: + n = p.getsockopt(opt) + except zmq.ZMQError as e: + errors.append("getsockopt(zmq.%s) raised '%s'."%(sopt, e)) + else: + if n > 2**31: + errors.append("getsockopt(zmq.%s) returned a ridiculous value." + " It is probably the wrong type."%sopt) + if errors: + self.fail('\n'.join([''] + errors)) + + def test_bad_sockopts(self): + """Test that appropriate errors are raised on bad socket options""" + s = self.context.socket(zmq.PUB) + self.sockets.append(s) + s.setsockopt(zmq.LINGER, 0) + # unrecognized int sockopts pass through to libzmq, and should raise EINVAL + self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5) + self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999) + # but only int sockopts are allowed through this way, otherwise raise a TypeError + self.assertRaises(TypeError, s.setsockopt, 9999, b"5") + # some sockopts are valid in general, but not on every socket: + self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi') + + def test_sockopt_roundtrip(self): + "test set/getsockopt roundtrip." + p = self.context.socket(zmq.PUB) + self.sockets.append(p) + p.setsockopt(zmq.LINGER, 11) + self.assertEqual(p.getsockopt(zmq.LINGER), 11) + + def test_send_unicode(self): + "test sending unicode objects" + a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + self.sockets.extend([a,b]) + u = "çπ§" + if str is not unicode: + u = u.decode('utf8') + self.assertRaises(TypeError, a.send, u,copy=False) + self.assertRaises(TypeError, a.send, u,copy=True) + a.send_unicode(u) + s = b.recv() + self.assertEqual(s,u.encode('utf8')) + self.assertEqual(s.decode('utf8'),u) + a.send_unicode(u,encoding='utf16') + s = b.recv_unicode(encoding='utf16') + self.assertEqual(s,u) + + @skip_pypy + def test_tracker(self): + "test the MessageTracker object for tracking when zmq is done with a buffer" + addr = 'tcp://127.0.0.1' + a = self.context.socket(zmq.PUB) + port = a.bind_to_random_port(addr) + a.close() + iface = "%s:%i"%(addr,port) + a = self.context.socket(zmq.PAIR) + # a.setsockopt(zmq.IDENTITY, b"a") + b = self.context.socket(zmq.PAIR) + self.sockets.extend([a,b]) + a.connect(iface) + time.sleep(0.1) + p1 = a.send(b'something', copy=False, track=True) + self.assertTrue(isinstance(p1, zmq.MessageTracker)) + self.assertFalse(p1.done) + p2 = a.send_multipart([b'something', b'else'], copy=False, track=True) + self.assert_(isinstance(p2, zmq.MessageTracker)) + self.assertEqual(p2.done, False) + self.assertEqual(p1.done, False) + + b.bind(iface) + msg = b.recv_multipart() + for i in range(10): + if p1.done: + break + time.sleep(0.1) + self.assertEqual(p1.done, True) + self.assertEqual(msg, [b'something']) + msg = b.recv_multipart() + for i in range(10): + if p2.done: + break + time.sleep(0.1) + self.assertEqual(p2.done, True) + self.assertEqual(msg, [b'something', b'else']) + m = zmq.Frame(b"again", track=True) + self.assertEqual(m.tracker.done, False) + p1 = a.send(m, copy=False) + p2 = a.send(m, copy=False) + self.assertEqual(m.tracker.done, False) + self.assertEqual(p1.done, False) + self.assertEqual(p2.done, False) + msg = b.recv_multipart() + self.assertEqual(m.tracker.done, False) + self.assertEqual(msg, [b'again']) + msg = b.recv_multipart() + self.assertEqual(m.tracker.done, False) + self.assertEqual(msg, [b'again']) + self.assertEqual(p1.done, False) + self.assertEqual(p2.done, False) + pm = m.tracker + del m + for i in range(10): + if p1.done: + break + time.sleep(0.1) + self.assertEqual(p1.done, True) + self.assertEqual(p2.done, True) + m = zmq.Frame(b'something', track=False) + self.assertRaises(ValueError, a.send, m, copy=False, track=True) + + + def test_close(self): + ctx = self.Context() + s = ctx.socket(zmq.PUB) + s.close() + self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'') + self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'') + self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'') + self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf') + self.assertRaisesErrno(zmq.ENOTSOCK, s.recv) + del ctx + + def test_attr(self): + """set setting/getting sockopts as attributes""" + s = self.context.socket(zmq.DEALER) + self.sockets.append(s) + linger = 10 + s.linger = linger + self.assertEqual(linger, s.linger) + self.assertEqual(linger, s.getsockopt(zmq.LINGER)) + self.assertEqual(s.fd, s.getsockopt(zmq.FD)) + + def test_bad_attr(self): + s = self.context.socket(zmq.DEALER) + self.sockets.append(s) + try: + s.apple='foo' + except AttributeError: + pass + else: + self.fail("bad setattr should have raised AttributeError") + try: + s.apple + except AttributeError: + pass + else: + self.fail("bad getattr should have raised AttributeError") + + def test_subclass(self): + """subclasses can assign attributes""" + class S(zmq.Socket): + a = None + def __init__(self, *a, **kw): + self.a=-1 + super(S, self).__init__(*a, **kw) + + s = S(self.context, zmq.REP) + self.sockets.append(s) + self.assertEqual(s.a, -1) + s.a=1 + self.assertEqual(s.a, 1) + a=s.a + self.assertEqual(a, 1) + + def test_recv_multipart(self): + a,b = self.create_bound_pair() + msg = b'hi' + for i in range(3): + a.send(msg) + time.sleep(0.1) + for i in range(3): + self.assertEqual(b.recv_multipart(), [msg]) + + def test_close_after_destroy(self): + """s.close() after ctx.destroy() should be fine""" + ctx = self.Context() + s = ctx.socket(zmq.REP) + ctx.destroy() + # reaper is not instantaneous + time.sleep(1e-2) + s.close() + self.assertTrue(s.closed) + + def test_poll(self): + a,b = self.create_bound_pair() + tic = time.time() + evt = a.poll(50) + self.assertEqual(evt, 0) + evt = a.poll(50, zmq.POLLOUT) + self.assertEqual(evt, zmq.POLLOUT) + msg = b'hi' + a.send(msg) + evt = b.poll(50) + self.assertEqual(evt, zmq.POLLIN) + msg2 = self.recv(b) + evt = b.poll(50) + self.assertEqual(evt, 0) + self.assertEqual(msg2, msg) + + def test_ipc_path_max_length(self): + """IPC_PATH_MAX_LEN is a sensible value""" + if zmq.IPC_PATH_MAX_LEN == 0: + raise SkipTest("IPC_PATH_MAX_LEN undefined") + + msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN + self.assertTrue(zmq.IPC_PATH_MAX_LEN > 30, msg) + self.assertTrue(zmq.IPC_PATH_MAX_LEN < 1025, msg) + + def test_ipc_path_max_length_msg(self): + if zmq.IPC_PATH_MAX_LEN == 0: + raise SkipTest("IPC_PATH_MAX_LEN undefined") + + s = self.context.socket(zmq.PUB) + self.sockets.append(s) + try: + s.bind('ipc://{0}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1))) + except zmq.ZMQError as e: + self.assertTrue(str(zmq.IPC_PATH_MAX_LEN) in e.strerror) + + def test_hwm(self): + zmq3 = zmq.zmq_version_info()[0] >= 3 + for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER): + s = self.context.socket(stype) + s.hwm = 100 + self.assertEqual(s.hwm, 100) + if zmq3: + try: + self.assertEqual(s.sndhwm, 100) + except AttributeError: + pass + try: + self.assertEqual(s.rcvhwm, 100) + except AttributeError: + pass + s.close() + + def test_shadow(self): + p = self.socket(zmq.PUSH) + p.bind("tcp://127.0.0.1:5555") + p2 = zmq.Socket.shadow(p.underlying) + self.assertEqual(p.underlying, p2.underlying) + s = self.socket(zmq.PULL) + s2 = zmq.Socket.shadow(s.underlying) + self.assertNotEqual(s.underlying, p.underlying) + self.assertEqual(s.underlying, s2.underlying) + s2.connect("tcp://127.0.0.1:5555") + sent = b'hi' + p2.send(sent) + rcvd = self.recv(s2) + self.assertEqual(rcvd, sent) + + def test_shadow_pyczmq(self): + try: + from pyczmq import zctx, zsocket + except Exception: + raise SkipTest("Requires pyczmq") + + ctx = zctx.new() + ca = zsocket.new(ctx, zmq.PUSH) + cb = zsocket.new(ctx, zmq.PULL) + a = zmq.Socket.shadow(ca) + b = zmq.Socket.shadow(cb) + a.bind("inproc://a") + b.connect("inproc://a") + a.send(b'hi') + rcvd = self.recv(b) + self.assertEqual(rcvd, b'hi') + + +if have_gevent: + import gevent + + class TestSocketGreen(GreenTest, TestSocket): + test_bad_attr = GreenTest.skip_green + test_close_after_destroy = GreenTest.skip_green + + def test_timeout(self): + a,b = self.create_bound_pair() + g = gevent.spawn_later(0.5, lambda: a.send(b'hi')) + timeout = gevent.Timeout(0.1) + timeout.start() + self.assertRaises(gevent.Timeout, b.recv) + g.kill() + + @skip_if(not hasattr(zmq, 'RCVTIMEO')) + def test_warn_set_timeo(self): + s = self.context.socket(zmq.REQ) + with warnings.catch_warnings(record=True) as w: + s.rcvtimeo = 5 + s.close() + self.assertEqual(len(w), 1) + self.assertEqual(w[0].category, UserWarning) + + + @skip_if(not hasattr(zmq, 'SNDTIMEO')) + def test_warn_get_timeo(self): + s = self.context.socket(zmq.REQ) + with warnings.catch_warnings(record=True) as w: + s.sndtimeo + s.close() + self.assertEqual(len(w), 1) + self.assertEqual(w[0].category, UserWarning) diff --git a/src/console/zmq/tests/test_stopwatch.py b/src/console/zmq/tests/test_stopwatch.py new file mode 100755 index 00000000..49fb79f2 --- /dev/null +++ b/src/console/zmq/tests/test_stopwatch.py @@ -0,0 +1,42 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import sys +import time + +from unittest import TestCase + +from zmq import Stopwatch, ZMQError + +if sys.version_info[0] >= 3: + long = int + +class TestStopWatch(TestCase): + + def test_stop_long(self): + """Ensure stop returns a long int.""" + watch = Stopwatch() + watch.start() + us = watch.stop() + self.assertTrue(isinstance(us, long)) + + def test_stop_microseconds(self): + """Test that stop/sleep have right units.""" + watch = Stopwatch() + watch.start() + tic = time.time() + watch.sleep(1) + us = watch.stop() + toc = time.time() + self.assertAlmostEqual(us/1e6,(toc-tic),places=0) + + def test_double_stop(self): + """Test error raised on multiple calls to stop.""" + watch = Stopwatch() + watch.start() + watch.stop() + self.assertRaises(ZMQError, watch.stop) + self.assertRaises(ZMQError, watch.stop) + diff --git a/src/console/zmq/tests/test_version.py b/src/console/zmq/tests/test_version.py new file mode 100755 index 00000000..6ebebf30 --- /dev/null +++ b/src/console/zmq/tests/test_version.py @@ -0,0 +1,44 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from unittest import TestCase +import zmq +from zmq.sugar import version + + +class TestVersion(TestCase): + + def test_pyzmq_version(self): + vs = zmq.pyzmq_version() + vs2 = zmq.__version__ + self.assertTrue(isinstance(vs, str)) + if zmq.__revision__: + self.assertEqual(vs, '@'.join(vs2, zmq.__revision__)) + else: + self.assertEqual(vs, vs2) + if version.VERSION_EXTRA: + self.assertTrue(version.VERSION_EXTRA in vs) + self.assertTrue(version.VERSION_EXTRA in vs2) + + def test_pyzmq_version_info(self): + info = zmq.pyzmq_version_info() + self.assertTrue(isinstance(info, tuple)) + for n in info[:3]: + self.assertTrue(isinstance(n, int)) + if version.VERSION_EXTRA: + self.assertEqual(len(info), 4) + self.assertEqual(info[-1], float('inf')) + else: + self.assertEqual(len(info), 3) + + def test_zmq_version_info(self): + info = zmq.zmq_version_info() + self.assertTrue(isinstance(info, tuple)) + for n in info[:3]: + self.assertTrue(isinstance(n, int)) + + def test_zmq_version(self): + v = zmq.zmq_version() + self.assertTrue(isinstance(v, str)) + diff --git a/src/console/zmq/tests/test_win32_shim.py b/src/console/zmq/tests/test_win32_shim.py new file mode 100755 index 00000000..55657bda --- /dev/null +++ b/src/console/zmq/tests/test_win32_shim.py @@ -0,0 +1,56 @@ +from __future__ import print_function + +import os + +from functools import wraps +from zmq.tests import BaseZMQTestCase +from zmq.utils.win32 import allow_interrupt + + +def count_calls(f): + @wraps(f) + def _(*args, **kwds): + try: + return f(*args, **kwds) + finally: + _.__calls__ += 1 + _.__calls__ = 0 + return _ + + +class TestWindowsConsoleControlHandler(BaseZMQTestCase): + + def test_handler(self): + @count_calls + def interrupt_polling(): + print('Caught CTRL-C!') + + if os.name == 'nt': + from ctypes import windll + from ctypes.wintypes import BOOL, DWORD + + kernel32 = windll.LoadLibrary('kernel32') + + # <http://msdn.microsoft.com/en-us/library/ms683155.aspx> + GenerateConsoleCtrlEvent = kernel32.GenerateConsoleCtrlEvent + GenerateConsoleCtrlEvent.argtypes = (DWORD, DWORD) + GenerateConsoleCtrlEvent.restype = BOOL + + try: + # Simulate CTRL-C event while handler is active. + with allow_interrupt(interrupt_polling): + result = GenerateConsoleCtrlEvent(0, 0) + if result == 0: + raise WindowsError + except KeyboardInterrupt: + pass + else: + self.fail('Expecting `KeyboardInterrupt` exception!') + + # Make sure our handler was called. + self.assertEqual(interrupt_polling.__calls__, 1) + else: + # On non-Windows systems, this utility is just a no-op! + with allow_interrupt(interrupt_polling): + pass + self.assertEqual(interrupt_polling.__calls__, 0) diff --git a/src/console/zmq/tests/test_z85.py b/src/console/zmq/tests/test_z85.py new file mode 100755 index 00000000..8a73cb4d --- /dev/null +++ b/src/console/zmq/tests/test_z85.py @@ -0,0 +1,63 @@ +# -*- coding: utf8 -*- +"""Test Z85 encoding + +confirm values and roundtrip with test values from the reference implementation. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from unittest import TestCase +from zmq.utils import z85 + + +class TestZ85(TestCase): + + def test_client_public(self): + client_public = \ + b"\xBB\x88\x47\x1D\x65\xE2\x65\x9B" \ + b"\x30\xC5\x5A\x53\x21\xCE\xBB\x5A" \ + b"\xAB\x2B\x70\xA3\x98\x64\x5C\x26" \ + b"\xDC\xA2\xB2\xFC\xB4\x3F\xC5\x18" + encoded = z85.encode(client_public) + + self.assertEqual(encoded, b"Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID") + decoded = z85.decode(encoded) + self.assertEqual(decoded, client_public) + + def test_client_secret(self): + client_secret = \ + b"\x7B\xB8\x64\xB4\x89\xAF\xA3\x67" \ + b"\x1F\xBE\x69\x10\x1F\x94\xB3\x89" \ + b"\x72\xF2\x48\x16\xDF\xB0\x1B\x51" \ + b"\x65\x6B\x3F\xEC\x8D\xFD\x08\x88" + encoded = z85.encode(client_secret) + + self.assertEqual(encoded, b"D:)Q[IlAW!ahhC2ac:9*A}h:p?([4%wOTJ%JR%cs") + decoded = z85.decode(encoded) + self.assertEqual(decoded, client_secret) + + def test_server_public(self): + server_public = \ + b"\x54\xFC\xBA\x24\xE9\x32\x49\x96" \ + b"\x93\x16\xFB\x61\x7C\x87\x2B\xB0" \ + b"\xC1\xD1\xFF\x14\x80\x04\x27\xC5" \ + b"\x94\xCB\xFA\xCF\x1B\xC2\xD6\x52" + encoded = z85.encode(server_public) + + self.assertEqual(encoded, b"rq:rM>}U?@Lns47E1%kR.o@n%FcmmsL/@{H8]yf7") + decoded = z85.decode(encoded) + self.assertEqual(decoded, server_public) + + def test_server_secret(self): + server_secret = \ + b"\x8E\x0B\xDD\x69\x76\x28\xB9\x1D" \ + b"\x8F\x24\x55\x87\xEE\x95\xC5\xB0" \ + b"\x4D\x48\x96\x3F\x79\x25\x98\x77" \ + b"\xB4\x9C\xD9\x06\x3A\xEA\xD3\xB7" + encoded = z85.encode(server_secret) + + self.assertEqual(encoded, b"JTKVSB%%)wK0E.X)V>+}o?pNmC{O&4W4b!Ni{Lh6") + decoded = z85.decode(encoded) + self.assertEqual(decoded, server_secret) + diff --git a/src/console/zmq/tests/test_zmqstream.py b/src/console/zmq/tests/test_zmqstream.py new file mode 100755 index 00000000..cdb3a171 --- /dev/null +++ b/src/console/zmq/tests/test_zmqstream.py @@ -0,0 +1,34 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import sys +import time + +from unittest import TestCase + +import zmq +from zmq.eventloop import ioloop, zmqstream + +class TestZMQStream(TestCase): + + def setUp(self): + self.context = zmq.Context() + self.socket = self.context.socket(zmq.REP) + self.loop = ioloop.IOLoop.instance() + self.stream = zmqstream.ZMQStream(self.socket) + + def tearDown(self): + self.socket.close() + self.context.term() + + def test_callable_check(self): + """Ensure callable check works (py3k).""" + + self.stream.on_send(lambda *args: None) + self.stream.on_recv(lambda *args: None) + self.assertRaises(AssertionError, self.stream.on_recv, 1) + self.assertRaises(AssertionError, self.stream.on_send, 1) + self.assertRaises(AssertionError, self.stream.on_recv, zmq) + diff --git a/src/console/zmq/utils/__init__.py b/src/console/zmq/utils/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/console/zmq/utils/__init__.py diff --git a/src/console/zmq/utils/buffers.pxd b/src/console/zmq/utils/buffers.pxd new file mode 100644 index 00000000..998aa551 --- /dev/null +++ b/src/console/zmq/utils/buffers.pxd @@ -0,0 +1,313 @@ +"""Python version-independent methods for C/Python buffers. + +This file was copied and adapted from mpi4py. + +Authors +------- +* MinRK +""" + +#----------------------------------------------------------------------------- +# Copyright (c) 2010 Lisandro Dalcin +# All rights reserved. +# Used under BSD License: http://www.opensource.org/licenses/bsd-license.php +# +# Retrieval: +# Jul 23, 2010 18:00 PST (r539) +# http://code.google.com/p/mpi4py/source/browse/trunk/src/MPI/asbuffer.pxi +# +# Modifications from original: +# Copyright (c) 2010-2012 Brian Granger, Min Ragan-Kelley +# +# Distributed under the terms of the New BSD License. The full license is in +# the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + + +#----------------------------------------------------------------------------- +# Python includes. +#----------------------------------------------------------------------------- + +# get version-independent aliases: +cdef extern from "pyversion_compat.h": + pass + +# Python 3 buffer interface (PEP 3118) +cdef extern from "Python.h": + int PY_MAJOR_VERSION + int PY_MINOR_VERSION + ctypedef int Py_ssize_t + ctypedef struct PyMemoryViewObject: + pass + ctypedef struct Py_buffer: + void *buf + Py_ssize_t len + int readonly + char *format + int ndim + Py_ssize_t *shape + Py_ssize_t *strides + Py_ssize_t *suboffsets + Py_ssize_t itemsize + void *internal + cdef enum: + PyBUF_SIMPLE + PyBUF_WRITABLE + PyBUF_FORMAT + PyBUF_ANY_CONTIGUOUS + int PyObject_CheckBuffer(object) + int PyObject_GetBuffer(object, Py_buffer *, int) except -1 + void PyBuffer_Release(Py_buffer *) + + int PyBuffer_FillInfo(Py_buffer *view, object obj, void *buf, + Py_ssize_t len, int readonly, int infoflags) except -1 + object PyMemoryView_FromBuffer(Py_buffer *info) + + object PyMemoryView_FromObject(object) + +# Python 2 buffer interface (legacy) +cdef extern from "Python.h": + ctypedef void const_void "const void" + Py_ssize_t Py_END_OF_BUFFER + int PyObject_CheckReadBuffer(object) + int PyObject_AsReadBuffer (object, const_void **, Py_ssize_t *) except -1 + int PyObject_AsWriteBuffer(object, void **, Py_ssize_t *) except -1 + + object PyBuffer_FromMemory(void *ptr, Py_ssize_t s) + object PyBuffer_FromReadWriteMemory(void *ptr, Py_ssize_t s) + + object PyBuffer_FromObject(object, Py_ssize_t offset, Py_ssize_t size) + object PyBuffer_FromReadWriteObject(object, Py_ssize_t offset, Py_ssize_t size) + + +#----------------------------------------------------------------------------- +# asbuffer: C buffer from python object +#----------------------------------------------------------------------------- + + +cdef inline int memoryview_available(): + return PY_MAJOR_VERSION >= 3 or (PY_MAJOR_VERSION >=2 and PY_MINOR_VERSION >= 7) + +cdef inline int oldstyle_available(): + return PY_MAJOR_VERSION < 3 + + +cdef inline int check_buffer(object ob): + """Version independent check for whether an object is a buffer. + + Parameters + ---------- + object : object + Any Python object + + Returns + ------- + int : 0 if no buffer interface, 3 if newstyle buffer interface, 2 if oldstyle. + """ + if PyObject_CheckBuffer(ob): + return 3 + if oldstyle_available(): + return PyObject_CheckReadBuffer(ob) and 2 + return 0 + + +cdef inline object asbuffer(object ob, int writable, int format, + void **base, Py_ssize_t *size, + Py_ssize_t *itemsize): + """Turn an object into a C buffer in a Python version-independent way. + + Parameters + ---------- + ob : object + The object to be turned into a buffer. + Must provide a Python Buffer interface + writable : int + Whether the resulting buffer should be allowed to write + to the object. + format : int + The format of the buffer. See Python buffer docs. + base : void ** + The pointer that will be used to store the resulting C buffer. + size : Py_ssize_t * + The size of the buffer(s). + itemsize : Py_ssize_t * + The size of an item, if the buffer is non-contiguous. + + Returns + ------- + An object describing the buffer format. Generally a str, such as 'B'. + """ + + cdef void *bptr = NULL + cdef Py_ssize_t blen = 0, bitemlen = 0 + cdef Py_buffer view + cdef int flags = PyBUF_SIMPLE + cdef int mode = 0 + + bfmt = None + + mode = check_buffer(ob) + if mode == 0: + raise TypeError("%r does not provide a buffer interface."%ob) + + if mode == 3: + flags = PyBUF_ANY_CONTIGUOUS + if writable: + flags |= PyBUF_WRITABLE + if format: + flags |= PyBUF_FORMAT + PyObject_GetBuffer(ob, &view, flags) + bptr = view.buf + blen = view.len + if format: + if view.format != NULL: + bfmt = view.format + bitemlen = view.itemsize + PyBuffer_Release(&view) + else: # oldstyle + if writable: + PyObject_AsWriteBuffer(ob, &bptr, &blen) + else: + PyObject_AsReadBuffer(ob, <const_void **>&bptr, &blen) + if format: + try: # numpy.ndarray + dtype = ob.dtype + bfmt = dtype.char + bitemlen = dtype.itemsize + except AttributeError: + try: # array.array + bfmt = ob.typecode + bitemlen = ob.itemsize + except AttributeError: + if isinstance(ob, bytes): + bfmt = b"B" + bitemlen = 1 + else: + # nothing found + bfmt = None + bitemlen = 0 + if base: base[0] = <void *>bptr + if size: size[0] = <Py_ssize_t>blen + if itemsize: itemsize[0] = <Py_ssize_t>bitemlen + + if PY_MAJOR_VERSION >= 3 and bfmt is not None: + return bfmt.decode('ascii') + return bfmt + + +cdef inline object asbuffer_r(object ob, void **base, Py_ssize_t *size): + """Wrapper for standard calls to asbuffer with a readonly buffer.""" + asbuffer(ob, 0, 0, base, size, NULL) + return ob + + +cdef inline object asbuffer_w(object ob, void **base, Py_ssize_t *size): + """Wrapper for standard calls to asbuffer with a writable buffer.""" + asbuffer(ob, 1, 0, base, size, NULL) + return ob + +#------------------------------------------------------------------------------ +# frombuffer: python buffer/view from C buffer +#------------------------------------------------------------------------------ + + +cdef inline object frombuffer_3(void *ptr, Py_ssize_t s, int readonly): + """Python 3 version of frombuffer. + + This is the Python 3 model, but will work on Python >= 2.6. Currently, + we use it only on >= 3.0. + """ + cdef Py_buffer pybuf + cdef Py_ssize_t *shape = [s] + cdef str astr="" + PyBuffer_FillInfo(&pybuf, astr, ptr, s, readonly, PyBUF_SIMPLE) + pybuf.format = "B" + pybuf.shape = shape + return PyMemoryView_FromBuffer(&pybuf) + + +cdef inline object frombuffer_2(void *ptr, Py_ssize_t s, int readonly): + """Python 2 version of frombuffer. + + This must be used for Python <= 2.6, but we use it for all Python < 3. + """ + + if oldstyle_available(): + if readonly: + return PyBuffer_FromMemory(ptr, s) + else: + return PyBuffer_FromReadWriteMemory(ptr, s) + else: + raise NotImplementedError("Old style buffers not available.") + + +cdef inline object frombuffer(void *ptr, Py_ssize_t s, int readonly): + """Create a Python Buffer/View of a C array. + + Parameters + ---------- + ptr : void * + Pointer to the array to be copied. + s : size_t + Length of the buffer. + readonly : int + whether the resulting object should be allowed to write to the buffer. + + Returns + ------- + Python Buffer/View of the C buffer. + """ + # oldstyle first priority for now + if oldstyle_available(): + return frombuffer_2(ptr, s, readonly) + else: + return frombuffer_3(ptr, s, readonly) + + +cdef inline object frombuffer_r(void *ptr, Py_ssize_t s): + """Wrapper for readonly view frombuffer.""" + return frombuffer(ptr, s, 1) + + +cdef inline object frombuffer_w(void *ptr, Py_ssize_t s): + """Wrapper for writable view frombuffer.""" + return frombuffer(ptr, s, 0) + +#------------------------------------------------------------------------------ +# viewfromobject: python buffer/view from python object, refcounts intact +# frombuffer(asbuffer(obj)) would lose track of refs +#------------------------------------------------------------------------------ + +cdef inline object viewfromobject(object obj, int readonly): + """Construct a Python Buffer/View object from another Python object. + + This work in a Python version independent manner. + + Parameters + ---------- + obj : object + The input object to be cast as a buffer + readonly : int + Whether the result should be prevented from overwriting the original. + + Returns + ------- + Buffer/View of the original object. + """ + if not memoryview_available(): + if readonly: + return PyBuffer_FromObject(obj, 0, Py_END_OF_BUFFER) + else: + return PyBuffer_FromReadWriteObject(obj, 0, Py_END_OF_BUFFER) + else: + return PyMemoryView_FromObject(obj) + + +cdef inline object viewfromobject_r(object obj): + """Wrapper for readonly viewfromobject.""" + return viewfromobject(obj, 1) + + +cdef inline object viewfromobject_w(object obj): + """Wrapper for writable viewfromobject.""" + return viewfromobject(obj, 0) diff --git a/src/console/zmq/utils/compiler.json b/src/console/zmq/utils/compiler.json new file mode 100644 index 00000000..e58fc130 --- /dev/null +++ b/src/console/zmq/utils/compiler.json @@ -0,0 +1,24 @@ +{ + "extra_link_args": [], + "define_macros": [ + [ + "HAVE_SYS_UN_H", + 1 + ] + ], + "runtime_library_dirs": [ + "$ORIGIN/.." + ], + "libraries": [ + "zmq" + ], + "library_dirs": [ + "zmq" + ], + "include_dirs": [ + "/auto/srg-sce-swinfra-usr/emb/users/hhaim/work/depot/asr1k/emb/private/bpsim/main/src/zmq/include", + "zmq/utils", + "zmq/backend/cython", + "zmq/devices" + ] +}
\ No newline at end of file diff --git a/src/console/zmq/utils/config.json b/src/console/zmq/utils/config.json new file mode 100644 index 00000000..1e4611f9 --- /dev/null +++ b/src/console/zmq/utils/config.json @@ -0,0 +1,10 @@ +{ + "have_sys_un_h": true, + "zmq_prefix": "/auto/srg-sce-swinfra-usr/emb/users/hhaim/work/depot/asr1k/emb/private/bpsim/main/src/zmq", + "no_libzmq_extension": true, + "libzmq_extension": false, + "easy_install": {}, + "bdist_egg": {}, + "skip_check_zmq": false, + "build_ext": {} +}
\ No newline at end of file diff --git a/src/console/zmq/utils/constant_names.py b/src/console/zmq/utils/constant_names.py new file mode 100755 index 00000000..47da9dc2 --- /dev/null +++ b/src/console/zmq/utils/constant_names.py @@ -0,0 +1,365 @@ +"""0MQ Constant names""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +# dictionaries of constants new or removed in particular versions + +new_in = { + (2,2,0) : [ + 'RCVTIMEO', + 'SNDTIMEO', + ], + (3,2,2) : [ + # errnos + 'EMSGSIZE', + 'EAFNOSUPPORT', + 'ENETUNREACH', + 'ECONNABORTED', + 'ECONNRESET', + 'ENOTCONN', + 'ETIMEDOUT', + 'EHOSTUNREACH', + 'ENETRESET', + + # ctx opts + 'IO_THREADS', + 'MAX_SOCKETS', + 'IO_THREADS_DFLT', + 'MAX_SOCKETS_DFLT', + + # socket opts + 'ROUTER_BEHAVIOR', + 'ROUTER_MANDATORY', + 'FAIL_UNROUTABLE', + 'TCP_KEEPALIVE', + 'TCP_KEEPALIVE_CNT', + 'TCP_KEEPALIVE_IDLE', + 'TCP_KEEPALIVE_INTVL', + 'DELAY_ATTACH_ON_CONNECT', + 'XPUB_VERBOSE', + + # msg opts + 'MORE', + + 'EVENT_CONNECTED', + 'EVENT_CONNECT_DELAYED', + 'EVENT_CONNECT_RETRIED', + 'EVENT_LISTENING', + 'EVENT_BIND_FAILED', + 'EVENT_ACCEPTED', + 'EVENT_ACCEPT_FAILED', + 'EVENT_CLOSED', + 'EVENT_CLOSE_FAILED', + 'EVENT_DISCONNECTED', + 'EVENT_ALL', + ], + (4,0,0) : [ + # socket types + 'STREAM', + + # socket opts + 'IMMEDIATE', + 'ROUTER_RAW', + 'IPV6', + 'MECHANISM', + 'PLAIN_SERVER', + 'PLAIN_USERNAME', + 'PLAIN_PASSWORD', + 'CURVE_SERVER', + 'CURVE_PUBLICKEY', + 'CURVE_SECRETKEY', + 'CURVE_SERVERKEY', + 'PROBE_ROUTER', + 'REQ_RELAXED', + 'REQ_CORRELATE', + 'CONFLATE', + 'ZAP_DOMAIN', + + # security + 'NULL', + 'PLAIN', + 'CURVE', + + # events + 'EVENT_MONITOR_STOPPED', + ], + (4,1,0) : [ + # ctx opts + 'SOCKET_LIMIT', + 'THREAD_PRIORITY', + 'THREAD_PRIORITY_DFLT', + 'THREAD_SCHED_POLICY', + 'THREAD_SCHED_POLICY_DFLT', + + # socket opts + 'ROUTER_HANDOVER', + 'TOS', + 'IPC_FILTER_PID', + 'IPC_FILTER_UID', + 'IPC_FILTER_GID', + 'CONNECT_RID', + 'GSSAPI_SERVER', + 'GSSAPI_PRINCIPAL', + 'GSSAPI_SERVICE_PRINCIPAL', + 'GSSAPI_PLAINTEXT', + 'HANDSHAKE_IVL', + 'IDENTITY_FD', + 'XPUB_NODROP', + 'SOCKS_PROXY', + + # msg opts + 'SRCFD', + 'SHARED', + + # security + 'GSSAPI', + + ], +} + + +removed_in = { + (3,2,2) : [ + 'UPSTREAM', + 'DOWNSTREAM', + + 'HWM', + 'SWAP', + 'MCAST_LOOP', + 'RECOVERY_IVL_MSEC', + ] +} + +# collections of zmq constant names based on their role +# base names have no specific use +# opt names are validated in get/set methods of various objects + +base_names = [ + # base + 'VERSION', + 'VERSION_MAJOR', + 'VERSION_MINOR', + 'VERSION_PATCH', + 'NOBLOCK', + 'DONTWAIT', + + 'POLLIN', + 'POLLOUT', + 'POLLERR', + + 'SNDMORE', + + 'STREAMER', + 'FORWARDER', + 'QUEUE', + + 'IO_THREADS_DFLT', + 'MAX_SOCKETS_DFLT', + 'POLLITEMS_DFLT', + 'THREAD_PRIORITY_DFLT', + 'THREAD_SCHED_POLICY_DFLT', + + # socktypes + 'PAIR', + 'PUB', + 'SUB', + 'REQ', + 'REP', + 'DEALER', + 'ROUTER', + 'XREQ', + 'XREP', + 'PULL', + 'PUSH', + 'XPUB', + 'XSUB', + 'UPSTREAM', + 'DOWNSTREAM', + 'STREAM', + + # events + 'EVENT_CONNECTED', + 'EVENT_CONNECT_DELAYED', + 'EVENT_CONNECT_RETRIED', + 'EVENT_LISTENING', + 'EVENT_BIND_FAILED', + 'EVENT_ACCEPTED', + 'EVENT_ACCEPT_FAILED', + 'EVENT_CLOSED', + 'EVENT_CLOSE_FAILED', + 'EVENT_DISCONNECTED', + 'EVENT_ALL', + 'EVENT_MONITOR_STOPPED', + + # security + 'NULL', + 'PLAIN', + 'CURVE', + 'GSSAPI', + + ## ERRNO + # Often used (these are alse in errno.) + 'EAGAIN', + 'EINVAL', + 'EFAULT', + 'ENOMEM', + 'ENODEV', + 'EMSGSIZE', + 'EAFNOSUPPORT', + 'ENETUNREACH', + 'ECONNABORTED', + 'ECONNRESET', + 'ENOTCONN', + 'ETIMEDOUT', + 'EHOSTUNREACH', + 'ENETRESET', + + # For Windows compatability + 'HAUSNUMERO', + 'ENOTSUP', + 'EPROTONOSUPPORT', + 'ENOBUFS', + 'ENETDOWN', + 'EADDRINUSE', + 'EADDRNOTAVAIL', + 'ECONNREFUSED', + 'EINPROGRESS', + 'ENOTSOCK', + + # 0MQ Native + 'EFSM', + 'ENOCOMPATPROTO', + 'ETERM', + 'EMTHREAD', +] + +int64_sockopt_names = [ + 'AFFINITY', + 'MAXMSGSIZE', + + # sockopts removed in 3.0.0 + 'HWM', + 'SWAP', + 'MCAST_LOOP', + 'RECOVERY_IVL_MSEC', +] + +bytes_sockopt_names = [ + 'IDENTITY', + 'SUBSCRIBE', + 'UNSUBSCRIBE', + 'LAST_ENDPOINT', + 'TCP_ACCEPT_FILTER', + + 'PLAIN_USERNAME', + 'PLAIN_PASSWORD', + + 'CURVE_PUBLICKEY', + 'CURVE_SECRETKEY', + 'CURVE_SERVERKEY', + 'ZAP_DOMAIN', + 'CONNECT_RID', + 'GSSAPI_PRINCIPAL', + 'GSSAPI_SERVICE_PRINCIPAL', + 'SOCKS_PROXY', +] + +fd_sockopt_names = [ + 'FD', + 'IDENTITY_FD', +] + +int_sockopt_names = [ + # sockopts + 'RECONNECT_IVL_MAX', + + # sockopts new in 2.2.0 + 'SNDTIMEO', + 'RCVTIMEO', + + # new in 3.x + 'SNDHWM', + 'RCVHWM', + 'MULTICAST_HOPS', + 'IPV4ONLY', + + 'ROUTER_BEHAVIOR', + 'TCP_KEEPALIVE', + 'TCP_KEEPALIVE_CNT', + 'TCP_KEEPALIVE_IDLE', + 'TCP_KEEPALIVE_INTVL', + 'DELAY_ATTACH_ON_CONNECT', + 'XPUB_VERBOSE', + + 'EVENTS', + 'TYPE', + 'LINGER', + 'RECONNECT_IVL', + 'BACKLOG', + + 'ROUTER_MANDATORY', + 'FAIL_UNROUTABLE', + + 'ROUTER_RAW', + 'IMMEDIATE', + 'IPV6', + 'MECHANISM', + 'PLAIN_SERVER', + 'CURVE_SERVER', + 'PROBE_ROUTER', + 'REQ_RELAXED', + 'REQ_CORRELATE', + 'CONFLATE', + 'ROUTER_HANDOVER', + 'TOS', + 'IPC_FILTER_PID', + 'IPC_FILTER_UID', + 'IPC_FILTER_GID', + 'GSSAPI_SERVER', + 'GSSAPI_PLAINTEXT', + 'HANDSHAKE_IVL', + 'XPUB_NODROP', +] + +switched_sockopt_names = [ + 'RATE', + 'RECOVERY_IVL', + 'SNDBUF', + 'RCVBUF', + 'RCVMORE', +] + +ctx_opt_names = [ + 'IO_THREADS', + 'MAX_SOCKETS', + 'SOCKET_LIMIT', + 'THREAD_PRIORITY', + 'THREAD_SCHED_POLICY', +] + +msg_opt_names = [ + 'MORE', + 'SRCFD', + 'SHARED', +] + +from itertools import chain + +all_names = list(chain( + base_names, + ctx_opt_names, + bytes_sockopt_names, + fd_sockopt_names, + int_sockopt_names, + int64_sockopt_names, + switched_sockopt_names, + msg_opt_names, +)) + +del chain + +def no_prefix(name): + """does the given constant have a ZMQ_ prefix?""" + return name.startswith('E') and not name.startswith('EVENT') + diff --git a/src/console/zmq/utils/garbage.py b/src/console/zmq/utils/garbage.py new file mode 100755 index 00000000..80a8725a --- /dev/null +++ b/src/console/zmq/utils/garbage.py @@ -0,0 +1,180 @@ +"""Garbage collection thread for representing zmq refcount of Python objects +used in zero-copy sends. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import atexit +import struct + +from os import getpid +from collections import namedtuple +from threading import Thread, Event, Lock +import warnings + +import zmq + + +gcref = namedtuple('gcref', ['obj', 'event']) + +class GarbageCollectorThread(Thread): + """Thread in which garbage collection actually happens.""" + def __init__(self, gc): + super(GarbageCollectorThread, self).__init__() + self.gc = gc + self.daemon = True + self.pid = getpid() + self.ready = Event() + + def run(self): + # detect fork at begining of the thread + if getpid is None or getpid() != self.pid: + self.ready.set() + return + try: + s = self.gc.context.socket(zmq.PULL) + s.linger = 0 + s.bind(self.gc.url) + finally: + self.ready.set() + + while True: + # detect fork + if getpid is None or getpid() != self.pid: + return + msg = s.recv() + if msg == b'DIE': + break + fmt = 'L' if len(msg) == 4 else 'Q' + key = struct.unpack(fmt, msg)[0] + tup = self.gc.refs.pop(key, None) + if tup and tup.event: + tup.event.set() + del tup + s.close() + + +class GarbageCollector(object): + """PyZMQ Garbage Collector + + Used for representing the reference held by libzmq during zero-copy sends. + This object holds a dictionary, keyed by Python id, + of the Python objects whose memory are currently in use by zeromq. + + When zeromq is done with the memory, it sends a message on an inproc PUSH socket + containing the packed size_t (32 or 64-bit unsigned int), + which is the key in the dict. + When the PULL socket in the gc thread receives that message, + the reference is popped from the dict, + and any tracker events that should be signaled fire. + """ + + refs = None + _context = None + _lock = None + url = "inproc://pyzmq.gc.01" + + def __init__(self, context=None): + super(GarbageCollector, self).__init__() + self.refs = {} + self.pid = None + self.thread = None + self._context = context + self._lock = Lock() + self._stay_down = False + atexit.register(self._atexit) + + @property + def context(self): + if self._context is None: + self._context = zmq.Context() + return self._context + + @context.setter + def context(self, ctx): + if self.is_alive(): + if self.refs: + warnings.warn("Replacing gc context while gc is running", RuntimeWarning) + self.stop() + self._context = ctx + + def _atexit(self): + """atexit callback + + sets _stay_down flag so that gc doesn't try to start up again in other atexit handlers + """ + self._stay_down = True + self.stop() + + def stop(self): + """stop the garbage-collection thread""" + if not self.is_alive(): + return + self._stop() + + def _stop(self): + push = self.context.socket(zmq.PUSH) + push.connect(self.url) + push.send(b'DIE') + push.close() + self.thread.join() + self.context.term() + self.refs.clear() + self.context = None + + def start(self): + """Start a new garbage collection thread. + + Creates a new zmq Context used for garbage collection. + Under most circumstances, this will only be called once per process. + """ + if self.thread is not None and self.pid != getpid(): + # It's re-starting, must free earlier thread's context + # since a fork probably broke it + self._stop() + self.pid = getpid() + self.refs = {} + self.thread = GarbageCollectorThread(self) + self.thread.start() + self.thread.ready.wait() + + def is_alive(self): + """Is the garbage collection thread currently running? + + Includes checks for process shutdown or fork. + """ + if (getpid is None or + getpid() != self.pid or + self.thread is None or + not self.thread.is_alive() + ): + return False + return True + + def store(self, obj, event=None): + """store an object and (optionally) event for zero-copy""" + if not self.is_alive(): + if self._stay_down: + return 0 + # safely start the gc thread + # use lock and double check, + # so we don't start multiple threads + with self._lock: + if not self.is_alive(): + self.start() + tup = gcref(obj, event) + theid = id(tup) + self.refs[theid] = tup + return theid + + def __del__(self): + if not self.is_alive(): + return + try: + self.stop() + except Exception as e: + raise (e) + +gc = GarbageCollector() diff --git a/src/console/zmq/utils/getpid_compat.h b/src/console/zmq/utils/getpid_compat.h new file mode 100644 index 00000000..47ce90fa --- /dev/null +++ b/src/console/zmq/utils/getpid_compat.h @@ -0,0 +1,6 @@ +#ifdef _WIN32 + #include <process.h> + #define getpid _getpid +#else + #include <unistd.h> +#endif diff --git a/src/console/zmq/utils/interop.py b/src/console/zmq/utils/interop.py new file mode 100755 index 00000000..26c01969 --- /dev/null +++ b/src/console/zmq/utils/interop.py @@ -0,0 +1,33 @@ +"""Utils for interoperability with other libraries. + +Just CFFI pointer casting for now. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +try: + long +except NameError: + long = int # Python 3 + + +def cast_int_addr(n): + """Cast an address to a Python int + + This could be a Python integer or a CFFI pointer + """ + if isinstance(n, (int, long)): + return n + try: + import cffi + except ImportError: + pass + else: + # from pyzmq, this is an FFI void * + ffi = cffi.FFI() + if isinstance(n, ffi.CData): + return int(ffi.cast("size_t", n)) + + raise ValueError("Cannot cast %r to int" % n) diff --git a/src/console/zmq/utils/ipcmaxlen.h b/src/console/zmq/utils/ipcmaxlen.h new file mode 100644 index 00000000..7218db78 --- /dev/null +++ b/src/console/zmq/utils/ipcmaxlen.h @@ -0,0 +1,21 @@ +/* + +Platform-independant detection of IPC path max length + +Copyright (c) 2012 Godefroid Chapelle + +Distributed under the terms of the New BSD License. The full license is in +the file COPYING.BSD, distributed as part of this software. + */ + +#if defined(HAVE_SYS_UN_H) +#include "sys/un.h" +int get_ipc_path_max_len(void) { + struct sockaddr_un *dummy; + return sizeof(dummy->sun_path) - 1; +} +#else +int get_ipc_path_max_len(void) { + return 0; +} +#endif diff --git a/src/console/zmq/utils/jsonapi.py b/src/console/zmq/utils/jsonapi.py new file mode 100755 index 00000000..865ca6d5 --- /dev/null +++ b/src/console/zmq/utils/jsonapi.py @@ -0,0 +1,59 @@ +"""Priority based json library imports. + +Always serializes to bytes instead of unicode for zeromq compatibility +on Python 2 and 3. + +Use ``jsonapi.loads()`` and ``jsonapi.dumps()`` for guaranteed symmetry. + +Priority: ``simplejson`` > ``jsonlib2`` > stdlib ``json`` + +``jsonapi.loads/dumps`` provide kwarg-compatibility with stdlib json. + +``jsonapi.jsonmod`` will be the module of the actual underlying implementation. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.utils.strtypes import bytes, unicode + +jsonmod = None + +priority = ['simplejson', 'jsonlib2', 'json'] +for mod in priority: + try: + jsonmod = __import__(mod) + except ImportError: + pass + else: + break + +def dumps(o, **kwargs): + """Serialize object to JSON bytes (utf-8). + + See jsonapi.jsonmod.dumps for details on kwargs. + """ + + if 'separators' not in kwargs: + kwargs['separators'] = (',', ':') + + s = jsonmod.dumps(o, **kwargs) + + if isinstance(s, unicode): + s = s.encode('utf8') + + return s + +def loads(s, **kwargs): + """Load object from JSON bytes (utf-8). + + See jsonapi.jsonmod.loads for details on kwargs. + """ + + if str is unicode and isinstance(s, bytes): + s = s.decode('utf8') + + return jsonmod.loads(s, **kwargs) + +__all__ = ['jsonmod', 'dumps', 'loads'] + diff --git a/src/console/zmq/utils/monitor.py b/src/console/zmq/utils/monitor.py new file mode 100755 index 00000000..734d54b1 --- /dev/null +++ b/src/console/zmq/utils/monitor.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +"""Module holding utility and convenience functions for zmq event monitoring.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import struct +import zmq +from zmq.error import _check_version + +def parse_monitor_message(msg): + """decode zmq_monitor event messages. + + Parameters + ---------- + msg : list(bytes) + zmq multipart message that has arrived on a monitor PAIR socket. + + First frame is:: + + 16 bit event id + 32 bit event value + no padding + + Second frame is the endpoint as a bytestring + + Returns + ------- + event : dict + event description as dict with the keys `event`, `value`, and `endpoint`. + """ + + if len(msg) != 2 or len(msg[0]) != 6: + raise RuntimeError("Invalid event message format: %s" % msg) + event = {} + event['event'], event['value'] = struct.unpack("=hi", msg[0]) + event['endpoint'] = msg[1] + return event + +def recv_monitor_message(socket, flags=0): + """Receive and decode the given raw message from the monitoring socket and return a dict. + + Requires libzmq ≥ 4.0 + + The returned dict will have the following entries: + event : int, the event id as described in libzmq.zmq_socket_monitor + value : int, the event value associated with the event, see libzmq.zmq_socket_monitor + endpoint : string, the affected endpoint + + Parameters + ---------- + socket : zmq PAIR socket + The PAIR socket (created by other.get_monitor_socket()) on which to recv the message + flags : bitfield (int) + standard zmq recv flags + + Returns + ------- + event : dict + event description as dict with the keys `event`, `value`, and `endpoint`. + """ + _check_version((4,0), 'libzmq event API') + # will always return a list + msg = socket.recv_multipart(flags) + # 4.0-style event API + return parse_monitor_message(msg) + +__all__ = ['parse_monitor_message', 'recv_monitor_message'] diff --git a/src/console/zmq/utils/pyversion_compat.h b/src/console/zmq/utils/pyversion_compat.h new file mode 100644 index 00000000..fac09046 --- /dev/null +++ b/src/console/zmq/utils/pyversion_compat.h @@ -0,0 +1,25 @@ +#include "Python.h" + +#if PY_VERSION_HEX < 0x02070000 + #define PyMemoryView_FromBuffer(info) (PyErr_SetString(PyExc_NotImplementedError, \ + "new buffer interface is not available"), (PyObject *)NULL) + #define PyMemoryView_FromObject(object) (PyErr_SetString(PyExc_NotImplementedError, \ + "new buffer interface is not available"), (PyObject *)NULL) +#endif + +#if PY_VERSION_HEX >= 0x03000000 + // for buffers + #define Py_END_OF_BUFFER ((Py_ssize_t) 0) + + #define PyObject_CheckReadBuffer(object) (0) + + #define PyBuffer_FromMemory(ptr, s) (PyErr_SetString(PyExc_NotImplementedError, \ + "old buffer interface is not available"), (PyObject *)NULL) + #define PyBuffer_FromReadWriteMemory(ptr, s) (PyErr_SetString(PyExc_NotImplementedError, \ + "old buffer interface is not available"), (PyObject *)NULL) + #define PyBuffer_FromObject(object, offset, size) (PyErr_SetString(PyExc_NotImplementedError, \ + "old buffer interface is not available"), (PyObject *)NULL) + #define PyBuffer_FromReadWriteObject(object, offset, size) (PyErr_SetString(PyExc_NotImplementedError, \ + "old buffer interface is not available"), (PyObject *)NULL) + +#endif diff --git a/src/console/zmq/utils/sixcerpt.py b/src/console/zmq/utils/sixcerpt.py new file mode 100755 index 00000000..5492fd59 --- /dev/null +++ b/src/console/zmq/utils/sixcerpt.py @@ -0,0 +1,52 @@ +"""Excerpts of six.py""" + +# Copyright (C) 2010-2014 Benjamin Peterson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import sys + +# Useful for very coarse version differentiation. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 + +if PY3: + + def reraise(tp, value, tb=None): + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + +else: + def exec_(_code_, _globs_=None, _locs_=None): + """Execute code in a namespace.""" + if _globs_ is None: + frame = sys._getframe(1) + _globs_ = frame.f_globals + if _locs_ is None: + _locs_ = frame.f_locals + del frame + elif _locs_ is None: + _locs_ = _globs_ + exec("""exec _code_ in _globs_, _locs_""") + + + exec_("""def reraise(tp, value, tb=None): + raise tp, value, tb +""") diff --git a/src/console/zmq/utils/strtypes.py b/src/console/zmq/utils/strtypes.py new file mode 100755 index 00000000..548410dc --- /dev/null +++ b/src/console/zmq/utils/strtypes.py @@ -0,0 +1,45 @@ +"""Declare basic string types unambiguously for various Python versions. + +Authors +------- +* MinRK +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys + +if sys.version_info[0] >= 3: + bytes = bytes + unicode = str + basestring = (bytes, unicode) +else: + unicode = unicode + bytes = str + basestring = basestring + +def cast_bytes(s, encoding='utf8', errors='strict'): + """cast unicode or bytes to bytes""" + if isinstance(s, bytes): + return s + elif isinstance(s, unicode): + return s.encode(encoding, errors) + else: + raise TypeError("Expected unicode or bytes, got %r" % s) + +def cast_unicode(s, encoding='utf8', errors='strict'): + """cast bytes or unicode to unicode""" + if isinstance(s, bytes): + return s.decode(encoding, errors) + elif isinstance(s, unicode): + return s + else: + raise TypeError("Expected unicode or bytes, got %r" % s) + +# give short 'b' alias for cast_bytes, so that we can use fake b('stuff') +# to simulate b'stuff' +b = asbytes = cast_bytes +u = cast_unicode + +__all__ = ['asbytes', 'bytes', 'unicode', 'basestring', 'b', 'u', 'cast_bytes', 'cast_unicode'] diff --git a/src/console/zmq/utils/win32.py b/src/console/zmq/utils/win32.py new file mode 100755 index 00000000..ea758299 --- /dev/null +++ b/src/console/zmq/utils/win32.py @@ -0,0 +1,132 @@ +"""Win32 compatibility utilities.""" + +#----------------------------------------------------------------------------- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +#----------------------------------------------------------------------------- + +import os + +# No-op implementation for other platforms. +class _allow_interrupt(object): + """Utility for fixing CTRL-C events on Windows. + + On Windows, the Python interpreter intercepts CTRL-C events in order to + translate them into ``KeyboardInterrupt`` exceptions. It (presumably) + does this by setting a flag in its "control control handler" and + checking it later at a convenient location in the interpreter. + + However, when the Python interpreter is blocked waiting for the ZMQ + poll operation to complete, it must wait for ZMQ's ``select()`` + operation to complete before translating the CTRL-C event into the + ``KeyboardInterrupt`` exception. + + The only way to fix this seems to be to add our own "console control + handler" and perform some application-defined operation that will + unblock the ZMQ polling operation in order to force ZMQ to pass control + back to the Python interpreter. + + This context manager performs all that Windows-y stuff, providing you + with a hook that is called when a CTRL-C event is intercepted. This + hook allows you to unblock your ZMQ poll operation immediately, which + will then result in the expected ``KeyboardInterrupt`` exception. + + Without this context manager, your ZMQ-based application will not + respond normally to CTRL-C events on Windows. If a CTRL-C event occurs + while blocked on ZMQ socket polling, the translation to a + ``KeyboardInterrupt`` exception will be delayed until the I/O completes + and control returns to the Python interpreter (this may never happen if + you use an infinite timeout). + + A no-op implementation is provided on non-Win32 systems to avoid the + application from having to conditionally use it. + + Example usage: + + .. sourcecode:: python + + def stop_my_application(): + # ... + + with allow_interrupt(stop_my_application): + # main polling loop. + + In a typical ZMQ application, you would use the "self pipe trick" to + send message to a ``PAIR`` socket in order to interrupt your blocking + socket polling operation. + + In a Tornado event loop, you can use the ``IOLoop.stop`` method to + unblock your I/O loop. + """ + + def __init__(self, action=None): + """Translate ``action`` into a CTRL-C handler. + + ``action`` is a callable that takes no arguments and returns no + value (returned value is ignored). It must *NEVER* raise an + exception. + + If unspecified, a no-op will be used. + """ + self._init_action(action) + + def _init_action(self, action): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + return + +if os.name == 'nt': + from ctypes import WINFUNCTYPE, windll + from ctypes.wintypes import BOOL, DWORD + + kernel32 = windll.LoadLibrary('kernel32') + + # <http://msdn.microsoft.com/en-us/library/ms686016.aspx> + PHANDLER_ROUTINE = WINFUNCTYPE(BOOL, DWORD) + SetConsoleCtrlHandler = kernel32.SetConsoleCtrlHandler + SetConsoleCtrlHandler.argtypes = (PHANDLER_ROUTINE, BOOL) + SetConsoleCtrlHandler.restype = BOOL + + class allow_interrupt(_allow_interrupt): + __doc__ = _allow_interrupt.__doc__ + + def _init_action(self, action): + if action is None: + action = lambda: None + self.action = action + @PHANDLER_ROUTINE + def handle(event): + if event == 0: # CTRL_C_EVENT + action() + # Typical C implementations would return 1 to indicate that + # the event was processed and other control handlers in the + # stack should not be executed. However, that would + # prevent the Python interpreter's handler from translating + # CTRL-C to a `KeyboardInterrupt` exception, so we pretend + # that we didn't handle it. + return 0 + self.handle = handle + + def __enter__(self): + """Install the custom CTRL-C handler.""" + result = SetConsoleCtrlHandler(self.handle, 1) + if result == 0: + # Have standard library automatically call `GetLastError()` and + # `FormatMessage()` into a nice exception object :-) + raise WindowsError() + + def __exit__(self, *args): + """Remove the custom CTRL-C handler.""" + result = SetConsoleCtrlHandler(self.handle, 0) + if result == 0: + # Have standard library automatically call `GetLastError()` and + # `FormatMessage()` into a nice exception object :-) + raise WindowsError() +else: + class allow_interrupt(_allow_interrupt): + __doc__ = _allow_interrupt.__doc__ + pass diff --git a/src/console/zmq/utils/z85.py b/src/console/zmq/utils/z85.py new file mode 100755 index 00000000..1bb1784e --- /dev/null +++ b/src/console/zmq/utils/z85.py @@ -0,0 +1,56 @@ +"""Python implementation of Z85 85-bit encoding + +Z85 encoding is a plaintext encoding for a bytestring interpreted as 32bit integers. +Since the chunks are 32bit, a bytestring must be a multiple of 4 bytes. +See ZMQ RFC 32 for details. + + +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys +import struct + +PY3 = sys.version_info[0] >= 3 +# Z85CHARS is the base 85 symbol table +Z85CHARS = b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.-:+=^!/*?&<>()[]{}@%$#" +# Z85MAP maps integers in [0,84] to the appropriate character in Z85CHARS +Z85MAP = dict([(c, idx) for idx, c in enumerate(Z85CHARS)]) + +_85s = [ 85**i for i in range(5) ][::-1] + +def encode(rawbytes): + """encode raw bytes into Z85""" + # Accepts only byte arrays bounded to 4 bytes + if len(rawbytes) % 4: + raise ValueError("length must be multiple of 4, not %i" % len(rawbytes)) + + nvalues = len(rawbytes) / 4 + + values = struct.unpack('>%dI' % nvalues, rawbytes) + encoded = [] + for v in values: + for offset in _85s: + encoded.append(Z85CHARS[(v // offset) % 85]) + + # In Python 3, encoded is a list of integers (obviously?!) + if PY3: + return bytes(encoded) + else: + return b''.join(encoded) + +def decode(z85bytes): + """decode Z85 bytes to raw bytes""" + if len(z85bytes) % 5: + raise ValueError("Z85 length must be multiple of 5, not %i" % len(z85bytes)) + + nvalues = len(z85bytes) / 5 + values = [] + for i in range(0, len(z85bytes), 5): + value = 0 + for j, offset in enumerate(_85s): + value += Z85MAP[z85bytes[i+j]] * offset + values.append(value) + return struct.pack('>%dI' % nvalues, *values) diff --git a/src/console/zmq/utils/zmq_compat.h b/src/console/zmq/utils/zmq_compat.h new file mode 100644 index 00000000..81c57b69 --- /dev/null +++ b/src/console/zmq/utils/zmq_compat.h @@ -0,0 +1,80 @@ +//----------------------------------------------------------------------------- +// Copyright (c) 2010 Brian Granger, Min Ragan-Kelley +// +// Distributed under the terms of the New BSD License. The full license is in +// the file COPYING.BSD, distributed as part of this software. +//----------------------------------------------------------------------------- + +#if defined(_MSC_VER) +#define pyzmq_int64_t __int64 +#else +#include <stdint.h> +#define pyzmq_int64_t int64_t +#endif + + +#include "zmq.h" +// version compatibility for constants: +#include "zmq_constants.h" + +#define _missing (-1) + + +// define fd type (from libzmq's fd.hpp) +#ifdef _WIN32 + #ifdef _MSC_VER && _MSC_VER <= 1400 + #define ZMQ_FD_T UINT_PTR + #else + #define ZMQ_FD_T SOCKET + #endif +#else + #define ZMQ_FD_T int +#endif + +// use unambiguous aliases for zmq_send/recv functions + +#if ZMQ_VERSION_MAJOR >= 4 +// nothing to remove +#else + #define zmq_curve_keypair(z85_public_key, z85_secret_key) _missing +#endif + +#if ZMQ_VERSION_MAJOR >= 4 && ZMQ_VERSION_MINOR >= 1 +// nothing to remove +#else + #define zmq_msg_gets(msg, prop) _missing + #define zmq_has(capability) _missing +#endif + +#if ZMQ_VERSION_MAJOR >= 3 + #define zmq_sendbuf zmq_send + #define zmq_recvbuf zmq_recv + + // 3.x deprecations - these symbols haven't been removed, + // but let's protect against their planned removal + #define zmq_device(device_type, isocket, osocket) _missing + #define zmq_init(io_threads) ((void*)NULL) + #define zmq_term zmq_ctx_destroy +#else + #define zmq_ctx_set(ctx, opt, val) _missing + #define zmq_ctx_get(ctx, opt) _missing + #define zmq_ctx_destroy zmq_term + #define zmq_ctx_new() ((void*)NULL) + + #define zmq_proxy(a,b,c) _missing + + #define zmq_disconnect(s, addr) _missing + #define zmq_unbind(s, addr) _missing + + #define zmq_msg_more(msg) _missing + #define zmq_msg_get(msg, opt) _missing + #define zmq_msg_set(msg, opt, val) _missing + #define zmq_msg_send(msg, s, flags) zmq_send(s, msg, flags) + #define zmq_msg_recv(msg, s, flags) zmq_recv(s, msg, flags) + + #define zmq_sendbuf(s, buf, len, flags) _missing + #define zmq_recvbuf(s, buf, len, flags) _missing + + #define zmq_socket_monitor(s, addr, flags) _missing + +#endif diff --git a/src/console/zmq/utils/zmq_constants.h b/src/console/zmq/utils/zmq_constants.h new file mode 100644 index 00000000..97683022 --- /dev/null +++ b/src/console/zmq/utils/zmq_constants.h @@ -0,0 +1,622 @@ +#ifndef _PYZMQ_CONSTANT_DEFS +#define _PYZMQ_CONSTANT_DEFS + +#define _PYZMQ_UNDEFINED (-9999) +#ifndef ZMQ_VERSION + #define ZMQ_VERSION (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_VERSION_MAJOR + #define ZMQ_VERSION_MAJOR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_VERSION_MINOR + #define ZMQ_VERSION_MINOR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_VERSION_PATCH + #define ZMQ_VERSION_PATCH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_NOBLOCK + #define ZMQ_NOBLOCK (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DONTWAIT + #define ZMQ_DONTWAIT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLIN + #define ZMQ_POLLIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLOUT + #define ZMQ_POLLOUT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLERR + #define ZMQ_POLLERR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDMORE + #define ZMQ_SNDMORE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_STREAMER + #define ZMQ_STREAMER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_FORWARDER + #define ZMQ_FORWARDER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_QUEUE + #define ZMQ_QUEUE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IO_THREADS_DFLT + #define ZMQ_IO_THREADS_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MAX_SOCKETS_DFLT + #define ZMQ_MAX_SOCKETS_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLITEMS_DFLT + #define ZMQ_POLLITEMS_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_PRIORITY_DFLT + #define ZMQ_THREAD_PRIORITY_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_SCHED_POLICY_DFLT + #define ZMQ_THREAD_SCHED_POLICY_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PAIR + #define ZMQ_PAIR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PUB + #define ZMQ_PUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SUB + #define ZMQ_SUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REQ + #define ZMQ_REQ (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REP + #define ZMQ_REP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DEALER + #define ZMQ_DEALER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER + #define ZMQ_ROUTER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XREQ + #define ZMQ_XREQ (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XREP + #define ZMQ_XREP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PULL + #define ZMQ_PULL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PUSH + #define ZMQ_PUSH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XPUB + #define ZMQ_XPUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XSUB + #define ZMQ_XSUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_UPSTREAM + #define ZMQ_UPSTREAM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DOWNSTREAM + #define ZMQ_DOWNSTREAM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_STREAM + #define ZMQ_STREAM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CONNECTED + #define ZMQ_EVENT_CONNECTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CONNECT_DELAYED + #define ZMQ_EVENT_CONNECT_DELAYED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CONNECT_RETRIED + #define ZMQ_EVENT_CONNECT_RETRIED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_LISTENING + #define ZMQ_EVENT_LISTENING (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_BIND_FAILED + #define ZMQ_EVENT_BIND_FAILED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_ACCEPTED + #define ZMQ_EVENT_ACCEPTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_ACCEPT_FAILED + #define ZMQ_EVENT_ACCEPT_FAILED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CLOSED + #define ZMQ_EVENT_CLOSED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CLOSE_FAILED + #define ZMQ_EVENT_CLOSE_FAILED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_DISCONNECTED + #define ZMQ_EVENT_DISCONNECTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_ALL + #define ZMQ_EVENT_ALL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_MONITOR_STOPPED + #define ZMQ_EVENT_MONITOR_STOPPED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_NULL + #define ZMQ_NULL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN + #define ZMQ_PLAIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE + #define ZMQ_CURVE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI + #define ZMQ_GSSAPI (_PYZMQ_UNDEFINED) +#endif + +#ifndef EAGAIN + #define EAGAIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef EINVAL + #define EINVAL (_PYZMQ_UNDEFINED) +#endif + +#ifndef EFAULT + #define EFAULT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOMEM + #define ENOMEM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENODEV + #define ENODEV (_PYZMQ_UNDEFINED) +#endif + +#ifndef EMSGSIZE + #define EMSGSIZE (_PYZMQ_UNDEFINED) +#endif + +#ifndef EAFNOSUPPORT + #define EAFNOSUPPORT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENETUNREACH + #define ENETUNREACH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ECONNABORTED + #define ECONNABORTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ECONNRESET + #define ECONNRESET (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOTCONN + #define ENOTCONN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ETIMEDOUT + #define ETIMEDOUT (_PYZMQ_UNDEFINED) +#endif + +#ifndef EHOSTUNREACH + #define EHOSTUNREACH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENETRESET + #define ENETRESET (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_HAUSNUMERO + #define ZMQ_HAUSNUMERO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOTSUP + #define ENOTSUP (_PYZMQ_UNDEFINED) +#endif + +#ifndef EPROTONOSUPPORT + #define EPROTONOSUPPORT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOBUFS + #define ENOBUFS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENETDOWN + #define ENETDOWN (_PYZMQ_UNDEFINED) +#endif + +#ifndef EADDRINUSE + #define EADDRINUSE (_PYZMQ_UNDEFINED) +#endif + +#ifndef EADDRNOTAVAIL + #define EADDRNOTAVAIL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ECONNREFUSED + #define ECONNREFUSED (_PYZMQ_UNDEFINED) +#endif + +#ifndef EINPROGRESS + #define EINPROGRESS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOTSOCK + #define ENOTSOCK (_PYZMQ_UNDEFINED) +#endif + +#ifndef EFSM + #define EFSM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOCOMPATPROTO + #define ENOCOMPATPROTO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ETERM + #define ETERM (_PYZMQ_UNDEFINED) +#endif + +#ifndef EMTHREAD + #define EMTHREAD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IO_THREADS + #define ZMQ_IO_THREADS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MAX_SOCKETS + #define ZMQ_MAX_SOCKETS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SOCKET_LIMIT + #define ZMQ_SOCKET_LIMIT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_PRIORITY + #define ZMQ_THREAD_PRIORITY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_SCHED_POLICY + #define ZMQ_THREAD_SCHED_POLICY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IDENTITY + #define ZMQ_IDENTITY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SUBSCRIBE + #define ZMQ_SUBSCRIBE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_UNSUBSCRIBE + #define ZMQ_UNSUBSCRIBE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_LAST_ENDPOINT + #define ZMQ_LAST_ENDPOINT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_ACCEPT_FILTER + #define ZMQ_TCP_ACCEPT_FILTER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN_USERNAME + #define ZMQ_PLAIN_USERNAME (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN_PASSWORD + #define ZMQ_PLAIN_PASSWORD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_PUBLICKEY + #define ZMQ_CURVE_PUBLICKEY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_SECRETKEY + #define ZMQ_CURVE_SECRETKEY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_SERVERKEY + #define ZMQ_CURVE_SERVERKEY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ZAP_DOMAIN + #define ZMQ_ZAP_DOMAIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CONNECT_RID + #define ZMQ_CONNECT_RID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_PRINCIPAL + #define ZMQ_GSSAPI_PRINCIPAL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_SERVICE_PRINCIPAL + #define ZMQ_GSSAPI_SERVICE_PRINCIPAL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SOCKS_PROXY + #define ZMQ_SOCKS_PROXY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_FD + #define ZMQ_FD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IDENTITY_FD + #define ZMQ_IDENTITY_FD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECONNECT_IVL_MAX + #define ZMQ_RECONNECT_IVL_MAX (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDTIMEO + #define ZMQ_SNDTIMEO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVTIMEO + #define ZMQ_RCVTIMEO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDHWM + #define ZMQ_SNDHWM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVHWM + #define ZMQ_RCVHWM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MULTICAST_HOPS + #define ZMQ_MULTICAST_HOPS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPV4ONLY + #define ZMQ_IPV4ONLY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_BEHAVIOR + #define ZMQ_ROUTER_BEHAVIOR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE + #define ZMQ_TCP_KEEPALIVE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE_CNT + #define ZMQ_TCP_KEEPALIVE_CNT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE_IDLE + #define ZMQ_TCP_KEEPALIVE_IDLE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE_INTVL + #define ZMQ_TCP_KEEPALIVE_INTVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DELAY_ATTACH_ON_CONNECT + #define ZMQ_DELAY_ATTACH_ON_CONNECT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XPUB_VERBOSE + #define ZMQ_XPUB_VERBOSE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENTS + #define ZMQ_EVENTS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TYPE + #define ZMQ_TYPE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_LINGER + #define ZMQ_LINGER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECONNECT_IVL + #define ZMQ_RECONNECT_IVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_BACKLOG + #define ZMQ_BACKLOG (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_MANDATORY + #define ZMQ_ROUTER_MANDATORY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_FAIL_UNROUTABLE + #define ZMQ_FAIL_UNROUTABLE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_RAW + #define ZMQ_ROUTER_RAW (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IMMEDIATE + #define ZMQ_IMMEDIATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPV6 + #define ZMQ_IPV6 (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MECHANISM + #define ZMQ_MECHANISM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN_SERVER + #define ZMQ_PLAIN_SERVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_SERVER + #define ZMQ_CURVE_SERVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PROBE_ROUTER + #define ZMQ_PROBE_ROUTER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REQ_RELAXED + #define ZMQ_REQ_RELAXED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REQ_CORRELATE + #define ZMQ_REQ_CORRELATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CONFLATE + #define ZMQ_CONFLATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_HANDOVER + #define ZMQ_ROUTER_HANDOVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TOS + #define ZMQ_TOS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPC_FILTER_PID + #define ZMQ_IPC_FILTER_PID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPC_FILTER_UID + #define ZMQ_IPC_FILTER_UID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPC_FILTER_GID + #define ZMQ_IPC_FILTER_GID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_SERVER + #define ZMQ_GSSAPI_SERVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_PLAINTEXT + #define ZMQ_GSSAPI_PLAINTEXT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_HANDSHAKE_IVL + #define ZMQ_HANDSHAKE_IVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XPUB_NODROP + #define ZMQ_XPUB_NODROP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_AFFINITY + #define ZMQ_AFFINITY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MAXMSGSIZE + #define ZMQ_MAXMSGSIZE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_HWM + #define ZMQ_HWM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SWAP + #define ZMQ_SWAP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MCAST_LOOP + #define ZMQ_MCAST_LOOP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECOVERY_IVL_MSEC + #define ZMQ_RECOVERY_IVL_MSEC (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RATE + #define ZMQ_RATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECOVERY_IVL + #define ZMQ_RECOVERY_IVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDBUF + #define ZMQ_SNDBUF (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVBUF + #define ZMQ_RCVBUF (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVMORE + #define ZMQ_RCVMORE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MORE + #define ZMQ_MORE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SRCFD + #define ZMQ_SRCFD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SHARED + #define ZMQ_SHARED (_PYZMQ_UNDEFINED) +#endif + + +#endif // ifndef _PYZMQ_CONSTANT_DEFS diff --git a/src/gtest/rpc_test.cpp b/src/gtest/rpc_test.cpp new file mode 100644 index 00000000..068457f3 --- /dev/null +++ b/src/gtest/rpc_test.cpp @@ -0,0 +1,226 @@ +/* + Itay Marom + + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include <common/gtest.h> +#include <trex_rpc_server_api.h> +#include <zmq.h> +#include <json/json.h> +#include <sstream> + +using namespace std; + +class RpcTest : public testing::Test { + + virtual void SetUp() { + TrexRpcServerConfig cfg = TrexRpcServerConfig(TrexRpcServerConfig::RPC_PROT_TCP, 5050); + + m_rpc = new TrexRpcServer(cfg); + m_rpc->start(); + + m_context = zmq_ctx_new (); + m_socket = zmq_socket (m_context, ZMQ_REQ); + zmq_connect (m_socket, "tcp://localhost:5050"); + } + + virtual void TearDown() { + m_rpc->stop(); + + delete m_rpc; + zmq_close(m_socket); + zmq_term(m_context); + } + +public: + string send_msg(const string &msg) { + char buffer[512]; + + zmq_send (m_socket, msg.c_str(), msg.size(), 0); + int len = zmq_recv(m_socket, buffer, sizeof(buffer), 0); + + return string(buffer, len); + } + + TrexRpcServer *m_rpc; + void *m_context; + void *m_socket; +}; + +TEST_F(RpcTest, basic_rpc_test) { + Json::Value request; + Json::Value response; + Json::Reader reader; + + string req_str; + string resp_str; + + // check bad JSON format + req_str = "bad format message"; + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_EQ(response["jsonrpc"], "2.0"); + EXPECT_EQ(response["id"], Json::Value::null); + EXPECT_EQ(response["error"]["code"], -32700); + + // check bad version + req_str = "{\"jsonrpc\": \"1.5\", \"method\": \"foobar\", \"id\": \"1\"}"; + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_EQ(response["jsonrpc"], "2.0"); + EXPECT_EQ(response["id"], "1"); + EXPECT_EQ(response["error"]["code"], -32600); + + // no method name present + req_str = "{\"jsonrpc\": \"2.0\", \"id\": 482}"; + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_EQ(response["jsonrpc"], "2.0"); + EXPECT_EQ(response["id"], 482); + EXPECT_EQ(response["error"]["code"], -32600); + + /* method does not exist */ + req_str = "{\"jsonrpc\": \"2.0\", \"method\": \"jfgldjlfds\", \"id\": 482}"; + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_EQ(response["jsonrpc"], "2.0"); + EXPECT_EQ(response["id"], 482); + EXPECT_EQ(response["error"]["code"], -32601); + + /* error but as notification */ + req_str = "{\"jsonrpc\": \"2.0\", \"method\": \"jfgldjlfds\"}"; + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_TRUE(response == Json::Value::null); + + +} + +TEST_F(RpcTest, test_add_command) { + Json::Value request; + Json::Value response; + Json::Reader reader; + + string req_str; + string resp_str; + + /* simple add - missing paramters */ + req_str = "{\"jsonrpc\": \"2.0\", \"method\": \"rpc_test_add\", \"id\": 488}"; + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_EQ(response["jsonrpc"], "2.0"); + EXPECT_EQ(response["id"], 488); + EXPECT_EQ(response["error"]["code"], -32602); + + /* simple add that works */ + req_str = "{\"jsonrpc\": \"2.0\", \"method\": \"rpc_test_add\", \"params\": {\"x\": 17, \"y\": -13} , \"id\": \"itay\"}"; + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_EQ(response["jsonrpc"], "2.0"); + EXPECT_EQ(response["id"], "itay"); + EXPECT_EQ(response["result"], 4); + + /* add with bad paratemers types */ + req_str = "{\"jsonrpc\": \"2.0\", \"method\": \"rpc_test_add\", \"params\": {\"x\": \"blah\", \"y\": -13} , \"id\": 17}"; + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_EQ(response["jsonrpc"], "2.0"); + EXPECT_EQ(response["id"], 17); + EXPECT_EQ(response["error"]["code"], -32602); + + /* add with invalid count of parameters */ + req_str = "{\"jsonrpc\": \"2.0\", \"method\": \"rpc_test_add\", \"params\": {\"y\": -13} , \"id\": 17}"; + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_EQ(response["jsonrpc"], "2.0"); + EXPECT_EQ(response["id"], 17); + EXPECT_EQ(response["error"]["code"], -32602); + + + /* big numbers */ + req_str = "{\"jsonrpc\": \"2.0\", \"method\": \"rpc_test_add\", \"params\": {\"x\": 4827371, \"y\": -39181273} , \"id\": \"itay\"}"; + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_EQ(response["jsonrpc"], "2.0"); + EXPECT_EQ(response["id"], "itay"); + EXPECT_EQ(response["result"], -34353902); + +} + +TEST_F(RpcTest, batch_rpc_test) { + Json::Value request; + Json::Value response; + Json::Reader reader; + + string req_str; + string resp_str; + + req_str = "[ \ + {\"jsonrpc\": \"2.0\", \"method\": \"rpc_test_add\", \"params\": {\"x\": 22, \"y\": 17}, \"id\": \"1\"}, \ + {\"jsonrpc\": \"2.0\", \"method\": \"rpc_test_sub\", \"params\": {\"x\": 22, \"y\": 17}, \"id\": \"2\"}, \ + {\"jsonrpc\": \"2.0\", \"method\": \"rpc_test_add\", \"params\": {\"x\": 22, \"y\": \"itay\"}, \"id\": \"2\"}, \ + {\"foo\": \"boo\"}, \ + {\"jsonrpc\": \"2.0\", \"method\": \"test_rpc_sheker\", \"params\": {\"name\": \"myself\"}, \"id\": 5}, \ + {\"jsonrpc\": \"2.0\", \"method\": \"rpc_test_add\", \"params\": {\"x\": 22, \"y\": 17} } \ + ]"; + + resp_str = send_msg(req_str); + + EXPECT_TRUE(reader.parse(resp_str, response, false)); + EXPECT_TRUE(response.isArray()); + + // message 1 + EXPECT_TRUE(response[0]["jsonrpc"] == "2.0"); + EXPECT_TRUE(response[0]["id"] == "1"); + EXPECT_TRUE(response[0]["result"] == 39); + + // message 2 + EXPECT_TRUE(response[1]["jsonrpc"] == "2.0"); + EXPECT_TRUE(response[1]["id"] == "2"); + EXPECT_TRUE(response[1]["result"] == 5); + + // message 3 + EXPECT_TRUE(response[2]["jsonrpc"] == "2.0"); + EXPECT_TRUE(response[2]["id"] == "2"); + EXPECT_TRUE(response[2]["error"]["code"] == -32602); + + // message 4 + EXPECT_TRUE(response[3] == Json::Value::null); + + // message 5 + EXPECT_TRUE(response[4]["jsonrpc"] == "2.0"); + EXPECT_TRUE(response[4]["id"] == 5); + EXPECT_TRUE(response[4]["error"]["code"] == -32601); + + // message 6 - no ID but a valid command + EXPECT_TRUE(response[5] == Json::Value::null); + + return; +} diff --git a/src/main_dpdk.cpp b/src/main_dpdk.cpp index a0258ef1..a748178d 100755 --- a/src/main_dpdk.cpp +++ b/src/main_dpdk.cpp @@ -71,8 +71,6 @@ extern "C" { #include "utl_term_io.h" #include "msg_manager.h" #include "platform_cfg.h" - -#define VERSION "1.73" #define RX_CHECK_MIX_SAMPLE_RATE 8 @@ -4394,7 +4392,7 @@ int main_test(int argc , char * argv[]){ int ret; unsigned lcore_id; - printf("Starting TRex %s please wait ... \n",VERSION); + printf("Starting TRex %s please wait ... \n",VERSION_BUILD_NUM); CGlobalInfo::m_options.preview.clean(); @@ -4575,7 +4573,6 @@ void CTRexExtendedDriverBase1G::update_global_config_fdir(port_cfg_t * cfg){ cfg->update_global_config_fdir_10g_1g(); } - int CTRexExtendedDriverBase1G::configure_rx_filter_rules(CPhyEthIF * _if){ uint16_t hops = get_rx_check_hops(); @@ -4587,15 +4584,15 @@ int CTRexExtendedDriverBase1G::configure_rx_filter_rules(CPhyEthIF * _if){ int i; // IPv4: bytes being compared are {TTL, Protocol} uint16_t ff_rules_v4[4]={ - 0xFF06 - v4_hops, - 0xFE11 - v4_hops, - 0xFF11 - v4_hops, - 0xFE06 - v4_hops, + (uint16_t)(0xFF06 - v4_hops), + (uint16_t)(0xFE11 - v4_hops), + (uint16_t)(0xFF11 - v4_hops), + (uint16_t)(0xFE06 - v4_hops), } ; // IPv6: bytes being compared are {NextHdr, HopLimit} uint16_t ff_rules_v6[2]={ - 0x3CFF - hops, - 0x3CFE - hops, + (uint16_t)(0x3CFF - hops), + (uint16_t)(0x3CFE - hops), } ; uint16_t *ff_rules; uint16_t num_rules; @@ -4733,17 +4730,17 @@ int CTRexExtendedDriverBase10G::configure_rx_filter_rules(CPhyEthIF * _if){ // IPv4: bytes being compared are {TTL, Protocol} uint16_t ff_rules_v4[4]={ - 0xFF11 - v4_hops, - 0xFE11 - v4_hops, - 0xFF06 - v4_hops, - 0xFE06 - v4_hops, + (uint16_t)(0xFF11 - v4_hops), + (uint16_t)(0xFE11 - v4_hops), + (uint16_t)(0xFF06 - v4_hops), + (uint16_t)(0xFE06 - v4_hops), } ; // IPv6: bytes being compared are {NextHdr, HopLimit} uint16_t ff_rules_v6[4]={ - 0x3CFF - hops, - 0x3CFE - hops, - 0x3CFF - hops, - 0x3CFE - hops, + (uint16_t)(0x3CFF - hops), + (uint16_t)(0x3CFE - hops), + (uint16_t)(0x3CFF - hops), + (uint16_t)(0x3CFE - hops), } ; const rte_l4type ff_rules_type[4]={ RTE_FDIR_L4TYPE_UDP, diff --git a/src/pal/linux_dpdk/dpdk180/rte_config.h b/src/pal/linux_dpdk/dpdk180/rte_config.h index 68dd7a7b..0603ed06 100755 --- a/src/pal/linux_dpdk/dpdk180/rte_config.h +++ b/src/pal/linux_dpdk/dpdk180/rte_config.h @@ -1,5 +1,8 @@ #ifndef __RTE_CONFIG_H #define __RTE_CONFIG_H + +#define typeof __typeof__ + #undef RTE_EXEC_ENV #define RTE_EXEC_ENV "linuxapp" #undef RTE_EXEC_ENV_LINUXAPP diff --git a/src/rpc-server/include/trex_rpc_cmd_api.h b/src/rpc-server/include/trex_rpc_cmd_api.h new file mode 100644 index 00000000..c773b15f --- /dev/null +++ b/src/rpc-server/include/trex_rpc_cmd_api.h @@ -0,0 +1,90 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef __TREX_RPC_CMD_API_H__ +#define __TREX_RPC_CMD_API_H__ + +#include <string> +#include <vector> +#include <json/json.h> + +/** + * interface for RPC command + * + * @author imarom (13-Aug-15) + */ +class TrexRpcCommand { +public: + + /** + * describe different types of rc for run() + */ + enum rpc_cmd_rc_e { + RPC_CMD_OK, + RPC_CMD_PARAM_COUNT_ERR = 1, + RPC_CMD_PARAM_PARSE_ERR, + RPC_CMD_INTERNAL_ERR + }; + + /** + * method name and params + */ + TrexRpcCommand(const std::string &method_name) : m_name(method_name) { + + } + + rpc_cmd_rc_e run(const Json::Value ¶ms, Json::Value &result) { + return _run(params, result); + } + + const std::string &get_name() { + return m_name; + } + + virtual ~TrexRpcCommand() {} + +protected: + + /** + * implemented by the dervied class + * + */ + virtual rpc_cmd_rc_e _run(const Json::Value ¶ms, Json::Value &result) = 0; + + /** + * error generating functions + * + */ + void genernate_err(Json::Value &result, const std::string &msg) { + result["specific_err"] = msg; + } + + void generate_err_param_count(Json::Value &result, int expected, int provided) { + std::stringstream ss; + ss << "method expects '" << expected << "' paramteres, '" << provided << "' provided"; + genernate_err(result, ss.str()); + } + + std::string m_name; +}; + +#endif /* __TREX_RPC_CMD_API_H__ */ + diff --git a/src/rpc-server/include/trex_rpc_cmds_table.h b/src/rpc-server/include/trex_rpc_cmds_table.h new file mode 100644 index 00000000..a41944f1 --- /dev/null +++ b/src/rpc-server/include/trex_rpc_cmds_table.h @@ -0,0 +1,79 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef __TREX_RPC_CMDS_TABLE_H__ +#define __TREX_RPC_CMDS_TABLE_H__ + +#include <unordered_map> +#include <string> +#include <vector> +#include <json/json.h> + +class TrexRpcCommand; + +/** + * holds all the commands registered + * + * @author imarom (13-Aug-15) + */ +class TrexRpcCommandsTable { + +public: + + static TrexRpcCommandsTable& get_instance() { + static TrexRpcCommandsTable instance; + return instance; + } + + /** + * register a new command + * + */ + void register_command(TrexRpcCommand *command); + + /** + * lookup for a command + * + */ + TrexRpcCommand * lookup(const std::string &method_name); + + /** + * query all commands registered + * + */ + void query(std::vector<std::string> &cmds); + +private: + TrexRpcCommandsTable(); + ~TrexRpcCommandsTable(); + + /* c++ 2011 style singleton */ + TrexRpcCommandsTable(TrexRpcCommandsTable const&) = delete; + void operator=(TrexRpcCommandsTable const&) = delete; + + /** + * holds all the registered RPC commands + * + */ + std::unordered_map<std::string, TrexRpcCommand *> m_rpc_cmd_table; +}; + +#endif /* __TREX_RPC_CMDS_TABLE_H__ */ diff --git a/src/rpc-server/include/trex_rpc_exception_api.h b/src/rpc-server/include/trex_rpc_exception_api.h new file mode 100644 index 00000000..8783c219 --- /dev/null +++ b/src/rpc-server/include/trex_rpc_exception_api.h @@ -0,0 +1,39 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef __TREX_RPC_EXCEPTION_API_H__ +#define __TREX_RPC_EXCEPTION_API_H__ + +#include <string> +#include <stdexcept> + +/** + * generic exception for RPC errors + * + */ +class TrexRpcException : public std::runtime_error +{ +public: + TrexRpcException(const std::string &what) : std::runtime_error(what) { + } +}; + +#endif /* __TREX_RPC_EXCEPTION_API_H__ */ diff --git a/src/rpc-server/include/trex_rpc_jsonrpc_v2_parser.h b/src/rpc-server/include/trex_rpc_jsonrpc_v2_parser.h new file mode 100644 index 00000000..3367ad6a --- /dev/null +++ b/src/rpc-server/include/trex_rpc_jsonrpc_v2_parser.h @@ -0,0 +1,94 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef __TREX_RPC_JSONRPC_V2_PARSER_H__ +#define __TREX_RPC_JSONRPC_V2_PARSER_H__ + +#include <string> +#include <vector> +#include <json/json.h> + +/** + * JSON RPC V2 parsed object + * + * @author imarom (12-Aug-15) + */ +class TrexJsonRpcV2ParsedObject { +public: + + TrexJsonRpcV2ParsedObject(const Json::Value &msg_id, bool force); + virtual ~TrexJsonRpcV2ParsedObject() {} + + /** + * main function to execute the command + * + */ + void execute(Json::Value &response); + +protected: + + /** + * instance private implementation + */ + virtual void _execute(Json::Value &response) = 0; + + Json::Value m_msg_id; + bool m_respond; +}; + +/** + * JSON RPC V2 parser + * + * @author imarom (12-Aug-15) + */ +class TrexJsonRpcV2Parser { + +public: + + /** + * creates a JSON-RPC object from a string + * + * @author imarom (12-Aug-15) + * + * @param msg + */ + TrexJsonRpcV2Parser(const std::string &msg); + + /** + * parses the string to a executable commands vector + * + * @author imarom (12-Aug-15) + */ + void parse(std::vector<TrexJsonRpcV2ParsedObject *> &commands); + +private: + + /** + * handle a single request + * + */ + void parse_single_request(Json::Value &request, std::vector<TrexJsonRpcV2ParsedObject *> &commands); + + std::string m_msg; +}; + +#endif /* __TREX_RPC_JSONRPC_V2_PARSER_H__ */ + diff --git a/src/rpc-server/include/trex_rpc_req_resp_server.h b/src/rpc-server/include/trex_rpc_req_resp_server.h new file mode 100644 index 00000000..f12d0540 --- /dev/null +++ b/src/rpc-server/include/trex_rpc_req_resp_server.h @@ -0,0 +1,51 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef __TREX_RPC_REQ_RESP_API_H__ +#define __TREX_RPC_REQ_RESP_API_H__ + +#include <trex_rpc_server_api.h> + +/** + * request-response RPC server + * + * @author imarom (11-Aug-15) + */ +class TrexRpcServerReqRes : public TrexRpcServerInterface { +public: + + TrexRpcServerReqRes(const TrexRpcServerConfig &cfg); + +protected: + void _rpc_thread_cb(); + void _stop_rpc_thread(); + +private: + void handle_request(const std::string &request); + + static const int RPC_MAX_MSG_SIZE = 2048; + void *m_context; + void *m_socket; + uint8_t m_msg_buffer[RPC_MAX_MSG_SIZE]; +}; + + +#endif /* __TREX_RPC_REQ_RESP_API_H__ */ diff --git a/src/rpc-server/include/trex_rpc_server_api.h b/src/rpc-server/include/trex_rpc_server_api.h new file mode 100644 index 00000000..6bb81c73 --- /dev/null +++ b/src/rpc-server/include/trex_rpc_server_api.h @@ -0,0 +1,165 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef __TREX_RPC_SERVER_API_H__ +#define __TREX_RPC_SERVER_API_H__ + +#include <stdint.h> +#include <vector> +#include <thread> +#include <string> +#include <stdexcept> +#include <trex_rpc_exception_api.h> + +class TrexRpcServerInterface; + +/** + * defines a configuration of generic RPC server + * + * @author imarom (17-Aug-15) + */ +class TrexRpcServerConfig { +public: + + enum rpc_prot_e { + RPC_PROT_TCP + }; + + TrexRpcServerConfig(rpc_prot_e protocol, uint16_t port) : m_protocol(protocol), m_port(port) { + + } + + uint16_t get_port() { + return m_port; + } + + rpc_prot_e get_protocol() { + return m_protocol; + } + +private: + rpc_prot_e m_protocol; + uint16_t m_port; +}; + +/** + * generic type RPC server instance + * + * @author imarom (12-Aug-15) + */ +class TrexRpcServerInterface { +public: + + TrexRpcServerInterface(const TrexRpcServerConfig &cfg, const std::string &name); + virtual ~TrexRpcServerInterface(); + + /** + * starts the server + * + */ + void start(); + + /** + * stops the server + * + */ + void stop(); + + /** + * set verbose on or off + * + */ + void set_verbose(bool verbose); + + /** + * return TRUE if server is active + * + */ + bool is_running(); + + /** + * is the server verbose or not + * + */ + bool is_verbose(); + +protected: + /** + * instances implement this + * + */ + virtual void _rpc_thread_cb() = 0; + virtual void _stop_rpc_thread() = 0; + + /** + * prints a verbosed message (if enabled) + * + */ + void verbose_msg(const std::string &msg); + + TrexRpcServerConfig m_cfg; + bool m_is_running; + bool m_is_verbose; + std::thread *m_thread; + std::string m_name; +}; + +/** + * TREX RPC server + * may contain serveral types of RPC servers + * (request response, async and etc.) + * + * @author imarom (12-Aug-15) + */ +class TrexRpcServer { +public: + + /* currently only request response server config is required */ + TrexRpcServer(const TrexRpcServerConfig &req_resp_cfg); + ~TrexRpcServer(); + + /** + * starts the RPC server + * + * @author imarom (19-Aug-15) + */ + void start(); + + /** + * stops the RPC server + * + * @author imarom (19-Aug-15) + */ + void stop(); + + void set_verbose(bool verbose); + + static const std::string &get_server_uptime() { + return s_server_uptime; + } + +private: + std::vector<TrexRpcServerInterface *> m_servers; + bool m_verbose; + static const std::string s_server_uptime; +}; + +#endif /* __TREX_RPC_SERVER_API_H__ */ diff --git a/src/rpc-server/src/commands/trex_rpc_cmd_general.cpp b/src/rpc-server/src/commands/trex_rpc_cmd_general.cpp new file mode 100644 index 00000000..193ce8db --- /dev/null +++ b/src/rpc-server/src/commands/trex_rpc_cmd_general.cpp @@ -0,0 +1,49 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "trex_rpc_cmds.h" +#include <../linux_dpdk/version.h> +#include <trex_rpc_server_api.h> + +using namespace std; + +/** + * get status + * + */ +TrexRpcCommand::rpc_cmd_rc_e +TrexRpcCmdGetStatus::_run(const Json::Value ¶ms, Json::Value &result) { + + /* validate count */ + if (params.size() != 0) { + generate_err_param_count(result, 0, params.size()); + return (TrexRpcCommand::RPC_CMD_PARAM_COUNT_ERR); + } + + Json::Value §ion = result["result"]; + + section["general"]["version"] = VERSION_BUILD_NUM; + section["general"]["build_date"] = get_build_date(); + section["general"]["build_time"] = get_build_time(); + section["general"]["version_user"] = VERSION_USER; + section["general"]["uptime"] = TrexRpcServer::get_server_uptime(); + return (RPC_CMD_OK); +} + diff --git a/src/rpc-server/src/commands/trex_rpc_cmd_test.cpp b/src/rpc-server/src/commands/trex_rpc_cmd_test.cpp new file mode 100644 index 00000000..e2dc8959 --- /dev/null +++ b/src/rpc-server/src/commands/trex_rpc_cmd_test.cpp @@ -0,0 +1,126 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "trex_rpc_cmds.h" +#include <iostream> +#include <sstream> +#include <trex_rpc_cmds_table.h> + +using namespace std; + +/** + * add command + * + */ +TrexRpcCommand::rpc_cmd_rc_e +TrexRpcCmdTestAdd::_run(const Json::Value ¶ms, Json::Value &result) { + + const Json::Value &x = params["x"]; + const Json::Value &y = params["y"]; + + /* validate count */ + if (params.size() != 2) { + generate_err_param_count(result, 2, params.size()); + return (TrexRpcCommand::RPC_CMD_PARAM_COUNT_ERR); + } + + /* check we have all the required paramters */ + if (!x.isInt()) { + genernate_err(result, "'x' is either missing or not an integer"); + return (TrexRpcCommand::RPC_CMD_PARAM_PARSE_ERR); + } + + if (!y.isInt()) { + genernate_err(result, "'y' is either missing or not an integer"); + return (TrexRpcCommand::RPC_CMD_PARAM_PARSE_ERR); + } + + result["result"] = x.asInt() + y.asInt(); + return (RPC_CMD_OK); +} + +/** + * sub command + * + * @author imarom (16-Aug-15) + */ +TrexRpcCommand::rpc_cmd_rc_e +TrexRpcCmdTestSub::_run(const Json::Value ¶ms, Json::Value &result) { + + const Json::Value &x = params["x"]; + const Json::Value &y = params["y"]; + + /* validate count */ + if (params.size() != 2) { + generate_err_param_count(result, 2, params.size()); + return (TrexRpcCommand::RPC_CMD_PARAM_COUNT_ERR); + } + + /* check we have all the required paramters */ + if (!x.isInt() || !y.isInt()) { + return (TrexRpcCommand::RPC_CMD_PARAM_PARSE_ERR); + } + + result["result"] = x.asInt() - y.asInt(); + return (RPC_CMD_OK); +} + +/** + * ping command + */ +TrexRpcCommand::rpc_cmd_rc_e +TrexRpcCmdPing::_run(const Json::Value ¶ms, Json::Value &result) { + + /* validate count */ + if (params.size() != 0) { + generate_err_param_count(result, 0, params.size()); + return (TrexRpcCommand::RPC_CMD_PARAM_COUNT_ERR); + } + + result["result"] = "ACK"; + return (RPC_CMD_OK); +} + +/** + * query command + */ +TrexRpcCommand::rpc_cmd_rc_e +TrexRpcCmdGetReg::_run(const Json::Value ¶ms, Json::Value &result) { + vector<string> cmds; + + /* validate count */ + if (params.size() != 0) { + generate_err_param_count(result, 0, params.size()); + return (TrexRpcCommand::RPC_CMD_PARAM_COUNT_ERR); + } + + + TrexRpcCommandsTable::get_instance().query(cmds); + + Json::Value test = Json::arrayValue; + for (auto cmd : cmds) { + test.append(cmd); + } + + result["result"] = test; + + return (RPC_CMD_OK); +} + diff --git a/src/rpc-server/src/commands/trex_rpc_cmds.h b/src/rpc-server/src/commands/trex_rpc_cmds.h new file mode 100644 index 00000000..e37e1cda --- /dev/null +++ b/src/rpc-server/src/commands/trex_rpc_cmds.h @@ -0,0 +1,89 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef __TREX_RPC_CMD_H__ +#define __TREX_RPC_CMD_H__ + +#include <trex_rpc_cmd_api.h> +#include <json/json.h> + +/* all the RPC commands decl. goes here */ + +/******************* test section ************/ + +/** + * add + * + */ +class TrexRpcCmdTestAdd : public TrexRpcCommand { +public: + TrexRpcCmdTestAdd() : TrexRpcCommand("test_add") {} +protected: + virtual rpc_cmd_rc_e _run(const Json::Value ¶ms, Json::Value &result); +}; + +/** + * sub + * + */ +class TrexRpcCmdTestSub : public TrexRpcCommand { +public: + TrexRpcCmdTestSub() : TrexRpcCommand("test_sub") {} ; +protected: + virtual rpc_cmd_rc_e _run(const Json::Value ¶ms, Json::Value &result); +}; + +/** + * ping + * + */ +class TrexRpcCmdPing : public TrexRpcCommand { +public: + TrexRpcCmdPing() : TrexRpcCommand("ping") {}; +protected: + virtual rpc_cmd_rc_e _run(const Json::Value ¶ms, Json::Value &result); +}; + +/** + * get all registered commands + * + */ +class TrexRpcCmdGetReg : public TrexRpcCommand { +public: + TrexRpcCmdGetReg() : TrexRpcCommand("get_reg_cmds") {}; +protected: + virtual rpc_cmd_rc_e _run(const Json::Value ¶ms, Json::Value &result); +}; + +/** + * get status + * + */ +class TrexRpcCmdGetStatus : public TrexRpcCommand { +public: + TrexRpcCmdGetStatus() : TrexRpcCommand("get_status") {}; +protected: + virtual rpc_cmd_rc_e _run(const Json::Value ¶ms, Json::Value &result); +}; + + +/**************** test section end *************/ +#endif /* __TREX_RPC_CMD_H__ */ diff --git a/src/rpc-server/src/trex_rpc_cmds_table.cpp b/src/rpc-server/src/trex_rpc_cmds_table.cpp new file mode 100644 index 00000000..04a56389 --- /dev/null +++ b/src/rpc-server/src/trex_rpc_cmds_table.cpp @@ -0,0 +1,65 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include <trex_rpc_cmds_table.h> +#include <iostream> + +#include "commands/trex_rpc_cmds.h" + +using namespace std; + +/************* table related methods ***********/ +TrexRpcCommandsTable::TrexRpcCommandsTable() { + /* add the test command (for gtest) */ + register_command(new TrexRpcCmdTestAdd()); + register_command(new TrexRpcCmdTestSub()); + register_command(new TrexRpcCmdPing()); + register_command(new TrexRpcCmdGetReg()); + register_command(new TrexRpcCmdGetStatus()); +} + +TrexRpcCommandsTable::~TrexRpcCommandsTable() { + for (auto cmd : m_rpc_cmd_table) { + delete cmd.second; + } +} + +TrexRpcCommand * TrexRpcCommandsTable::lookup(const string &method_name) { + auto search = m_rpc_cmd_table.find(method_name); + + if (search != m_rpc_cmd_table.end()) { + return search->second; + } else { + return NULL; + } +} + + +void TrexRpcCommandsTable::register_command(TrexRpcCommand *command) { + + m_rpc_cmd_table[command->get_name()] = command; +} + +void TrexRpcCommandsTable::query(vector<string> &cmds) { + for (auto cmd : m_rpc_cmd_table) { + cmds.push_back(cmd.first); + } +} + diff --git a/src/rpc-server/src/trex_rpc_jsonrpc_v2_parser.cpp b/src/rpc-server/src/trex_rpc_jsonrpc_v2_parser.cpp new file mode 100644 index 00000000..be1eb2f8 --- /dev/null +++ b/src/rpc-server/src/trex_rpc_jsonrpc_v2_parser.cpp @@ -0,0 +1,194 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include <trex_rpc_exception_api.h> +#include <trex_rpc_jsonrpc_v2_parser.h> +#include <trex_rpc_cmd_api.h> +#include <trex_rpc_cmds_table.h> + +#include <json/json.h> + +#include <iostream> + +/** + * error as described in the RFC + * http://www.jsonrpc.org/specification + */ +enum { + JSONRPC_V2_ERR_PARSE = -32700, + JSONRPC_V2_ERR_INVALID_REQ = -32600, + JSONRPC_V2_ERR_METHOD_NOT_FOUND = -32601, + JSONRPC_V2_ERR_INVALID_PARAMS = -32602, + JSONRPC_V2_ERR_INTERNAL_ERROR = -32603 +}; + + +/*************** JSON RPC parsed object base type ************/ + +TrexJsonRpcV2ParsedObject::TrexJsonRpcV2ParsedObject(const Json::Value &msg_id, bool force = false) : m_msg_id(msg_id) { + /* if we have msg_id or a force was issued - write resposne */ + m_respond = (msg_id != Json::Value::null) || force; +} + +void TrexJsonRpcV2ParsedObject::execute(Json::Value &response) { + + /* common fields */ + if (m_respond) { + response["jsonrpc"] = "2.0"; + response["id"] = m_msg_id; + _execute(response); + } else { + Json::Value dummy; + _execute(dummy); + } +} + +/****************** valid method return value **************/ +class JsonRpcMethod : public TrexJsonRpcV2ParsedObject { +public: + JsonRpcMethod(const Json::Value &msg_id, TrexRpcCommand *cmd, const Json::Value ¶ms) : TrexJsonRpcV2ParsedObject(msg_id), m_cmd(cmd), m_params(params) { + + } + + virtual void _execute(Json::Value &response) { + Json::Value result; + + TrexRpcCommand::rpc_cmd_rc_e rc = m_cmd->run(m_params, result); + + switch (rc) { + case TrexRpcCommand::RPC_CMD_OK: + response["result"] = result["result"]; + break; + + case TrexRpcCommand::RPC_CMD_PARAM_COUNT_ERR: + case TrexRpcCommand::RPC_CMD_PARAM_PARSE_ERR: + response["error"]["code"] = JSONRPC_V2_ERR_INVALID_PARAMS; + response["error"]["message"] = "Bad paramters for method"; + response["error"]["specific_err"] = result["specific_err"]; + break; + + case TrexRpcCommand::RPC_CMD_INTERNAL_ERR: + response["error"]["code"] = JSONRPC_V2_ERR_INTERNAL_ERROR; + response["error"]["message"] = "Internal Server Error"; + response["error"]["specific_err"] = result["specific_err"]; + break; + } + + } + +private: + TrexRpcCommand *m_cmd; + Json::Value m_params; +}; + +/******************* RPC error **************/ + +/** + * describes the parser error + * + */ +class JsonRpcError : public TrexJsonRpcV2ParsedObject { +public: + + JsonRpcError(const Json::Value &msg_id, int code, const std::string &msg, bool force = false) : TrexJsonRpcV2ParsedObject(msg_id, force), m_code(code), m_msg(msg) { + + } + + virtual void _execute(Json::Value &response) { + response["error"]["code"] = m_code; + response["error"]["message"] = m_msg; + } + +private: + int m_code; + std::string m_msg; +}; + + +/************** JSON RPC V2 parser implementation *************/ + +TrexJsonRpcV2Parser::TrexJsonRpcV2Parser(const std::string &msg) : m_msg(msg) { + +} + +/** + * parse a batch of commands + * + * @author imarom (17-Aug-15) + * + * @param commands + */ +void TrexJsonRpcV2Parser::parse(std::vector<TrexJsonRpcV2ParsedObject *> &commands) { + + Json::Reader reader; + Json::Value request; + + /* basic JSON parsing */ + bool rc = reader.parse(m_msg, request, false); + if (!rc) { + commands.push_back(new JsonRpcError(Json::Value::null, JSONRPC_V2_ERR_PARSE, "Bad JSON Format", true)); + return; + } + + /* request can be an array of requests */ + if (request.isArray()) { + /* handle each command */ + for (auto single_request : request) { + parse_single_request(single_request, commands); + } + } else { + /* handle single command */ + parse_single_request(request, commands); + } + + +} + + +void TrexJsonRpcV2Parser::parse_single_request(Json::Value &request, + std::vector<TrexJsonRpcV2ParsedObject *> &commands) { + + Json::Value msg_id = request["id"]; + + /* check version */ + if (request["jsonrpc"] != "2.0") { + commands.push_back(new JsonRpcError(msg_id, JSONRPC_V2_ERR_INVALID_REQ, "Invalid JSONRPC Version")); + return; + } + + /* check method name */ + std::string method_name = request["method"].asString(); + if (method_name == "") { + commands.push_back(new JsonRpcError(msg_id, JSONRPC_V2_ERR_INVALID_REQ, "Missing Method Name")); + return; + } + + /* lookup the method in the DB */ + TrexRpcCommand * rpc_cmd = TrexRpcCommandsTable::get_instance().lookup(method_name); + if (!rpc_cmd) { + commands.push_back(new JsonRpcError(msg_id, JSONRPC_V2_ERR_METHOD_NOT_FOUND, "Method not registered")); + return; + } + + /* create a method object */ + commands.push_back(new JsonRpcMethod(msg_id, rpc_cmd, request["params"])); +} + diff --git a/src/rpc-server/src/trex_rpc_req_resp_server.cpp b/src/rpc-server/src/trex_rpc_req_resp_server.cpp new file mode 100644 index 00000000..7484758d --- /dev/null +++ b/src/rpc-server/src/trex_rpc_req_resp_server.cpp @@ -0,0 +1,146 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include <trex_rpc_server_api.h> +#include <trex_rpc_req_resp_server.h> +#include <trex_rpc_jsonrpc_v2_parser.h> + +#include <unistd.h> +#include <sstream> +#include <iostream> + +#include <zmq.h> +#include <json/json.h> + +/** + * ZMQ based request-response server + * + */ +TrexRpcServerReqRes::TrexRpcServerReqRes(const TrexRpcServerConfig &cfg) : TrexRpcServerInterface(cfg, "req resp") { + /* ZMQ is not thread safe - this should be outside */ + m_context = zmq_ctx_new(); +} + +/** + * main entry point for the server + * this function will be created on a different thread + * + * @author imarom (17-Aug-15) + */ +void TrexRpcServerReqRes::_rpc_thread_cb() { + std::stringstream ss; + + /* create a socket based on the configuration */ + + m_socket = zmq_socket (m_context, ZMQ_REP); + + switch (m_cfg.get_protocol()) { + case TrexRpcServerConfig::RPC_PROT_TCP: + ss << "tcp://*:"; + break; + default: + throw TrexRpcException("unknown protocol for RPC"); + } + + ss << m_cfg.get_port(); + + /* bind the scoket */ + int rc = zmq_bind (m_socket, ss.str().c_str()); + if (rc != 0) { + throw TrexRpcException("Unable to start ZMQ server at: " + ss.str()); + } + + /* server main loop */ + while (m_is_running) { + int msg_size = zmq_recv (m_socket, m_msg_buffer, sizeof(m_msg_buffer), 0); + + /* msg_size of -1 is an error - decode it */ + if (msg_size == -1) { + /* normal shutdown and zmq_term was called */ + if (errno == ETERM) { + break; + } else { + throw TrexRpcException("Unhandled error of zmq_recv"); + } + } + + /* transform it to a string */ + std::string request((const char *)m_msg_buffer, msg_size); + + verbose_msg("Server Received: " + request); + + handle_request(request); + } + + /* must be done from the same thread */ + zmq_close(m_socket); +} + +/** + * stops the ZMQ based RPC server + * + */ +void TrexRpcServerReqRes::_stop_rpc_thread() { + /* by calling zmq_term we signal the blocked thread to exit */ + zmq_term(m_context); + +} + +/** + * handles a request given to the server + * respondes to the request + */ +void TrexRpcServerReqRes::handle_request(const std::string &request) { + std::vector<TrexJsonRpcV2ParsedObject *> commands; + Json::FastWriter writer; + Json::Value response; + + /* first parse the request using JSON RPC V2 parser */ + TrexJsonRpcV2Parser rpc_request(request); + rpc_request.parse(commands); + + int index = 0; + + /* for every command parsed - launch it */ + for (auto command : commands) { + Json::Value single_response; + + command->execute(single_response); + delete command; + + response[index++] = single_response; + + } + + /* write the JSON to string and sever on ZMQ */ + std::string response_str; + + if (response.size() == 1) { + response_str = writer.write(response[0]); + } else { + response_str = writer.write(response); + } + + verbose_msg("Server Replied: " + response_str); + + zmq_send(m_socket, response_str.c_str(), response_str.size(), 0); + +} diff --git a/src/rpc-server/src/trex_rpc_server.cpp b/src/rpc-server/src/trex_rpc_server.cpp new file mode 100644 index 00000000..366bfc5b --- /dev/null +++ b/src/rpc-server/src/trex_rpc_server.cpp @@ -0,0 +1,153 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include <trex_rpc_server_api.h> +#include <trex_rpc_req_resp_server.h> +#include <unistd.h> +#include <zmq.h> +#include <sstream> +#include <iostream> + +/************** RPC server interface ***************/ + +TrexRpcServerInterface::TrexRpcServerInterface(const TrexRpcServerConfig &cfg, const std::string &name) : m_cfg(cfg), m_name(name) { + m_is_running = false; + m_is_verbose = false; +} + +TrexRpcServerInterface::~TrexRpcServerInterface() { + if (m_is_running) { + stop(); + } +} + +void TrexRpcServerInterface::verbose_msg(const std::string &msg) { + if (!m_is_verbose) { + return; + } + + std::cout << "[verbose][" << m_name << "] " << msg << "\n"; +} + +/** + * starts a RPC specific server + * + * @author imarom (17-Aug-15) + */ +void TrexRpcServerInterface::start() { + m_is_running = true; + + verbose_msg("Starting RPC Server"); + + m_thread = new std::thread(&TrexRpcServerInterface::_rpc_thread_cb, this); + if (!m_thread) { + throw TrexRpcException("unable to create RPC thread"); + } +} + +void TrexRpcServerInterface::stop() { + m_is_running = false; + + verbose_msg("Attempting To Stop RPC Server"); + + /* call the dynamic type class stop */ + _stop_rpc_thread(); + + /* hold until thread has joined */ + m_thread->join(); + + verbose_msg("Server Stopped"); + + delete m_thread; +} + +void TrexRpcServerInterface::set_verbose(bool verbose) { + m_is_verbose = verbose; +} + +bool TrexRpcServerInterface::is_verbose() { + return m_is_verbose; +} + +bool TrexRpcServerInterface::is_running() { + return m_is_running; +} + + +/************** RPC server *************/ + +static const std::string +get_current_date_time() { + time_t now = time(0); + struct tm tstruct; + char buf[80]; + tstruct = *localtime(&now); + strftime(buf, sizeof(buf), "%b %d %Y @ %X", &tstruct); + + return buf; +} + +const std::string TrexRpcServer::s_server_uptime = get_current_date_time(); + +TrexRpcServer::TrexRpcServer(const TrexRpcServerConfig &req_resp_cfg) { + + /* add the request response server */ + m_servers.push_back(new TrexRpcServerReqRes(req_resp_cfg)); +} + +TrexRpcServer::~TrexRpcServer() { + + /* make sure they are all stopped */ + stop(); + + for (auto server : m_servers) { + delete server; + } +} + +/** + * start the server array + * + */ +void TrexRpcServer::start() { + for (auto server : m_servers) { + server->start(); + } +} + +/** + * stop the server array + * + */ +void TrexRpcServer::stop() { + for (auto server : m_servers) { + if (server->is_running()) { + server->stop(); + } + } +} + +void TrexRpcServer::set_verbose(bool verbose) { + for (auto server : m_servers) { + server->set_verbose(verbose); + } +} + diff --git a/src/rpc-server/src/trex_rpc_server_mock.cpp b/src/rpc-server/src/trex_rpc_server_mock.cpp new file mode 100644 index 00000000..fd4f051c --- /dev/null +++ b/src/rpc-server/src/trex_rpc_server_mock.cpp @@ -0,0 +1,75 @@ +/* + Itay Marom + Cisco Systems, Inc. +*/ + +/* +Copyright (c) 2015-2015 Cisco Systems, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include <trex_rpc_server_api.h> +#include <iostream> +#include <unistd.h> + +using namespace std; + +/** + * on simulation this is not rebuild every version + * (improved stub) + * + */ +extern "C" const char * get_build_date(void){ + return (__DATE__); +} + +extern "C" const char * get_build_time(void){ + return (__TIME__ ); +} + +int gtest_main(int argc, char **argv); + +int main(int argc, char *argv[]) { + + // gtest ? + if (argc > 1) { + if ( (string(argv[1]) != "--ut") || (argc != 2) ) { + cout << "\n[Usage] " << argv[0] << ": " << " [--ut]\n\n"; + exit(-1); + } + return gtest_main(argc, argv); + } + + cout << "\n-= Starting RPC Server Mock =-\n\n"; + cout << "Listening on tcp://localhost:5050 [ZMQ]\n\n"; + + TrexRpcServerConfig rpc_cfg(TrexRpcServerConfig::RPC_PROT_TCP, 5050); + TrexRpcServer rpc(rpc_cfg); + + /* init the RPC server */ + rpc.start(); + + cout << "Setting Server To Full Verbose\n\n"; + rpc.set_verbose(true); + + cout << "Server Started\n\n"; + + while (true) { + sleep(1); + } + + rpc.stop(); + + +} |