summaryrefslogtreecommitdiffstats
path: root/scripts/external_libs/elasticsearch/elasticsearch/helpers/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/external_libs/elasticsearch/elasticsearch/helpers/test.py')
-rw-r--r--scripts/external_libs/elasticsearch/elasticsearch/helpers/test.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/scripts/external_libs/elasticsearch/elasticsearch/helpers/test.py b/scripts/external_libs/elasticsearch/elasticsearch/helpers/test.py
new file mode 100644
index 00000000..6fdd2ec0
--- /dev/null
+++ b/scripts/external_libs/elasticsearch/elasticsearch/helpers/test.py
@@ -0,0 +1,60 @@
+import time
+import os
+try:
+ # python 2.6
+ from unittest2 import TestCase, SkipTest
+except ImportError:
+ from unittest import TestCase, SkipTest
+
+from elasticsearch import Elasticsearch
+from elasticsearch.exceptions import ConnectionError
+
+def get_test_client(nowait=False, **kwargs):
+ # construct kwargs from the environment
+ kw = {'timeout': 30}
+ if 'TEST_ES_CONNECTION' in os.environ:
+ from elasticsearch import connection
+ kw['connection_class'] = getattr(connection, os.environ['TEST_ES_CONNECTION'])
+
+ kw.update(kwargs)
+ client = Elasticsearch([os.environ.get('TEST_ES_SERVER', {})], **kw)
+
+ # wait for yellow status
+ for _ in range(1 if nowait else 100):
+ try:
+ client.cluster.health(wait_for_status='yellow')
+ return client
+ except ConnectionError:
+ time.sleep(.1)
+ else:
+ # timeout
+ raise SkipTest("Elasticsearch failed to start.")
+
+def _get_version(version_string):
+ if '.' not in version_string:
+ return ()
+ version = version_string.strip().split('.')
+ return tuple(int(v) if v.isdigit() else 999 for v in version)
+
+class ElasticsearchTestCase(TestCase):
+ @staticmethod
+ def _get_client():
+ return get_test_client()
+
+ @classmethod
+ def setUpClass(cls):
+ super(ElasticsearchTestCase, cls).setUpClass()
+ cls.client = cls._get_client()
+
+ def tearDown(self):
+ super(ElasticsearchTestCase, self).tearDown()
+ self.client.indices.delete(index='*', ignore=404)
+ self.client.indices.delete_template(name='*', ignore=404)
+
+ @property
+ def es_version(self):
+ if not hasattr(self, '_es_version'):
+ version_string = self.client.info()['version']['number']
+ self._es_version = _get_version(version_string)
+ return self._es_version
+