diff --git a/pydruid/db/api.py b/pydruid/db/api.py index 9f3d5f97..6b87af38 100644 --- a/pydruid/db/api.py +++ b/pydruid/db/api.py @@ -29,6 +29,7 @@ def connect( password=None, context=None, header=False, + ssl_verify_cert=True, ): # noqa: E125 """ Constructor for creating a connection to the database. @@ -38,7 +39,17 @@ def connect( """ context = context or {} - return Connection(host, port, path, scheme, user, password, context, header) + return Connection( + host, + port, + path, + scheme, + user, + password, + context, + header, + ssl_verify_cert, + ) def check_closed(f): @@ -118,6 +129,7 @@ def __init__( password=None, context=None, header=False, + ssl_verify_cert=True, ): netloc = "{host}:{port}".format(host=host, port=port) self.url = parse.urlunparse((scheme, netloc, path, None, None, None)) @@ -127,6 +139,7 @@ def __init__( self.header = header self.user = user self.password = password + self.ssl_verify_cert = ssl_verify_cert @check_closed def close(self): @@ -150,7 +163,14 @@ def commit(self): @check_closed def cursor(self): """Return a new Cursor Object using the connection.""" - cursor = Cursor(self.url, self.user, self.password, self.context, self.header) + cursor = Cursor( + self.url, + self.user, + self.password, + self.context, + self.header, + self.ssl_verify_cert, + ) self.cursors.append(cursor) return cursor @@ -171,12 +191,21 @@ class Cursor(object): """Connection cursor.""" - def __init__(self, url, user=None, password=None, context=None, header=False): + def __init__( + self, + url, + user=None, + password=None, + context=None, + header=False, + ssl_verify_cert=True, + ): self.url = url self.context = context or {} self.header = header self.user = user self.password = password + self.ssl_verify_cert = ssl_verify_cert # This read/write attribute specifies the number of rows to fetch at a # time with .fetchmany(). It defaults to 1 meaning to fetch a single @@ -300,7 +329,12 @@ def _stream_query(self, query): requests.auth.HTTPBasicAuth(self.user, self.password) if self.user else None ) r = requests.post( - self.url, stream=True, headers=headers, json=payload, auth=auth + self.url, + stream=True, + headers=headers, + json=payload, + auth=auth, + verify=self.ssl_verify_cert, ) if r.encoding is None: r.encoding = "utf-8" diff --git a/tests/db/test_cursor.py b/tests/db/test_cursor.py index 222c72e5..c8c798f2 100644 --- a/tests/db/test_cursor.py +++ b/tests/db/test_cursor.py @@ -63,6 +63,7 @@ def test_context(self, requests_post_mock): stream=True, headers={'Content-Type': 'application/json'}, json={'query': query, 'context': context, 'header': False}, + verify=True ) @patch('requests.post')