summaryrefslogtreecommitdiffstats
path: root/scripts/external_libs/elasticsearch/elasticsearch/helpers/test.py
blob: 6fdd2ec0558896c47ace9fd5196b110d4a12925a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import time
import os
try:
    # python 2.6
    from unittest2 import TestCase, SkipTest
except ImportError:
    from unittest import TestCase, SkipTest

from elasticsearch import Elasticsearch
from elasticsearch.exceptions import ConnectionError

def get_test_client(nowait=False, **kwargs):
    # construct kwargs from the environment
    kw = {'timeout': 30}
    if 'TEST_ES_CONNECTION' in os.environ:
        from elasticsearch import connection
        kw['connection_class'] = getattr(connection, os.environ['TEST_ES_CONNECTION'])

    kw.update(kwargs)
    client = Elasticsearch([os.environ.get('TEST_ES_SERVER', {})], **kw)

    # wait for yellow status
    for _ in range(1 if nowait else 100):
        try:
            client.cluster.health(wait_for_status='yellow')
            return client
        except ConnectionError:
            time.sleep(.1)
    else:
        # timeout
        raise SkipTest("Elasticsearch failed to start.")

def _get_version(version_string):
    if '.' not in version_string:
        return ()
    version = version_string.strip().split('.')
    return tuple(int(v) if v.isdigit() else 999 for v in version)

class ElasticsearchTestCase(TestCase):
    @staticmethod
    def _get_client():
        return get_test_client()

    @classmethod
    def setUpClass(cls):
        super(ElasticsearchTestCase, cls).setUpClass()
        cls.client = cls._get_client()

    def tearDown(self):
        super(ElasticsearchTestCase, self).tearDown()
        self.client.indices.delete(index='*', ignore=404)
        self.client.indices.delete_template(name='*', ignore=404)

    @property
    def es_version(self):
        if not hasattr(self, '_es_version'):
            version_string = self.client.info()['version']['number']
            self._es_version = _get_version(version_string)
        return self._es_version