Skip to content

Commit

Permalink
Merge pull request #627 from hnousiainen/htn_local_webserver_user_aut…
Browse files Browse the repository at this point in the history
…hentication

webserver: support authentication on the webserver
  • Loading branch information
alexole authored Sep 21, 2024
2 parents 6f077c1 + 8c1c2b0 commit 69d1115
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 10 deletions.
12 changes: 10 additions & 2 deletions golang/pghoard_postgres_command_go.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func run() (int, error) {
verPtr := flag.Bool("version", false, "show program version")
hostPtr := flag.String("host", PGHOARD_HOST, "pghoard service host")
portPtr := flag.Int("port", PGHOARD_PORT, "pghoard service port")
usernamePtr := flag.String("username", "", "pghoard service username")
passwordPtr := flag.String("password", "", "pghoard service password")
sitePtr := flag.String("site", "", "pghoard backup site")
xlogPtr := flag.String("xlog", "", "xlog file name")
outputPtr := flag.String("output", "", "output file")
Expand Down Expand Up @@ -82,7 +84,7 @@ func run() (int, error) {
retry_seconds := *riPtr
for {
attempt += 1
rc, err := restore_command(url, *outputPtr, *xlogPtr)
rc, err := restore_command(url, *outputPtr, *xlogPtr, *usernamePtr, *passwordPtr)
if rc != EXIT_RESTORE_FAIL {
return rc, err
}
Expand All @@ -101,13 +103,16 @@ func archive_command(url string) (int, error) {
return EXIT_ABORT, errors.New("archive_command not yet implemented")
}

func restore_command(url string, output string, xlog string) (int, error) {
func restore_command(url string, output string, xlog string, username string, password string) (int, error) {
var output_path string
var req *http.Request
var err error

if output == "" {
req, err = http.NewRequest("HEAD", url, nil)
if username != "" && password != "" {
req.SetBasicAuth(username, password)
}
} else {
/* Construct absolute path for output - postgres calls this command with a relative path to its xlog
directory. Note that os.path.join strips preceding components if a new components starts with a
Expand Down Expand Up @@ -136,6 +141,9 @@ func restore_command(url string, output string, xlog string) (int, error) {
}
req, err = http.NewRequest("GET", url, nil)
req.Header.Set("x-pghoard-target-path", output_path)
if username != "" && password != "" {
req.SetBasicAuth(username, password)
}
}

client := &http.Client{}
Expand Down
1 change: 1 addition & 0 deletions pghoard.spec
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ BuildRequires: python3-devel
BuildRequires: python3-flake8
BuildRequires: python3-pylint
BuildRequires: python3-pytest
BuildRequires: systemd

%undefine _missing_build_ids_terminate_build
%define debug_package %{nil}
Expand Down
2 changes: 2 additions & 0 deletions pghoard/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def set_and_check_config_defaults(config, *, check_commands=True, check_pgdata=T
config.setdefault("backup_location", None)
config.setdefault("http_address", PGHOARD_HOST)
config.setdefault("http_port", PGHOARD_PORT)
config.setdefault("webserver_username", None)
config.setdefault("webserver_password", None)
config.setdefault("alert_file_dir", config.get("backup_location") or os.getcwd())
config.setdefault("json_state_file_path", "/var/lib/pghoard/pghoard_state.json")
config.setdefault("maintenance_mode_file", "/var/lib/pghoard/maintenance_mode_file")
Expand Down
7 changes: 5 additions & 2 deletions pghoard/object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Optional

from requests import Session
from requests.auth import HTTPBasicAuth
from rohmu import dates


Expand Down Expand Up @@ -72,14 +73,16 @@ def get_file_bytes(self, name):


class HTTPRestore(ObjectStore):
def __init__(self, host, port, site, pgdata=None):
def __init__(self, host, port, site, pgdata=None, *, username=None, password=None):
super().__init__(storage=None, prefix=None, site=site, pgdata=pgdata)
self.host = host
self.port = port
self.session = Session()
if username and password:
self.session.auth = HTTPBasicAuth(username, password)

def _url(self, path):
return "http://{host}:{port}/{site}/{path}".format(host=self.host, port=self.port, site=self.site, path=path)
return f"http://{self.host}:{self.port}/{self.site}/{path}"

def list_basebackups(self):
response = self.session.get(self._url("basebackup"))
Expand Down
14 changes: 12 additions & 2 deletions pghoard/postgres_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import argparse
import base64
import os
import socket
import sys
Expand Down Expand Up @@ -45,10 +46,17 @@ def __init__(self, message, exit_code=EXIT_FAIL):
self.exit_code = exit_code


def http_request(host, port, method, path, headers=None):
def http_request(host, port, method, path, headers=None, *, username=None, password=None):
conn = HTTPConnection(host=host, port=port)
if headers is not None:
headers = headers.copy()
else:
headers = {}
if username and password:
auth_str = base64.b64encode(f"{username}:{password}".encode("utf-8")).decode()
headers["Authorization"] = f"Basic {auth_str}"
try:
conn.request(method, path, headers=headers or {})
conn.request(method, path, headers=headers)
resp = conn.getresponse()
finally:
conn.close()
Expand Down Expand Up @@ -112,6 +120,8 @@ def main(args=None):
parser.add_argument("--version", action="version", help="show program version", version=version.__version__)
parser.add_argument("--host", type=str, default=PGHOARD_HOST, help="pghoard service host")
parser.add_argument("--port", type=int, default=PGHOARD_PORT, help="pghoard service port")
parser.add_argument("--username", type=str, help="pghoard service username")
parser.add_argument("--password", type=str, help="pghoard service password")
parser.add_argument("--site", type=str, required=True, help="pghoard backup site")
parser.add_argument("--xlog", type=str, required=True, help="xlog file name")
parser.add_argument("--output", type=str, help="output file")
Expand Down
19 changes: 16 additions & 3 deletions pghoard/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def create_recovery_conf(
site,
*,
port=PGHOARD_PORT,
webserver_username=None,
webserver_password=None,
primary_conninfo=None,
recovery_end_command=None,
recovery_target_action=None,
Expand All @@ -113,6 +115,13 @@ def create_recovery_conf(
"--xlog",
"%f",
]
if webserver_username and webserver_password:
restore_command.extend([
"--username",
webserver_username,
"--password",
webserver_password,
])
with open(os.path.join(dirpath, "PG_VERSION"), "r") as fp:
v = Version(fp.read().strip())
pg_version = v.major if v.major >= 10 else float(f"{v.major}.{v.minor}")
Expand Down Expand Up @@ -213,9 +222,11 @@ def generic_args(require_config=True, require_site=False):

cmd.add_argument("--site", help="pghoard site", required=require_site)

def host_port_args():
def host_port_user_args():
cmd.add_argument("--host", help="pghoard repository host", default=PGHOARD_HOST)
cmd.add_argument("--port", help="pghoard repository port", default=PGHOARD_PORT)
cmd.add_argument("--username", help="pghoard repository username")
cmd.add_argument("--password", help="pghoard repository password")

def target_args():
cmd.add_argument("--basebackup", help="pghoard basebackup", default="latest")
Expand Down Expand Up @@ -266,7 +277,7 @@ def target_args():
)

cmd = add_cmd(self.list_basebackups_http)
host_port_args()
host_port_user_args()
generic_args(require_config=False, require_site=True)

cmd = add_cmd(self.list_basebackups)
Expand All @@ -280,7 +291,7 @@ def target_args():

def list_basebackups_http(self, arg):
"""List available basebackups from a HTTP source"""
self.storage = HTTPRestore(arg.host, arg.port, arg.site)
self.storage = HTTPRestore(arg.host, arg.port, arg.site, username=arg.username, password=arg.password)
self.storage.show_basebackup_list(verbose=arg.verbose)

def _get_site_prefix(self, site):
Expand Down Expand Up @@ -609,6 +620,8 @@ def _get_basebackup(
dirpath=pgdata,
site=site,
port=self.config["http_port"],
webserver_username=self.config.get("webserver_username"),
webserver_password=self.config.get("webserver_password"),
primary_conninfo=primary_conninfo,
recovery_end_command=recovery_end_command,
recovery_target_action=recovery_target_action,
Expand Down
23 changes: 23 additions & 0 deletions pghoard/webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Copyright (c) 2016 Ohmu Ltd
See LICENSE for details
"""
import base64
import ipaddress
import logging
import os
Expand Down Expand Up @@ -242,6 +243,22 @@ class RequestHandler(BaseHTTPRequestHandler):
disable_nagle_algorithm = True
server_version = "pghoard/" + __version__
server: OwnHTTPServer
_expected_auth_header = None

def _authentication_check(self):
if self.server.config.get("webserver_username") and self.server.config.get("webserver_password"):
if self._expected_auth_header is None:
auth_data_raw = self.server.config["webserver_username"] + ":" + self.server.config["webserver_password"]
auth_data_b64 = base64.b64encode(auth_data_raw.encode("utf-8")).decode()
self._expected_auth_header = f"Basic {auth_data_b64}"
if self.headers.get("Authorization") != self._expected_auth_header:
self.send_response(401)
self.send_header("WWW-Authenticate", 'Basic realm="pghoard"')
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(b"Authentication required")
return False
return True

@contextmanager
def _response_handler(self, method):
Expand Down Expand Up @@ -645,12 +662,16 @@ def handle_archival_request(self, site, filename, filetype):
raise HttpResponse(status=201)

def do_PUT(self):
if not self._authentication_check():
return
with self._response_handler("PUT") as path:
site, obtype, obname = self._parse_request(path)
assert obtype in ("basebackup", "xlog", "timeline")
self.handle_archival_request(site, obname, obtype)

def do_HEAD(self):
if not self._authentication_check():
return
with self._response_handler("HEAD") as path:
site, obtype, obname = self._parse_request(path)
if self.headers.get("x-pghoard-target-path"):
Expand All @@ -664,6 +685,8 @@ def do_HEAD(self):
raise HttpResponse(status=200, headers=headers)

def do_GET(self):
if not self._authentication_check():
return
with self._response_handler("GET") as path:
site, obtype, obname = self._parse_request(path)
if obtype == "basebackup":
Expand Down
13 changes: 12 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ def fixture_pghoard(db, tmpdir, request):
yield from pghoard_base(db, tmpdir, request)


@pytest.fixture(name="pghoard_with_userauth")
def fixture_pghoard_with_userauth(db, tmpdir, request):
yield from pghoard_base(db, tmpdir, request, username="testuser", password="testpass")


@pytest.fixture(name="pghoard_ipv4_hostname")
def fixture_pghoard_ipv4_hostname(db, tmpdir, request):
yield from pghoard_base(db, tmpdir, request, listen_http_address="localhost")
Expand Down Expand Up @@ -362,7 +367,9 @@ def pghoard_base(
active_backup_mode="pg_receivexlog",
slot_name=None,
compression_count=None,
listen_http_address="127.0.0.1"
listen_http_address="127.0.0.1",
username=None,
password=None
):
test_site = request.function.__name__

Expand Down Expand Up @@ -418,6 +425,10 @@ def pghoard_base(
if compression_count is not None:
config["compression"]["thread_count"] = compression_count

if username is not None and password is not None:
config["webserver_username"] = username
config["webserver_password"] = password

confpath = os.path.join(str(tmpdir), "config.json")
with open(confpath, "w") as fp:
json.dump(config, fp)
Expand Down
61 changes: 61 additions & 0 deletions test/test_webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Copyright (c) 2015 Ohmu Ltd
See LICENSE for details
"""
import base64
import json
import logging
import os
Expand Down Expand Up @@ -40,6 +41,19 @@ def http_restore(pghoard):
return HTTPRestore("localhost", pghoard.config["http_port"], site=pghoard.test_site, pgdata=pgdata)


@pytest.fixture
def http_restore_with_userauth(pghoard_with_userauth):
pgdata = get_pg_wal_directory(pghoard_with_userauth.config["backup_sites"][pghoard_with_userauth.test_site])
return HTTPRestore(
"localhost",
pghoard_with_userauth.config["http_port"],
site=pghoard_with_userauth.test_site,
pgdata=pgdata,
username=pghoard_with_userauth.config["webserver_username"],
password=pghoard_with_userauth.config["webserver_password"]
)


class TestWebServer:
def test_requesting_status(self, pghoard):
pghoard.write_backup_state_to_json_file()
Expand Down Expand Up @@ -774,6 +788,53 @@ def test_uncontrolled_target_path(self, pghoard):
status = conn.getresponse().status
assert status == 400

def test_requesting_status_with_user_authentiction(self, pghoard_with_userauth):
pghoard_with_userauth.write_backup_state_to_json_file()
conn = HTTPConnection(host="127.0.0.1", port=pghoard_with_userauth.config["http_port"])
conn.request("GET", "/status")
response = conn.getresponse()
assert response.status == 401

username = pghoard_with_userauth.config["webserver_username"]
password = pghoard_with_userauth.config["webserver_password"]
auth_str = base64.b64encode(f"{username}:{password}".encode("utf-8")).decode()
headers = {"Authorization": f"Basic {auth_str}"}

conn = HTTPConnection(host="127.0.0.1", port=pghoard_with_userauth.config["http_port"])
conn.request("GET", "/status", headers=headers)
response = conn.getresponse()
assert response.status == 200

response_parsed = json.loads(response.read().decode("utf-8"))
# "startup_time": "2016-06-23T14:53:25.840787",
assert response_parsed["startup_time"] is not None

conn.request("GET", "/status/somesite", headers=headers)
response = conn.getresponse()
assert response.status == 400

conn.request("GET", "/somesite/status", headers=headers)
response = conn.getresponse()
assert response.status == 404

conn.request("GET", "/{}/status".format(pghoard_with_userauth.test_site), headers=headers)
response = conn.getresponse()
assert response.status == 501

def test_basebackups_with_user_authentication(self, capsys, db, http_restore_with_userauth, pghoard_with_userauth): # pylint: disable=redefined-outer-name
final_location = self._run_and_wait_basebackup(pghoard_with_userauth, db, "pipe")
backups = http_restore_with_userauth.list_basebackups()
assert len(backups) == 1
assert backups[0]["size"] > 0
assert backups[0]["name"] == os.path.join(
pghoard_with_userauth.test_site, "basebackup", os.path.basename(final_location)
)
# make sure they show up on the printable listing, too
http_restore_with_userauth.show_basebackup_list()
out, _ = capsys.readouterr()
assert "{} MB".format(int(backups[0]["metadata"]["original-file-size"]) // (1024 ** 2)) in out
assert backups[0]["name"] in out


@pytest.fixture(name="download_results_processor")
def fixture_download_results_processor() -> DownloadResultsProcessor:
Expand Down

0 comments on commit 69d1115

Please sign in to comment.