From f744dbec4f2c9714e00fe860b43d7da90664bfc5 Mon Sep 17 00:00:00 2001 From: Maurice Makaay Date: Thu, 12 Dec 2019 10:16:46 +0100 Subject: [PATCH] Added a separate module for dropping privileges, to clean up the manager code. --- pgbouncemgr/drop_privileges.py | 70 +++++++++ pgbouncemgr/logger.py | 12 +- pgbouncemgr/manager.py | 29 ++-- pgbouncemgr/node_config.py | 8 +- pgbouncemgr/node_poller.py | 13 ++ pgbouncemgr/postgres.py | 271 +++++++++++++++++++++++++++++++++ pgbouncemgr/state.py | 6 +- pgbouncemgr/state_store.py | 18 +-- tests/test_drop_privileges.py | 78 ++++++++++ tests/test_logger.py | 18 +-- tests/test_manager.py | 2 +- tests/test_state_store.py | 11 +- 12 files changed, 492 insertions(+), 44 deletions(-) create mode 100644 pgbouncemgr/drop_privileges.py create mode 100644 pgbouncemgr/node_poller.py create mode 100644 pgbouncemgr/postgres.py create mode 100644 tests/test_drop_privileges.py diff --git a/pgbouncemgr/drop_privileges.py b/pgbouncemgr/drop_privileges.py new file mode 100644 index 0000000..02eb523 --- /dev/null +++ b/pgbouncemgr/drop_privileges.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# pylint: disable=too-few-public-methods,missing-docstring + +"""Functionality to drop privileges to those of a specified user/group id.""" + +from os import setuid, setgid, setuid, setgroups, getuid, geteuid, getgroups +from pwd import getpwnam, getpwuid +from grp import getgrnam, getgrgid +from pgbouncemgr.logger import format_ex + + +class DropPrivilegesException(Exception): + """Used for all exceptions that are raised from + pgbouncemgr.drop_privileges.""" + + +def drop_privileges(user, group): + """Drop privileges to those of the provided system user / group. + When dropping the privileges fails, an exception will be raised. + Otherwise a tuple (user, uid, group, gid) will be returned.""" + user, uid = get_uid(user) + group, gid = get_gid(group) + + try: + # Clear all groups when the current user is root (because only root + # has permission to do so). + if geteuid() == 0: + setgroups([]) + + # When the requested group is not already in the currently + # effective groups, then set the gid. + if gid not in getgroups(): + setgid(gid) + + # When the requested user is not already the currently + # effective user, then set the uid. + if uid != getuid(): + setuid(uid) + except Exception as exception: + raise DropPrivilegesException( + "Could not drop privileges to %s:%s (%d:%d): %s" % + (user, group, uid, gid, format_ex(exception))) + + return (user, uid, group, gid) + + +def get_uid(user): + try: + try: + uid = int(user) + entry = getpwuid(uid) + except ValueError: + entry = getpwnam(user) + return (entry.pw_name, entry.pw_uid) + except Exception as exception: + raise DropPrivilegesException( + "Invalid run user: %s (%s)" % (user, format_ex(exception))) + + +def get_gid(group): + try: + try: + gid = int(group) + entry = getgrgid(gid) + except ValueError: + entry = getgrnam(group) + return (entry.gr_name, entry.gr_gid) + except Exception as exception: + raise DropPrivilegesException( + "Invalid run group: %s (%s)" % (group, format_ex(exception))) diff --git a/pgbouncemgr/logger.py b/pgbouncemgr/logger.py index dbf5b8f..f9ef0b1 100644 --- a/pgbouncemgr/logger.py +++ b/pgbouncemgr/logger.py @@ -8,8 +8,8 @@ import re class LoggerException(Exception): """Used for all exceptions that are raised from pgbouncemgr.logger.""" -class SyslogLogTargetException(Exception): - """Used for all exceptions that are raised from the SyslogLogTarget.""" +class SyslogLogException(Exception): + """Used for all exceptions that are raised from the SyslogLog log target.""" class Logger(list): @@ -41,7 +41,7 @@ def format_ex(exception): return "%s: %s" % (name, error) -class MemoryLogTarget(list): +class MemoryLog(list): """MemoryTarget is used to collect log messages in memory.""" def debug(self, msg): self.append(["DEBUG", msg]) @@ -56,7 +56,7 @@ class MemoryLogTarget(list): self.append(["ERROR", msg]) -class ConsoleLogTarget(): +class ConsoleLog(): """ConsoleTarget is used to send log messages to the console.""" def __init__(self, verbose, debug): self.verbose_enabled = verbose or debug @@ -79,7 +79,7 @@ class ConsoleLogTarget(): print("[ERROR] %s" % msg) -class SyslogLogTarget(): +class SyslogLog(): """Syslogtarget is used to send log messages to syslog.""" def __init__(self, ident, facility): facility = self._resolve_facility(facility) @@ -108,6 +108,6 @@ class SyslogLogTarget(): try: return int(getattr(syslog, str.upper(facility))) except (AttributeError, ValueError): - raise SyslogLogTargetException( + raise SyslogLogException( "Invalid syslog facility provided (facility=%s)" % (facility if facility is None else repr(facility))) diff --git a/pgbouncemgr/manager.py b/pgbouncemgr/manager.py index e22455c..f0dca33 100644 --- a/pgbouncemgr/manager.py +++ b/pgbouncemgr/manager.py @@ -1,12 +1,17 @@ # -*- coding: utf-8 -*- +# pylint: disable=too-few-public-methods,missing-docstring + """The manager implements the main process that keeps track of changes in the PostgreSQL cluster and that reconfigures pgbouncer when needed.""" +from time import sleep from argparse import ArgumentParser -from pgbouncemgr.logger import Logger, ConsoleLogTarget, SyslogLogTarget +from pgbouncemgr.logger import Logger, ConsoleLog, SyslogLog, format_ex from pgbouncemgr.config import Config from pgbouncemgr.state import State +from pgbouncemgr.drop_privileges import drop_privileges from pgbouncemgr.state_store import StateStore +from pgbouncemgr.node_poller import NodePoller DEFAULT_CONFIG = "/etc/pgbouncer/pgbouncemgr.yaml" @@ -19,23 +24,25 @@ class Manager(): self.config = Config(args.config) self._create_logger(args) self._create_state() + self.node_poller = NodePoller(self.state) def _create_logger(self, args): self.log = Logger() - self.log.append(ConsoleLogTarget(args.verbose, args.debug)) + self.log.append(ConsoleLog(args.verbose, args.debug)) if args.log_facility.lower() != 'none': - self.log.append(SyslogLogTarget("pgbouncemgr", args.log_facility)) + self.log.append(SyslogLog("pgbouncemgr", args.log_facility)) def _create_state(self): - self.state = State.fromConfig(self.config, self.log) + self.state = State.from_config(self.config, self.log) self.state_store = StateStore(self.config.state_file, self.state) self.state_store.load() - def start(self): - self.log.info("Not yet!") - self.log.debug("Work in progres...") - self.log.warning("Beware!") - self.log.error("I will crash now") + def run(self): + """Starts the manager.""" + self.drop_privileges(self.config.run_user, self.config.run_group) + while True: + self.node_poller.poll() + sleep(self.config.poll_interval_in_sec) def _parse_arguments(args): @@ -52,7 +59,7 @@ def _parse_arguments(args): "-f", "--log-facility", default=DEFAULT_LOG_FACILITY, help="syslog facility to use or 'none' to disable syslog logging " + - "(default: %s)" % DEFAULT_LOG_FACILITY) + "(default: %s)" % DEFAULT_LOG_FACILITY) parser.add_argument( "--config", default=DEFAULT_CONFIG, @@ -61,4 +68,4 @@ def _parse_arguments(args): if __name__ == "__main__": - Manager(None).start() + Manager(None).run() diff --git a/pgbouncemgr/node_config.py b/pgbouncemgr/node_config.py index 2f65391..c907d3f 100644 --- a/pgbouncemgr/node_config.py +++ b/pgbouncemgr/node_config.py @@ -1,7 +1,13 @@ +# -*- coding: utf-8 -*- +# no-pylint: disable=missing-docstring,too-many-instance-attributes + import os from pgbouncemgr.config import InvalidConfigValue + class NodeConfig(): + """NodeConfig holds the configuration for a single PostgreSQL node + in the PostgreSQL cluster.""" def __init__(self, node_id): self.node_id = node_id self._pgbouncer_config = None @@ -26,7 +32,7 @@ class NodeConfig(): def export(self): """Exports the data for the node configuration, that we want - to end up in the state data.""" + to end up in the state data that is stored in the state store.""" return { "pgbouncer_config": self.pgbouncer_config, "host": self.host, diff --git a/pgbouncemgr/node_poller.py b/pgbouncemgr/node_poller.py new file mode 100644 index 0000000..e8dacfa --- /dev/null +++ b/pgbouncemgr/node_poller.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# no-pylint: disable=missing-docstring + +class NodePoller(): + """The NodePoller is used to poll all the nodes that are available + in the state object, and to update their status according to + the results.""" + def __init__(self, state): + self.state = state + + def poll(self): + for node in self.state.nodes.values(): + print(repr(node.config)) diff --git a/pgbouncemgr/postgres.py b/pgbouncemgr/postgres.py new file mode 100644 index 0000000..0c0c58c --- /dev/null +++ b/pgbouncemgr/postgres.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +# no-pylint: disable=missing-docstring,no-self-use,broad-except + +"""This module provides various classes that uare used for connecting to + a PostgreSQL or pgbouncer server.""" + +import psycopg2 +import multiprocessing +from psycopg2.extras imoprt LogicalReplicationConnection +from pgbouncemgr.logger import format_ex + + +class PgException(Exception): + """Used for all exceptions that are raised from pgbouncemgr.postgres.""" + +class PgConnectionFailed(PgException): + """Raised when connecting to the database server failed.""" + def __init__(self, exception): + super().__init__( + "Could not connect to %s: %s" % (format_ex(exception))) + +class RetrievingPgReplicationStatusFailed(PgException): + """Raised when the replication status cannot be determined.""" + +class ReloadingPgbouncerFailed(PgException): + """Raised when reloading the pgbouncer configuration fails.""" + +class ConnectedToWrongBackend(PgException): + """Raised when the pgbouncer instance is not connected to the + correct PostgreSQL backend service.""" + def __init__(self, msg): + super().__init__( + "The pgbouncer is not connected to the expected PostgreSQL " + + "backend service: %s" % msg) + + +# Return values for the PgConnection.connect() method. +CONNECTED = 'CONNECTED' +REUSED = 'REUSED' +RECONNECTED = 'RECONNECTED' + +class PgConnection(): + """Implements a connection to a PostgreSQL server.""" + def __init__(self, config): + self.conn_params = self._create_conn_params(config) + self.ping_query = "SELECT 1" + self.conn = None + + def _create_conn_params(self, config): + """Use only connection parameters that don't have value None.""" + return dict( + (k, v) for k, v in config.items() + if v is not None) + + def connect(self): + """Connect to the database server. When a connection exists, + then check if it is still oeprational. If yes, then reuse + this connection. If no, or when no connection exists, then + setup a new connection. + Raises an exeption when the database connection cannot be setup. + returns CONNECTED, REUSED or RECONNECTED when the connection + was setup successfully.""" + reconnected = False + if self.conn is not None: + try: + with self.conn.cursor() as cursor: + cursor.execute(self.ping_query) + return REUSED + except psycopg2.OperationalError: + reconnected = True + self.disconnect() + try: + self.conn = psycopg2.connect(**self.conn_params) + return RECONNECTED if reconnected else CONNECTED + except psycopg2.OperationalError as exception: + self.disconnect() + raise PgConnectionFailed(exception) + + def disconnect(self): + """Disconnect from the database server.""" + try: + if self.conn: + self.conn.close() + except Exception: + pass + self.conn = None + + +# Return values for the PgReplicationConnection status. +OFFLINE = "OFFLINE" +PRIMARY = "PRIMARY" +STANDBY = "STANDBY" + +class PgReplicationConnection(PgConnection): + """This PostgresQL connection class is used to setup a replication + connection to a PostgreSQL database server, which can be used + to retrieve the replication status for the server.""" + def __init__(self, node_config): + super().__init__(node_config) + self.conn_params["connection_factory"] = LogicalReplicationConnection + + def get_replication_status(self): + """Returns the replication status for a node. This is an array, + containing the keys "status" (OFFLINE, PRIMARY or STANDBY), + "system_id" and "timeline_id".""" + status = { + "status": None, + "system_id": None, + "timeline_id": None + } + + # Try to connect to the node. If this fails, the node is OFFLINE. + try: + self.connect() + except PgConnectionFailed: + status["status"] = OFFLINE + return status + + # Check if the node is running in primary or standby mode. + try: + with self.conn.cursor() as cursor: + cursor.execute("SELECT pg_is_in_recovery()") + in_recovery = cursor.fetchone()[0] + status["status"] = STANDBY if in_recovery else PRIMARY + except psycopg2.InternalError as exception: + self.disconnect() + raise RetrievingPgReplicationStatusFailed( + "SELECT pg_is_in_recovery() failed: %s" % format_ex(exception)) + + # Retrieve system_id and timeline_id. + try: + with self.conn.cursor() as cursor: + cursor.execute("IDENTIFY_SYSTEM") + row = cursor.fetchone() + system_id, timeline_id, *_ = row + status["system_id"] = system_id + status["timeline_id"] = timeline_id + except psycopg2.InternalError as exception: + self.disconnect() + raise RetrievingPgReplicationStatusFailed( + "IDENTIFY_SYSTEM failed: %s" % format_ex(exception)) + + return status + + +class PgConnectionViaPgbouncer(PgConnection): + """This PostgreSQL connection class is used to setup a connection + to the PostgreSQL cluster, via the pgbouncer instance.""" + def __init__(self, node_config, pgbouncer_config): + """Instantiate a new connection. The node_config and the + pgbouncer_config will be combined to get the connection parameters + for connecting to the PostgreSQL server.""" + self.node_config = node_config + + # First, apply all the connection parameters as defined for the node. + # This is fully handled by the parent class. + super().__init__(node_config) + + # Secondly, override parameters to redirect the connection to + # the pgbouncer instance. + self.conn_params["host"] = pgbouncer_config["host"] + self.conn_params["port"] = pgbouncer_config["port"] + + # Note that we don't setup a replication connection here. This is + # unfortunately not possible, because pgbouncer does not support + # this type of connection. If this would ever become possible in + # pgbouncer, I will definitely switch to such connection, since it + # allows for doing some extra checks. + + def verify_connection(self): + """Check if the connection via pgbouncer ends up with the + configured node.""" + # This is done in a somewhat convoluted way with a subprocess and a + # timer. This is done, because a connection is made through + # pgbouncer, and pgbouncer will try really hard to connect the user + # to its known backend service. When that service is unavailable, + # then pgbouncer accepts the connection, but the communication will + # then stall for quite a while. For swift operation of backend + # switching, we therefore use the timeout setup here, to reconfigure + # the system more quickly. + # + # Note that in most situations this shouldn't be an issue, since we + # will always reload pgbouncer after a configuration change and the + # connection from below will, because of that, work as intended + # right away. We must be prepared for the odd case out though, since + # we're going for HA here.""" + def check_func(report_func): + # Setup the database connection + try: + self.connect() + except Exception as exception: + return report_func(False, exception) + + # Check if we're connected to the requested node. + with self.conn.cursor() as cursor: + try: + cursor.execute(VERIFY_QUERY, { + "host": self.node_config["host"], + "port": self.node_config["port"] + }) + result = cursor.fetchone()[0] + self.disconnect() + + if result is not None: + raise ConnectedToWrongBackend(result) + except Exception as exception: + self.disconnect() + return report_func(False, exception) + + # When the verify query did not return an error message, then we + # are in the green. + return report_func(True, None) + + parent_conn, child_conn = multiprocessing.Pipe() + def report_func(true_or_false, exception): + child_conn.send([true_or_false, exception]) + child_conn.close() + proc = multiprocessing.Process(target=check_func, args=(report_func,)) + proc.start() + proc.join(self.node_config["connect_timeout"]) + if proc.is_alive(): + proc.terminate() + proc.join() + return (False, PgConnectionFailed("Connection attempt timed out")) + result = parent_conn.recv() + proc.join() + return result + + +class PgBouncerConsoleConnection(PgConnection): + """This PostgreSQL connection class is used to setup a console + connection to a pgbouncer server. This kind of connection can be + used to control the pgbouncer instance via admin commands. + This connection is used by pgbouncemgr to reload the configuration + of pgbouncer when the cluster state changes.""" + def __init__(self, pgbouncer_config): + super().__init__(pgbouncer_config) + + # For the console connection, the database name "pgbouncer" + # must be used. + self.conn_params["dbname"] = "pgbouncer" + + # The default ping query does not work when connected to the + # pgbouncer console. Here's a simple replacement for it. + self.ping_query = "SHOW VERSION" + + def connect(self): + """Connect to the pgbouncer console. After connecting, the autocommit + feature will be disabled for the connection, so the underlying + PostgreSQL client library won't automatically try to setup a + transaction. Transactions are not supported by the pgbouncer + console.""" + result = super().connect() + self.conn.autocommit = True + return result + + def reload(self): + """Send the 'RELOAD' command to the pgbouncer console, in order to + reload the configuration file.""" + try: + self.connect() + with self.conn.cursor() as cursor: + cursor.execute('RELOAD') + if cursor.statusmessage == 'RELOAD': + return + raise ReloadingPgbouncerFailed( + "Unexpected status message: %s" % cursor.statusmessage) + except Exception as exception: + raise ReloadingPgbouncerFailed( + "An exception occurred: %s" % format_ex(exception)) + diff --git a/pgbouncemgr/state.py b/pgbouncemgr/state.py index 2a151b9..1e05656 100644 --- a/pgbouncemgr/state.py +++ b/pgbouncemgr/state.py @@ -109,8 +109,8 @@ class InvalidNodeStatus(StateException): class State(): @staticmethod - def fromConfig(config, logger): - state = State(logger) + def from_config(config, logger): + state = State(logger) for node_id, settings in config.nodes.items(): node_config = NodeConfig(node_id) for k, v in settings.items(): @@ -237,7 +237,7 @@ class State(): @leader_node_id.setter def leader_node_id(self, leader_node_id): """Sets the id of the leader node.""" - try: + try: node = self.get_node(leader_node_id) except UnknownNodeRequested: self.log.warning( diff --git a/pgbouncemgr/state_store.py b/pgbouncemgr/state_store.py index 02c5387..b4db532 100644 --- a/pgbouncemgr/state_store.py +++ b/pgbouncemgr/state_store.py @@ -11,7 +11,7 @@ from pgbouncemgr.state import \ class StateStoreException(Exception): - """Used for all exceptions that are raised from pgbouncemgr.state_store.""" + """Used for all exceptions that are raised from pgbouncemgr.state_store.""" class StateStore(): @@ -42,10 +42,10 @@ class StateStore(): # Copy the state over to the state object. for key in [ - "system_id", - "timeline_id", - "active_pgbouncer_config", - "leader_node_id"]: + "system_id", + "timeline_id", + "active_pgbouncer_config", + "leader_node_id"]: if key in loaded_state: try: setattr(self.state, key, loaded_state[key]) @@ -65,13 +65,11 @@ class StateStore(): The error can be found in the err property.""" new_state = json.dumps(self.state.export(), sort_keys=True, indent=2) try: - self.err = None swap_path = "%s..SWAP" % self.path with open(swap_path, "w") as file_handle: print(new_state, file=file_handle) rename(swap_path, self.path) - return True except Exception as exception: - self.err = "Storing state to file (%s) failed: %s" % ( - self.path, format_ex(exception)) - return False + raise StateStoreException( + "Storing state to file (%s) failed: %s" % ( + self.path, format_ex(exception))) diff --git a/tests/test_drop_privileges.py b/tests/test_drop_privileges.py new file mode 100644 index 0000000..617e6b0 --- /dev/null +++ b/tests/test_drop_privileges.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- + +import unittest +from os import geteuid +from pgbouncemgr.drop_privileges import * + + +class DropPrivilegesTests(unittest.TestCase): + def test_givenKnownUsername_GetUid_ReturnsUid(self): + user, uid = get_uid('daemon') + self.assertEqual('daemon', user) + self.assertEqual(1, uid) + + def test_givenUnknownUsername_GetUid_RaisesException(self): + with self.assertRaises(DropPrivilegesException) as context: + get_uid("nobodier_than_ever") + self.assertIn("name not found", str(context.exception)) + self.assertIn("nobodier_than_ever", str(context.exception)) + + def test_givenKnownUserId_GetUid_ReturnsUserAndUid(self): + user, uid = get_uid(1) + self.assertEqual('daemon', user) + self.assertEqual(1, uid) + + def test_givenUnknownUserId_GetUid_RaisesException(self): + with self.assertRaises(DropPrivilegesException) as context: + get_uid(22222) + self.assertIn("Invalid run user: 22222", str(context.exception)) + self.assertIn("uid not found: 22222", str(context.exception)) + + def test_givenKownGroupName_GetGid_ReturnsGid(self): + group, gid = get_gid('daemon') + self.assertEqual('daemon', group) + self.assertEqual(gid, 1) + + def test_givenUnknownGroupName_GetGid_RaisesException(self): + with self.assertRaises(DropPrivilegesException) as context: + get_uid("groupies_are_none") + self.assertIn("name not found", str(context.exception)) + self.assertIn("groupies_are_none", str(context.exception)) + + def test_givenKnownGid_GetGid_ReturnsGid(self): + group, gid = get_gid(2) + self.assertEqual(2, gid) + self.assertEqual('bin', group) + + def test_givenUnknownGid_GetGid_RaisesException(self): + with self.assertRaises(DropPrivilegesException) as context: + get_gid(33333) + self.assertIn("Invalid run group: 33333", str(context.exception)) + self.assertIn("gid not found: 33333", str(context.exception)) + + + def test_givenProblem_DropPrivileges_RaisesException(self): + if geteuid() > 0: + with self.assertRaises(DropPrivilegesException) as context: + drop_privileges(1, 0) + self.assertIn("Operation not permitted", str(context.exception)) + else: + # Root is allowed to change the uid/gid, so in case the + # tests are run as the root user, use an alternative error + # scenario. + with self.assertRaises(DropPrivilegesException) as context: + drop_privileges(22222, 0) + self.assertIn("uid not found: 22222", str(context.exception)) + + def test_givenProblem_DropPrivileges_RaisesException(self): + if geteuid() > 0: + with self.assertRaises(DropPrivilegesException) as context: + drop_privileges(1, 0) + self.assertIn("Operation not permitted", str(context.exception)) + else: + # Root is allowed to change the uid/gid, so in case the + # tests are run as the root user, use an alternative error + # scenario. + with self.assertRaises(DropPrivilegesException) as context: + drop_privileges(22222, 0) + self.assertIn("uid not found: 22222", str(context.exception)) diff --git a/tests/test_logger.py b/tests/test_logger.py index 9219128..30ba2d8 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -15,8 +15,8 @@ class LoggerTests(unittest.TestCase): def test_Logger_WritesToAllTargets(self): logger = Logger() - logger.append(MemoryLogTarget()) - logger.append(MemoryLogTarget()) + logger.append(MemoryLog()) + logger.append(MemoryLog()) send_logs(logger) self.assertEqual([ @@ -36,30 +36,30 @@ class LoggerTests(unittest.TestCase): class SyslogLargetTests(unittest.TestCase): def test_GivenInvalidFacility_ExceptionIsRaised(self): - with self.assertRaises(SyslogLogTargetException) as context: - SyslogLogTarget("my app", "LOG_WRONG") + with self.assertRaises(SyslogLogException) as context: + SyslogLog("my app", "LOG_WRONG") self.assertIn("Invalid syslog facility provided", str(context.exception)) self.assertIn("'LOG_WRONG'", str(context.exception)) def test_GivenValidFacility_LogTargetIsCreated(self): - SyslogLogTarget("my app", "LOG_LOCAL0") + SyslogLog("my app", "LOG_LOCAL0") class ConsoleLogTargetTests(unittest.TestCase): def test_CanCreateSilentConsoleLogger(self): - console = ConsoleLogTarget(False, False) + console = ConsoleLog(False, False) self.assertFalse(console.verbose_enabled) self.assertFalse(console.debug_enabled) def test_CanCreateVerboseConsoleLogger(self): - console = ConsoleLogTarget(True, False) + console = ConsoleLog(True, False) self.assertTrue(console.verbose_enabled) self.assertFalse(console.debug_enabled) def test_CanCreateDebuggingConsoleLogger(self): - console = ConsoleLogTarget(False, True) + console = ConsoleLog(False, True) self.assertTrue(console.verbose_enabled) self.assertTrue(console.debug_enabled) @@ -67,7 +67,7 @@ class ConsoleLogTargetTests(unittest.TestCase): def test_CanCreateVerboseDebuggingConsoleLogger(self): """Basically the same as a debugging console logger, since the debug flag enables the verbose flag as well.""" - console = ConsoleLogTarget(True, True) + console = ConsoleLog(True, True) self.assertTrue(console.verbose_enabled) self.assertTrue(console.debug_enabled) diff --git a/tests/test_manager.py b/tests/test_manager.py index fcd7b52..feea781 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -35,7 +35,7 @@ class ManagerTests(unittest.TestCase): # Check the logging setup. self.assertEqual(1, len(mgr.log)) console = mgr.log[0] - self.assertEqual('ConsoleLogTarget', type(console).__name__) + self.assertEqual('ConsoleLog', type(console).__name__) self.assertTrue(console.verbose_enabled) self.assertTrue(console.debug_enabled) diff --git a/tests/test_state_store.py b/tests/test_state_store.py index 329e8b7..7b1d7be 100644 --- a/tests/test_state_store.py +++ b/tests/test_state_store.py @@ -10,9 +10,8 @@ from pgbouncemgr.state_store import * def make_test_file_path(filename): - return os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "testfiles", filename) + testfiles_path = os.path.dirname(os.path.realpath(__file__)) + return os.path.join(testfiles_path, "testfiles", filename) class StateStoreTests(unittest.TestCase): def test_GivenNonExistingStateFile_OnLoad_StateStoreDoesNotLoadState(self): @@ -60,3 +59,9 @@ class StateStoreTests(unittest.TestCase): if tmpfile and os.path.exists(tmpfile.name): os.unlink(tmpfile.name) + def test_GivenError_OnSave_ExceptionIsRaised(self): + state = State(Logger()) + with self.assertRaises(StateStoreException) as context: + StateStore("/tmp/path/that/does/not/exist/fofr/statefile", state).store() + self.assertIn("Storing state to file", str(context.exception)) + self.assertIn("failed: FileNotFoundError", str(context.exception))