Added a separate module for dropping privileges, to clean up the manager code.

This commit is contained in:
Maurice Makaay 2019-12-12 10:16:46 +01:00
parent e5dbf96cb4
commit f744dbec4f
12 changed files with 492 additions and 44 deletions

View File

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-few-public-methods,missing-docstring
"""Functionality to drop privileges to those of a specified user/group id."""
from os import setuid, setgid, setuid, setgroups, getuid, geteuid, getgroups
from pwd import getpwnam, getpwuid
from grp import getgrnam, getgrgid
from pgbouncemgr.logger import format_ex
class DropPrivilegesException(Exception):
"""Used for all exceptions that are raised from
pgbouncemgr.drop_privileges."""
def drop_privileges(user, group):
"""Drop privileges to those of the provided system user / group.
When dropping the privileges fails, an exception will be raised.
Otherwise a tuple (user, uid, group, gid) will be returned."""
user, uid = get_uid(user)
group, gid = get_gid(group)
try:
# Clear all groups when the current user is root (because only root
# has permission to do so).
if geteuid() == 0:
setgroups([])
# When the requested group is not already in the currently
# effective groups, then set the gid.
if gid not in getgroups():
setgid(gid)
# When the requested user is not already the currently
# effective user, then set the uid.
if uid != getuid():
setuid(uid)
except Exception as exception:
raise DropPrivilegesException(
"Could not drop privileges to %s:%s (%d:%d): %s" %
(user, group, uid, gid, format_ex(exception)))
return (user, uid, group, gid)
def get_uid(user):
try:
try:
uid = int(user)
entry = getpwuid(uid)
except ValueError:
entry = getpwnam(user)
return (entry.pw_name, entry.pw_uid)
except Exception as exception:
raise DropPrivilegesException(
"Invalid run user: %s (%s)" % (user, format_ex(exception)))
def get_gid(group):
try:
try:
gid = int(group)
entry = getgrgid(gid)
except ValueError:
entry = getgrnam(group)
return (entry.gr_name, entry.gr_gid)
except Exception as exception:
raise DropPrivilegesException(
"Invalid run group: %s (%s)" % (group, format_ex(exception)))

View File

@ -8,8 +8,8 @@ import re
class LoggerException(Exception):
"""Used for all exceptions that are raised from pgbouncemgr.logger."""
class SyslogLogTargetException(Exception):
"""Used for all exceptions that are raised from the SyslogLogTarget."""
class SyslogLogException(Exception):
"""Used for all exceptions that are raised from the SyslogLog log target."""
class Logger(list):
@ -41,7 +41,7 @@ def format_ex(exception):
return "%s: %s" % (name, error)
class MemoryLogTarget(list):
class MemoryLog(list):
"""MemoryTarget is used to collect log messages in memory."""
def debug(self, msg):
self.append(["DEBUG", msg])
@ -56,7 +56,7 @@ class MemoryLogTarget(list):
self.append(["ERROR", msg])
class ConsoleLogTarget():
class ConsoleLog():
"""ConsoleTarget is used to send log messages to the console."""
def __init__(self, verbose, debug):
self.verbose_enabled = verbose or debug
@ -79,7 +79,7 @@ class ConsoleLogTarget():
print("[ERROR] %s" % msg)
class SyslogLogTarget():
class SyslogLog():
"""Syslogtarget is used to send log messages to syslog."""
def __init__(self, ident, facility):
facility = self._resolve_facility(facility)
@ -108,6 +108,6 @@ class SyslogLogTarget():
try:
return int(getattr(syslog, str.upper(facility)))
except (AttributeError, ValueError):
raise SyslogLogTargetException(
raise SyslogLogException(
"Invalid syslog facility provided (facility=%s)" %
(facility if facility is None else repr(facility)))

View File

@ -1,12 +1,17 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-few-public-methods,missing-docstring
"""The manager implements the main process that keeps track of changes in the
PostgreSQL cluster and that reconfigures pgbouncer when needed."""
from time import sleep
from argparse import ArgumentParser
from pgbouncemgr.logger import Logger, ConsoleLogTarget, SyslogLogTarget
from pgbouncemgr.logger import Logger, ConsoleLog, SyslogLog, format_ex
from pgbouncemgr.config import Config
from pgbouncemgr.state import State
from pgbouncemgr.drop_privileges import drop_privileges
from pgbouncemgr.state_store import StateStore
from pgbouncemgr.node_poller import NodePoller
DEFAULT_CONFIG = "/etc/pgbouncer/pgbouncemgr.yaml"
@ -19,23 +24,25 @@ class Manager():
self.config = Config(args.config)
self._create_logger(args)
self._create_state()
self.node_poller = NodePoller(self.state)
def _create_logger(self, args):
self.log = Logger()
self.log.append(ConsoleLogTarget(args.verbose, args.debug))
self.log.append(ConsoleLog(args.verbose, args.debug))
if args.log_facility.lower() != 'none':
self.log.append(SyslogLogTarget("pgbouncemgr", args.log_facility))
self.log.append(SyslogLog("pgbouncemgr", args.log_facility))
def _create_state(self):
self.state = State.fromConfig(self.config, self.log)
self.state = State.from_config(self.config, self.log)
self.state_store = StateStore(self.config.state_file, self.state)
self.state_store.load()
def start(self):
self.log.info("Not yet!")
self.log.debug("Work in progres...")
self.log.warning("Beware!")
self.log.error("I will crash now")
def run(self):
"""Starts the manager."""
self.drop_privileges(self.config.run_user, self.config.run_group)
while True:
self.node_poller.poll()
sleep(self.config.poll_interval_in_sec)
def _parse_arguments(args):
@ -52,7 +59,7 @@ def _parse_arguments(args):
"-f", "--log-facility",
default=DEFAULT_LOG_FACILITY,
help="syslog facility to use or 'none' to disable syslog logging " +
"(default: %s)" % DEFAULT_LOG_FACILITY)
"(default: %s)" % DEFAULT_LOG_FACILITY)
parser.add_argument(
"--config",
default=DEFAULT_CONFIG,
@ -61,4 +68,4 @@ def _parse_arguments(args):
if __name__ == "__main__":
Manager(None).start()
Manager(None).run()

View File

@ -1,7 +1,13 @@
# -*- coding: utf-8 -*-
# no-pylint: disable=missing-docstring,too-many-instance-attributes
import os
from pgbouncemgr.config import InvalidConfigValue
class NodeConfig():
"""NodeConfig holds the configuration for a single PostgreSQL node
in the PostgreSQL cluster."""
def __init__(self, node_id):
self.node_id = node_id
self._pgbouncer_config = None
@ -26,7 +32,7 @@ class NodeConfig():
def export(self):
"""Exports the data for the node configuration, that we want
to end up in the state data."""
to end up in the state data that is stored in the state store."""
return {
"pgbouncer_config": self.pgbouncer_config,
"host": self.host,

View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
# no-pylint: disable=missing-docstring
class NodePoller():
"""The NodePoller is used to poll all the nodes that are available
in the state object, and to update their status according to
the results."""
def __init__(self, state):
self.state = state
def poll(self):
for node in self.state.nodes.values():
print(repr(node.config))

271
pgbouncemgr/postgres.py Normal file
View File

@ -0,0 +1,271 @@
# -*- coding: utf-8 -*-
# no-pylint: disable=missing-docstring,no-self-use,broad-except
"""This module provides various classes that uare used for connecting to
a PostgreSQL or pgbouncer server."""
import psycopg2
import multiprocessing
from psycopg2.extras imoprt LogicalReplicationConnection
from pgbouncemgr.logger import format_ex
class PgException(Exception):
"""Used for all exceptions that are raised from pgbouncemgr.postgres."""
class PgConnectionFailed(PgException):
"""Raised when connecting to the database server failed."""
def __init__(self, exception):
super().__init__(
"Could not connect to %s: %s" % (format_ex(exception)))
class RetrievingPgReplicationStatusFailed(PgException):
"""Raised when the replication status cannot be determined."""
class ReloadingPgbouncerFailed(PgException):
"""Raised when reloading the pgbouncer configuration fails."""
class ConnectedToWrongBackend(PgException):
"""Raised when the pgbouncer instance is not connected to the
correct PostgreSQL backend service."""
def __init__(self, msg):
super().__init__(
"The pgbouncer is not connected to the expected PostgreSQL " +
"backend service: %s" % msg)
# Return values for the PgConnection.connect() method.
CONNECTED = 'CONNECTED'
REUSED = 'REUSED'
RECONNECTED = 'RECONNECTED'
class PgConnection():
"""Implements a connection to a PostgreSQL server."""
def __init__(self, config):
self.conn_params = self._create_conn_params(config)
self.ping_query = "SELECT 1"
self.conn = None
def _create_conn_params(self, config):
"""Use only connection parameters that don't have value None."""
return dict(
(k, v) for k, v in config.items()
if v is not None)
def connect(self):
"""Connect to the database server. When a connection exists,
then check if it is still oeprational. If yes, then reuse
this connection. If no, or when no connection exists, then
setup a new connection.
Raises an exeption when the database connection cannot be setup.
returns CONNECTED, REUSED or RECONNECTED when the connection
was setup successfully."""
reconnected = False
if self.conn is not None:
try:
with self.conn.cursor() as cursor:
cursor.execute(self.ping_query)
return REUSED
except psycopg2.OperationalError:
reconnected = True
self.disconnect()
try:
self.conn = psycopg2.connect(**self.conn_params)
return RECONNECTED if reconnected else CONNECTED
except psycopg2.OperationalError as exception:
self.disconnect()
raise PgConnectionFailed(exception)
def disconnect(self):
"""Disconnect from the database server."""
try:
if self.conn:
self.conn.close()
except Exception:
pass
self.conn = None
# Return values for the PgReplicationConnection status.
OFFLINE = "OFFLINE"
PRIMARY = "PRIMARY"
STANDBY = "STANDBY"
class PgReplicationConnection(PgConnection):
"""This PostgresQL connection class is used to setup a replication
connection to a PostgreSQL database server, which can be used
to retrieve the replication status for the server."""
def __init__(self, node_config):
super().__init__(node_config)
self.conn_params["connection_factory"] = LogicalReplicationConnection
def get_replication_status(self):
"""Returns the replication status for a node. This is an array,
containing the keys "status" (OFFLINE, PRIMARY or STANDBY),
"system_id" and "timeline_id"."""
status = {
"status": None,
"system_id": None,
"timeline_id": None
}
# Try to connect to the node. If this fails, the node is OFFLINE.
try:
self.connect()
except PgConnectionFailed:
status["status"] = OFFLINE
return status
# Check if the node is running in primary or standby mode.
try:
with self.conn.cursor() as cursor:
cursor.execute("SELECT pg_is_in_recovery()")
in_recovery = cursor.fetchone()[0]
status["status"] = STANDBY if in_recovery else PRIMARY
except psycopg2.InternalError as exception:
self.disconnect()
raise RetrievingPgReplicationStatusFailed(
"SELECT pg_is_in_recovery() failed: %s" % format_ex(exception))
# Retrieve system_id and timeline_id.
try:
with self.conn.cursor() as cursor:
cursor.execute("IDENTIFY_SYSTEM")
row = cursor.fetchone()
system_id, timeline_id, *_ = row
status["system_id"] = system_id
status["timeline_id"] = timeline_id
except psycopg2.InternalError as exception:
self.disconnect()
raise RetrievingPgReplicationStatusFailed(
"IDENTIFY_SYSTEM failed: %s" % format_ex(exception))
return status
class PgConnectionViaPgbouncer(PgConnection):
"""This PostgreSQL connection class is used to setup a connection
to the PostgreSQL cluster, via the pgbouncer instance."""
def __init__(self, node_config, pgbouncer_config):
"""Instantiate a new connection. The node_config and the
pgbouncer_config will be combined to get the connection parameters
for connecting to the PostgreSQL server."""
self.node_config = node_config
# First, apply all the connection parameters as defined for the node.
# This is fully handled by the parent class.
super().__init__(node_config)
# Secondly, override parameters to redirect the connection to
# the pgbouncer instance.
self.conn_params["host"] = pgbouncer_config["host"]
self.conn_params["port"] = pgbouncer_config["port"]
# Note that we don't setup a replication connection here. This is
# unfortunately not possible, because pgbouncer does not support
# this type of connection. If this would ever become possible in
# pgbouncer, I will definitely switch to such connection, since it
# allows for doing some extra checks.
def verify_connection(self):
"""Check if the connection via pgbouncer ends up with the
configured node."""
# This is done in a somewhat convoluted way with a subprocess and a
# timer. This is done, because a connection is made through
# pgbouncer, and pgbouncer will try really hard to connect the user
# to its known backend service. When that service is unavailable,
# then pgbouncer accepts the connection, but the communication will
# then stall for quite a while. For swift operation of backend
# switching, we therefore use the timeout setup here, to reconfigure
# the system more quickly.
#
# Note that in most situations this shouldn't be an issue, since we
# will always reload pgbouncer after a configuration change and the
# connection from below will, because of that, work as intended
# right away. We must be prepared for the odd case out though, since
# we're going for HA here."""
def check_func(report_func):
# Setup the database connection
try:
self.connect()
except Exception as exception:
return report_func(False, exception)
# Check if we're connected to the requested node.
with self.conn.cursor() as cursor:
try:
cursor.execute(VERIFY_QUERY, {
"host": self.node_config["host"],
"port": self.node_config["port"]
})
result = cursor.fetchone()[0]
self.disconnect()
if result is not None:
raise ConnectedToWrongBackend(result)
except Exception as exception:
self.disconnect()
return report_func(False, exception)
# When the verify query did not return an error message, then we
# are in the green.
return report_func(True, None)
parent_conn, child_conn = multiprocessing.Pipe()
def report_func(true_or_false, exception):
child_conn.send([true_or_false, exception])
child_conn.close()
proc = multiprocessing.Process(target=check_func, args=(report_func,))
proc.start()
proc.join(self.node_config["connect_timeout"])
if proc.is_alive():
proc.terminate()
proc.join()
return (False, PgConnectionFailed("Connection attempt timed out"))
result = parent_conn.recv()
proc.join()
return result
class PgBouncerConsoleConnection(PgConnection):
"""This PostgreSQL connection class is used to setup a console
connection to a pgbouncer server. This kind of connection can be
used to control the pgbouncer instance via admin commands.
This connection is used by pgbouncemgr to reload the configuration
of pgbouncer when the cluster state changes."""
def __init__(self, pgbouncer_config):
super().__init__(pgbouncer_config)
# For the console connection, the database name "pgbouncer"
# must be used.
self.conn_params["dbname"] = "pgbouncer"
# The default ping query does not work when connected to the
# pgbouncer console. Here's a simple replacement for it.
self.ping_query = "SHOW VERSION"
def connect(self):
"""Connect to the pgbouncer console. After connecting, the autocommit
feature will be disabled for the connection, so the underlying
PostgreSQL client library won't automatically try to setup a
transaction. Transactions are not supported by the pgbouncer
console."""
result = super().connect()
self.conn.autocommit = True
return result
def reload(self):
"""Send the 'RELOAD' command to the pgbouncer console, in order to
reload the configuration file."""
try:
self.connect()
with self.conn.cursor() as cursor:
cursor.execute('RELOAD')
if cursor.statusmessage == 'RELOAD':
return
raise ReloadingPgbouncerFailed(
"Unexpected status message: %s" % cursor.statusmessage)
except Exception as exception:
raise ReloadingPgbouncerFailed(
"An exception occurred: %s" % format_ex(exception))

View File

@ -109,8 +109,8 @@ class InvalidNodeStatus(StateException):
class State():
@staticmethod
def fromConfig(config, logger):
state = State(logger)
def from_config(config, logger):
state = State(logger)
for node_id, settings in config.nodes.items():
node_config = NodeConfig(node_id)
for k, v in settings.items():
@ -237,7 +237,7 @@ class State():
@leader_node_id.setter
def leader_node_id(self, leader_node_id):
"""Sets the id of the leader node."""
try:
try:
node = self.get_node(leader_node_id)
except UnknownNodeRequested:
self.log.warning(

View File

@ -11,7 +11,7 @@ from pgbouncemgr.state import \
class StateStoreException(Exception):
"""Used for all exceptions that are raised from pgbouncemgr.state_store."""
"""Used for all exceptions that are raised from pgbouncemgr.state_store."""
class StateStore():
@ -42,10 +42,10 @@ class StateStore():
# Copy the state over to the state object.
for key in [
"system_id",
"timeline_id",
"active_pgbouncer_config",
"leader_node_id"]:
"system_id",
"timeline_id",
"active_pgbouncer_config",
"leader_node_id"]:
if key in loaded_state:
try:
setattr(self.state, key, loaded_state[key])
@ -65,13 +65,11 @@ class StateStore():
The error can be found in the err property."""
new_state = json.dumps(self.state.export(), sort_keys=True, indent=2)
try:
self.err = None
swap_path = "%s..SWAP" % self.path
with open(swap_path, "w") as file_handle:
print(new_state, file=file_handle)
rename(swap_path, self.path)
return True
except Exception as exception:
self.err = "Storing state to file (%s) failed: %s" % (
self.path, format_ex(exception))
return False
raise StateStoreException(
"Storing state to file (%s) failed: %s" % (
self.path, format_ex(exception)))

View File

@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
import unittest
from os import geteuid
from pgbouncemgr.drop_privileges import *
class DropPrivilegesTests(unittest.TestCase):
def test_givenKnownUsername_GetUid_ReturnsUid(self):
user, uid = get_uid('daemon')
self.assertEqual('daemon', user)
self.assertEqual(1, uid)
def test_givenUnknownUsername_GetUid_RaisesException(self):
with self.assertRaises(DropPrivilegesException) as context:
get_uid("nobodier_than_ever")
self.assertIn("name not found", str(context.exception))
self.assertIn("nobodier_than_ever", str(context.exception))
def test_givenKnownUserId_GetUid_ReturnsUserAndUid(self):
user, uid = get_uid(1)
self.assertEqual('daemon', user)
self.assertEqual(1, uid)
def test_givenUnknownUserId_GetUid_RaisesException(self):
with self.assertRaises(DropPrivilegesException) as context:
get_uid(22222)
self.assertIn("Invalid run user: 22222", str(context.exception))
self.assertIn("uid not found: 22222", str(context.exception))
def test_givenKownGroupName_GetGid_ReturnsGid(self):
group, gid = get_gid('daemon')
self.assertEqual('daemon', group)
self.assertEqual(gid, 1)
def test_givenUnknownGroupName_GetGid_RaisesException(self):
with self.assertRaises(DropPrivilegesException) as context:
get_uid("groupies_are_none")
self.assertIn("name not found", str(context.exception))
self.assertIn("groupies_are_none", str(context.exception))
def test_givenKnownGid_GetGid_ReturnsGid(self):
group, gid = get_gid(2)
self.assertEqual(2, gid)
self.assertEqual('bin', group)
def test_givenUnknownGid_GetGid_RaisesException(self):
with self.assertRaises(DropPrivilegesException) as context:
get_gid(33333)
self.assertIn("Invalid run group: 33333", str(context.exception))
self.assertIn("gid not found: 33333", str(context.exception))
def test_givenProblem_DropPrivileges_RaisesException(self):
if geteuid() > 0:
with self.assertRaises(DropPrivilegesException) as context:
drop_privileges(1, 0)
self.assertIn("Operation not permitted", str(context.exception))
else:
# Root is allowed to change the uid/gid, so in case the
# tests are run as the root user, use an alternative error
# scenario.
with self.assertRaises(DropPrivilegesException) as context:
drop_privileges(22222, 0)
self.assertIn("uid not found: 22222", str(context.exception))
def test_givenProblem_DropPrivileges_RaisesException(self):
if geteuid() > 0:
with self.assertRaises(DropPrivilegesException) as context:
drop_privileges(1, 0)
self.assertIn("Operation not permitted", str(context.exception))
else:
# Root is allowed to change the uid/gid, so in case the
# tests are run as the root user, use an alternative error
# scenario.
with self.assertRaises(DropPrivilegesException) as context:
drop_privileges(22222, 0)
self.assertIn("uid not found: 22222", str(context.exception))

View File

@ -15,8 +15,8 @@ class LoggerTests(unittest.TestCase):
def test_Logger_WritesToAllTargets(self):
logger = Logger()
logger.append(MemoryLogTarget())
logger.append(MemoryLogTarget())
logger.append(MemoryLog())
logger.append(MemoryLog())
send_logs(logger)
self.assertEqual([
@ -36,30 +36,30 @@ class LoggerTests(unittest.TestCase):
class SyslogLargetTests(unittest.TestCase):
def test_GivenInvalidFacility_ExceptionIsRaised(self):
with self.assertRaises(SyslogLogTargetException) as context:
SyslogLogTarget("my app", "LOG_WRONG")
with self.assertRaises(SyslogLogException) as context:
SyslogLog("my app", "LOG_WRONG")
self.assertIn("Invalid syslog facility provided", str(context.exception))
self.assertIn("'LOG_WRONG'", str(context.exception))
def test_GivenValidFacility_LogTargetIsCreated(self):
SyslogLogTarget("my app", "LOG_LOCAL0")
SyslogLog("my app", "LOG_LOCAL0")
class ConsoleLogTargetTests(unittest.TestCase):
def test_CanCreateSilentConsoleLogger(self):
console = ConsoleLogTarget(False, False)
console = ConsoleLog(False, False)
self.assertFalse(console.verbose_enabled)
self.assertFalse(console.debug_enabled)
def test_CanCreateVerboseConsoleLogger(self):
console = ConsoleLogTarget(True, False)
console = ConsoleLog(True, False)
self.assertTrue(console.verbose_enabled)
self.assertFalse(console.debug_enabled)
def test_CanCreateDebuggingConsoleLogger(self):
console = ConsoleLogTarget(False, True)
console = ConsoleLog(False, True)
self.assertTrue(console.verbose_enabled)
self.assertTrue(console.debug_enabled)
@ -67,7 +67,7 @@ class ConsoleLogTargetTests(unittest.TestCase):
def test_CanCreateVerboseDebuggingConsoleLogger(self):
"""Basically the same as a debugging console logger,
since the debug flag enables the verbose flag as well."""
console = ConsoleLogTarget(True, True)
console = ConsoleLog(True, True)
self.assertTrue(console.verbose_enabled)
self.assertTrue(console.debug_enabled)

View File

@ -35,7 +35,7 @@ class ManagerTests(unittest.TestCase):
# Check the logging setup.
self.assertEqual(1, len(mgr.log))
console = mgr.log[0]
self.assertEqual('ConsoleLogTarget', type(console).__name__)
self.assertEqual('ConsoleLog', type(console).__name__)
self.assertTrue(console.verbose_enabled)
self.assertTrue(console.debug_enabled)

View File

@ -10,9 +10,8 @@ from pgbouncemgr.state_store import *
def make_test_file_path(filename):
return os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"testfiles", filename)
testfiles_path = os.path.dirname(os.path.realpath(__file__))
return os.path.join(testfiles_path, "testfiles", filename)
class StateStoreTests(unittest.TestCase):
def test_GivenNonExistingStateFile_OnLoad_StateStoreDoesNotLoadState(self):
@ -60,3 +59,9 @@ class StateStoreTests(unittest.TestCase):
if tmpfile and os.path.exists(tmpfile.name):
os.unlink(tmpfile.name)
def test_GivenError_OnSave_ExceptionIsRaised(self):
state = State(Logger())
with self.assertRaises(StateStoreException) as context:
StateStore("/tmp/path/that/does/not/exist/fofr/statefile", state).store()
self.assertIn("Storing state to file", str(context.exception))
self.assertIn("failed: FileNotFoundError", str(context.exception))