diff --git a/rethinkdb/net.py b/rethinkdb/net.py index 5a4c8ddc..7b3c774d 100644 --- a/rethinkdb/net.py +++ b/rethinkdb/net.py @@ -25,6 +25,11 @@ import struct import time +try: + from urllib.parse import urlparse, parse_qs +except ImportError: + from urlparse import urlparse, parse_qs + from rethinkdb import ql2_pb2 from rethinkdb.ast import DB, ReQLDecoder, ReQLEncoder, Repl, expr from rethinkdb.errors import ( @@ -703,9 +708,6 @@ def __init__(self, *args, **kwargs): Connection.__init__(self, ConnectionInstance, *args, **kwargs) - - - def make_connection( connection_type, host=None, @@ -716,20 +718,40 @@ def make_connection( password=None, timeout=20, ssl=None, + url=None, _handshake_version=10, **kwargs): - if host is None: - host = 'localhost' - if port is None: - port = DEFAULT_PORT - if user is None: - user = 'admin' - if timeout is None: - timeout = 20 - if ssl is None: - ssl = dict() - if _handshake_version is None: - _handshake_version = 10 + if url: + connection_string = urlparse(url) + query_string = parse_qs(connection_string.query) + + user = connection_string.username + password = connection_string.password + host = connection_string.hostname + port = connection_string.port + + db = connection_string.path.replace("/", "") or None + auth_key = query_string.get("auth_key") + timeout = query_string.get("timeout") + + if auth_key: + auth_key = auth_key[0] + + if timeout: + timeout = int(timeout[0]) + + + host = host or 'localhost' + port = port or DEFAULT_PORT + user = user or 'admin' + timeout = timeout or 20 + ssl = ssl or dict() + _handshake_version = _handshake_version or 10 + + # The internal APIs will wait for none to deal with auth_key and password + # TODO: refactor when we drop python2 + if not password and not password is None: + password = None conn = connection_type(host, port, db, auth_key, user, password, timeout, ssl, _handshake_version, **kwargs) return conn.reconnect(timeout=timeout) diff --git a/tests/helpers.py b/tests/helpers.py index b666050e..91a02574 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,18 +6,21 @@ class IntegrationTestCaseBase(object): + def _create_database(self, conn): + if INTEGRATION_TEST_DB not in self.r.db_list().run(conn): + self.r.db_create(INTEGRATION_TEST_DB).run(conn) + + conn.use(INTEGRATION_TEST_DB) + def setup_method(self): self.r = r - self.rethinkdb_host = os.getenv('RETHINKDB_HOST') + self.rethinkdb_host = os.getenv('RETHINKDB_HOST', '127.0.0.1') self.conn = self.r.connect( host=self.rethinkdb_host ) - if INTEGRATION_TEST_DB not in self.r.db_list().run(self.conn): - self.r.db_create(INTEGRATION_TEST_DB).run(self.conn) - - self.conn.use(INTEGRATION_TEST_DB) + self._create_database(self.conn) def teardown_method(self): self.r.db_drop(INTEGRATION_TEST_DB).run(self.conn) diff --git a/tests/integration/test_connect.py b/tests/integration/test_connect.py new file mode 100644 index 00000000..77213eb7 --- /dev/null +++ b/tests/integration/test_connect.py @@ -0,0 +1,29 @@ +import os +import pytest + +from rethinkdb import r +from tests.helpers import IntegrationTestCaseBase, INTEGRATION_TEST_DB + + +@pytest.mark.integration +class TestConnect(IntegrationTestCaseBase): + def setup_method(self): + super(TestConnect, self).setup_method() + + def test_connect(self): + db_url = "rethinkdb://{host}".format(host=self.rethinkdb_host) + + assert self.r.connect(url=db_url) is not None + + def test_connect_with_username(self): + db_url = "rethinkdb://admin@{host}".format(host=self.rethinkdb_host) + + assert self.r.connect(url=db_url) is not None + + def test_connect_to_db(self): + db_url = "rethinkdb://{host}/{database}".format( + host=self.rethinkdb_host, + database=INTEGRATION_TEST_DB + ) + + assert self.r.connect(url=db_url) is not None diff --git a/tests/integration/test_write_hooks.py b/tests/integration/test_write_hooks.py index 2ef0128c..cf40cd8d 100644 --- a/tests/integration/test_write_hooks.py +++ b/tests/integration/test_write_hooks.py @@ -48,4 +48,4 @@ def test_get_write_hook(self): hook = self.r.table(self.table_name).get_write_hook().run(self.conn) - assert list(hook.keys()) == ['function', 'query'] \ No newline at end of file + assert list(sorted(hook.keys())) == ['function', 'query'] diff --git a/tests/test_net.py b/tests/test_net.py new file mode 100644 index 00000000..76b3027a --- /dev/null +++ b/tests/test_net.py @@ -0,0 +1,200 @@ +import pytest +from mock import Mock, ANY +from rethinkdb.net import make_connection, DefaultConnection, DEFAULT_PORT + + +@pytest.mark.unit +class TestMakeConnection(object): + def setup_method(self): + self.reconnect = Mock() + self.conn_type = Mock() + self.conn_type.return_value.reconnect.return_value = self.reconnect + + self.host = "myhost" + self.port = 1234 + self.db = "mydb" + self.auth_key = None + self.user = "gabor" + self.password = "strongpass" + self.timeout = 20 + + + def test_make_connection(self): + ssl = dict() + _handshake_version = 10 + + conn = make_connection( + self.conn_type, + host=self.host, + port=self.port, + db=self.db, + auth_key=self.auth_key, + user=self.user, + password=self.password, + timeout=self.timeout, + ) + + assert conn == self.reconnect + self.conn_type.assert_called_once_with( + self.host, + self.port, + self.db, + self.auth_key, + self.user, + self.password, + self.timeout, + ssl, + _handshake_version + ) + + + def test_make_connection_db_url(self): + url = "rethinkdb://gabor:strongpass@myhost:1234/mydb?auth_key=mykey&timeout=30" + ssl = dict() + _handshake_version = 10 + + conn = make_connection(self.conn_type, url=url) + + assert conn == self.reconnect + self.conn_type.assert_called_once_with( + self.host, + self.port, + self.db, + "mykey", + self.user, + self.password, + 30, + ssl, + _handshake_version + ) + + + def test_make_connection_no_host(self): + conn = make_connection( + self.conn_type, + port=self.port, + db=self.db, + auth_key=self.auth_key, + user=self.user, + password=self.password, + timeout=self.timeout, + ) + + assert conn == self.reconnect + self.conn_type.assert_called_once_with( + "localhost", + self.port, + self.db, + self.auth_key, + self.user, + self.password, + self.timeout, + ANY, + ANY + ) + + + def test_make_connection_no_port(self): + conn = make_connection( + self.conn_type, + host=self.host, + db=self.db, + auth_key=self.auth_key, + user=self.user, + password=self.password, + timeout=self.timeout, + ) + + assert conn == self.reconnect + self.conn_type.assert_called_once_with( + self.host, + DEFAULT_PORT, + self.db, + self.auth_key, + self.user, + self.password, + self.timeout, + ANY, + ANY + ) + + + def test_make_connection_no_user(self): + conn = make_connection( + self.conn_type, + host=self.host, + port=self.port, + db=self.db, + auth_key=self.auth_key, + password=self.password, + timeout=self.timeout, + ) + + assert conn == self.reconnect + self.conn_type.assert_called_once_with( + self.host, + self.port, + self.db, + self.auth_key, + "admin", + self.password, + self.timeout, + ANY, + ANY + ) + + + def test_make_connection_with_ssl(self): + ssl = dict() + + conn = make_connection( + self.conn_type, + host=self.host, + port=self.port, + db=self.db, + auth_key=self.auth_key, + user=self.user, + password=self.password, + timeout=self.timeout, + ssl=ssl, + ) + + assert conn == self.reconnect + self.conn_type.assert_called_once_with( + self.host, + self.port, + self.db, + self.auth_key, + self.user, + self.password, + self.timeout, + ssl, + ANY + ) + + + def test_make_connection_different_handshake_version(self): + conn = make_connection( + self.conn_type, + host=self.host, + port=self.port, + db=self.db, + auth_key=self.auth_key, + user=self.user, + password=self.password, + timeout=self.timeout, + _handshake_version=20, + ) + + assert conn == self.reconnect + self.conn_type.assert_called_once_with( + self.host, + self.port, + self.db, + self.auth_key, + self.user, + self.password, + self.timeout, + ANY, + 20 + )