# -*- coding: utf-8 -*- import unittest import psycopg2 import time from pgbouncemgr.postgres import * from pgbouncemgr.logger import Logger from pgbouncemgr.node_config import NodeConfig from tests.stub_psycopg2 import * class PgConnectionTests(unittest.TestCase): def test_GivenAuthenticationFailure_PgConnection_RaisesException(self): self._test_connect_exception( PgConnection, StubPsycopg2().add_auth_failure(), PgConnectionFailed, "authentication failed") def test_GivenConnectionTimeout_PgConnection_RaisesException(self): self._test_connect_exception( PgConnection, StubPsycopg2().add_conn_timeout(), PgConnectionFailed, "timeout expired") def test_GivenConnectionFailure_PgConnection_RaisesException(self): self._test_connect_exception( PgConnection, StubPsycopg2().add_conn_failure(), PgConnectionFailed, "could not connect") def test_GivenPostgresStartingUp_PgConnection_RaisesException(self): self._test_connect_exception( PgConnection, StubPsycopg2().add_admin_startup(), PgConnectionFailed, "system is starting up") def test_GivenPostgresShuttingDown_PgConnection_RaisesException(self): self._test_connect_exception( PgConnection, StubPsycopg2().add_admin_shutdown(), PgConnectionFailed, "AdminShutdown") def _test_connect_exception(self, test_class, stub_psycopg2, err, msg): pg = test_class(NodeConfig(1), Logger(), stub_psycopg2) with self.assertRaises(err) as context: pg.connect() self.assertIn(msg, str(context.exception)) def test_NodeConfig_IsAppliedToConnParams(self): node_config = NodeConfig(1) node_config.host = "1.1.1.1" node_config.port = 9999 pg = PgConnection(node_config, Logger(), None) self.assertEqual(1, pg.node_id) self.assertEqual("1.1.1.1", pg.conn_params["host"]) self.assertEqual(9999, pg.conn_params["port"]) def test_GivenNoneValueInNodeConfig_ValueIsOmittedInConnParams(self): node_config = NodeConfig(2) node_config.host = None node_config.port = None stub_psycopg2 = StubPsycopg2().add_connection(StubConnection()) pg = PgConnection(node_config, Logger(), stub_psycopg2) pg.connect() self.assertEqual(2, pg.node_id) self.assertNotIn("host", pg.conn_params) self.assertNotIn("port", pg.conn_params) def test_FirstConnect_SetsUpConnection(self): stub_psycopg2 = StubPsycopg2().add_connection(StubConnection()) pg = PgConnection(NodeConfig('a'), Logger(), stub_psycopg2) result = pg.connect() self.assertEqual("CONN_CONNECTED", result) def test_SecondConnect_PingsAndReusesConnection(self): cursor = StubCursor(("SELECT", [[1]])) conn = StubConnection(cursor) stub_psycopg2 = StubPsycopg2().add_connection(conn) pg = PgConnection(NodeConfig('b'), Logger(), stub_psycopg2) result1 = pg.connect() self.assertIs(None, cursor.query) result2 = pg.connect() self.assertEqual("SELECT 1", cursor.query) self.assertEqual("CONN_CONNECTED", result1) self.assertEqual("CONN_REUSED", result2) def test_SecondConnectPingFails_SetsUpNewConnection(self): conn1 = StubConnection(StubCursor(psycopg2.OperationalError())) conn2 = StubConnection() stub_psycopg2 = StubPsycopg2(conn1, conn2) pg = PgConnection(NodeConfig('b'), Logger(), stub_psycopg2) result1 = pg.connect() # Connection OK result2 = pg.connect() # Ping fails, reconnectt self.assertEqual("CONN_CONNECTED", result1) self.assertEqual("CONN_RECONNECTED", result2) def test_Disconnect_ClosesConnection(self): conn = StubConnection() stub_psycopg2 = StubPsycopg2(conn) pg = PgConnection(NodeConfig('disco'), Logger(), stub_psycopg2) pg.connect() self.assertTrue(conn.connected) pg.disconnect() self.assertFalse(conn.connected) def test_BigConnectionFlow(self): conn1 = StubConnection( StubCursor(psycopg2.OperationalError())) conn2 = StubConnection( StubCursor(("SELECT", [])), StubCursor(("SELECT", []))) conn3 = StubConnection( StubCursor(("SELECT", []))) stub_psycopg2 = StubPsycopg2(conn1, conn2, conn3) pg = PgConnection(NodeConfig('b'), Logger(), stub_psycopg2) result1 = pg.connect() # Connection 1 OK result2 = pg.connect() # Ping fails, new connection 2 OK result3 = pg.connect() # Ping 1 success result4 = pg.connect() # Ping 2 success pg.disconnect() # Explicit disconnect result5 = pg.connect() # Connection 3 OK result6 = pg.connect() # Ping success self.assertEqual("CONN_CONNECTED", result1) self.assertEqual("CONN_RECONNECTED", result2) self.assertEqual("CONN_REUSED", result3) self.assertEqual("CONN_REUSED", result4) self.assertEqual("CONN_CONNECTED", result5) self.assertEqual("CONN_REUSED", result6) class PgReplicationConnectionTests(unittest.TestCase): def test_LogicalReplicationConnection_IsUsed(self): conn = StubConnection() stub_psycopg2 = StubPsycopg2().add_connection(conn) pg = PgReplicationConnection( NodeConfig("foo"), Logger(), stub_psycopg2) pg.connect() self.assertEqual( psycopg2.extras.LogicalReplicationConnection, pg.conn_params["connection_factory"]) def test_GivenFailingConnection_ReplicationStatusIsOffline(self): stub_psycopg2 = StubPsycopg2().add_auth_failure() pg = PgReplicationConnection( NodeConfig("foo"), Logger(), stub_psycopg2) status = pg.get_replication_status() self.assertEqual({ "status": "NODE_OFFLINE", "system_id": None, "timeline_id": None}, status) def test_GivenFailingStandbyQuery_ReplicationStatusRaisesException(self): conn = StubConnection(StubCursor(psycopg2.OperationalError())) pg = PgReplicationConnection( NodeConfig("foo"), Logger(), StubPsycopg2(conn)) with self.assertRaises(RetrievingPgReplicationStatusFailed) as context: pg.get_replication_status() self.assertIn("pg_is_in_recovery() failed", str(context.exception)) def test_GivenFailingIdentifySystemQuery_ReplicationStatusRaisesException(self): conn = StubConnection( StubCursor(("SELECT", [[False]])), psycopg2.OperationalError()) pg = PgReplicationConnection( NodeConfig("foo"), Logger(), StubPsycopg2(conn)) with self.assertRaises(RetrievingPgReplicationStatusFailed) as context: pg.get_replication_status() self.assertIn("IDENTIFY_SYSTEM failed", str(context.exception)) def test_GivenConnectionToPrimaryNode_ReplicationStatusIsPrimary(self): conn = StubConnection( StubCursor(("SELECT", [[False]])), StubCursor(("IDENTIFY_SYSTEM", [["id", 1234, "other", "fields"]]))) stub_psycopg2 = StubPsycopg2(conn) pg = PgReplicationConnection( NodeConfig("foo"), Logger(), stub_psycopg2) status = pg.get_replication_status() self.assertEqual({ "status": "NODE_PRIMARY", "system_id": "id", "timeline_id": 1234}, status) def test_GivenConnectionToStandbyNode_ReplicationStatusIsStandby(self): conn = StubConnection( StubCursor(("SELECT", [[True]])), StubCursor(("IDENTIFY_SYSTEM", [["towel", 42, "other", "fields"]]))) stub_psycopg2 = StubPsycopg2(conn) pg = PgReplicationConnection( NodeConfig("foo"), Logger(), stub_psycopg2) status = pg.get_replication_status() self.assertEqual({ "status": "NODE_STANDBY", "system_id": "towel", "timeline_id": 42}, status) class PgConnectionViaPgbouncerTests(unittest.TestCase): def test_NodeConfigAndPgBouncerConfig_AreMergedInConnParams(self): node_config = NodeConfig(777) node_config.host = "1.1.1.1" node_config.port = 9999 pgbouncer_config = {"host": "192.168.0.1", "port": 7654} pg = PgConnectionViaPgbouncer( node_config, pgbouncer_config, Logger(), None) self.assertEqual(777, pg.node_id) self.assertEqual("template1", pg.conn_params["database"]) self.assertEqual("192.168.0.1", pg.conn_params["host"]) self.assertEqual(7654, pg.conn_params["port"]) def test_WhenVerifyQueryCannotConnect_ExceptionIsRaised(self): stub_psycopg2 = StubPsycopg2().add_auth_failure() node_config = NodeConfig('xyz123') pgbouncer_config = {"host": "192.168.0.1", "port": 7654} pg = PgConnectionViaPgbouncer( node_config, pgbouncer_config, Logger(), stub_psycopg2) ok, err, cursor = pg.verify_connection() self.assertFalse(ok) self.assertIs(None, cursor) self.assertIn("password authentication failed", str(err)) def test_WhenVerifyQueryTimesOut_ExceptionIsRaised(self): def sleeping_query(query): time.sleep(1) conn = StubConnection(StubCursor(sleeping_query)) stub_psycopg2 = StubPsycopg2(conn) node_config = NodeConfig('xyz123') node_config.connect_timeout = 0.01 pgbouncer_config = {"host": "192.168.0.1", "port": 7654} pg = PgConnectionViaPgbouncer( node_config, pgbouncer_config, Logger(), stub_psycopg2) ok, err, cursor = pg.verify_connection() self.assertFalse(ok) self.assertIs(None, cursor) self.assertIn("Connection attempt timed out", str(err)) def test_WhenVerifyQueryRuns_NodeHostAndPortAreQueried(self): conn = StubConnection(StubCursor(('SELECT', [[None]]))) stub_psycopg2 = StubPsycopg2(conn) node_config = NodeConfig('uh') node_config.host = "111.222.111.222" node_config.port = 1122 pgbouncer_config = {"host": "192.168.0.1", "port": 7654} pg = PgConnectionViaPgbouncer( node_config, pgbouncer_config, Logger(), stub_psycopg2) ok, err, cursor = pg.verify_connection() self.assertIn("inet_server_addr()", cursor.query) self.assertIn("%(host)s", cursor.query) self.assertIn("inet_server_port()", cursor.query) self.assertIn("%(port)s", cursor.query) self.assertEqual({ "host": "111.222.111.222", "port": 1122 }, cursor.params) def test_WhenVerifyQueryReturnsIssue_IssueIsReported(self): cursor = StubCursor(('SELECT', [['Bokito has escaped']])) conn = StubConnection(cursor) stub_psycopg2 = StubPsycopg2(conn) node_config = NodeConfig('xyz123') node_config.host = "111.222.111.222" node_config.port = 1122 pgbouncer_config = {"host": "192.168.0.1", "port": 7654} pg = PgConnectionViaPgbouncer( node_config, pgbouncer_config, Logger(), stub_psycopg2) ok, err, _ = pg.verify_connection() self.assertFalse(ok) self.assertIn("not connected to the expected PostgreSQL backend", err) self.assertIn("Bokito has escaped", err) def test_WhenVerifyQueryReturnsNoIssue_OkIsReported(self): conn = StubConnection(StubCursor(('SELECT', [[None]]))) stub_psycopg2 = StubPsycopg2(conn) node_config = NodeConfig('xyz123') pgbouncer_config = {"host": "192.168.0.1", "port": 7654} pg = PgConnectionViaPgbouncer( node_config, pgbouncer_config, Logger(), stub_psycopg2) ok, err, _ = pg.verify_connection() self.assertTrue(ok) self.assertIs(None, err) class PgBouncerConsoleConnectionTests(unittest.TestCase): def test_OnConnect_AutoCommitIsEnabled(self): conn = StubConnection() stub_psycopg2 = StubPsycopg2().add_connection(conn) pg = PgBouncerConsoleConnection( NodeConfig("bob"), Logger(), stub_psycopg2) self.assertFalse(conn.autocommit) pg.connect() self.assertTrue(conn.autocommit)