Add pytest test suite
Some checks failed
/ build (push) Failing after 5s

174 tests covering URL normalization, FTS5 query sanitization, SSRF/CSRF
guards, sharing-mode logic, DB schema and upsert paths, handler
end-to-end flows, and gateway body-size / mesh-whitelist guards. Each
recent bug-fix commit (6ffd38d, 1bc695f, 8dffd8c) has an explicit
regression test in test_regressions.py. One xfail documents a minor
latent bug in clean_url where port 80 is not stripped from upgraded
https URLs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Derick Phan 2026-04-24 15:03:29 -07:00
parent 8dffd8ccea
commit 44a16dea98
No known key found for this signature in database
18 changed files with 1673 additions and 0 deletions

128
conftest.py Normal file
View file

@ -0,0 +1,128 @@
"""Shared pytest fixtures for TinyWeb tests.
Three fixtures cover most tests: `temp_db` swaps the SQLite path to a
per-test tempfile, `seeded_db` layers sample rows on top, and `csrf_session`
primes the thread-local CSRF token that handlers read.
"""
import socket
import sys
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent))
import db as db_module
import handlers as handlers_module
@pytest.fixture
def temp_db(tmp_path, monkeypatch):
"""Isolated SQLite DB per test.
Swaps `db.DATABASE` and `db.DATA_DIR` to a tempdir, clears the connection
pool before and after so state doesn't leak across tests, and calls
`init_db()` so every schema object exists.
"""
data_dir = tmp_path / "tinyweb"
data_dir.mkdir()
db_path = data_dir / "index.db"
monkeypatch.setattr(db_module, "DATA_DIR", str(data_dir))
monkeypatch.setattr(db_module, "DATABASE", str(db_path))
with db_module._pool_lock:
for conn in db_module._pool:
try:
conn.close()
except Exception:
pass
db_module._pool.clear()
db_module.init_db()
yield db_path
with db_module._pool_lock:
for conn in db_module._pool:
try:
conn.close()
except Exception:
pass
db_module._pool.clear()
@pytest.fixture
def seeded_db(temp_db):
"""A temp DB with a small, realistic set of pages/tags/links."""
db = db_module.get_db()
try:
rows = [
("https://example.com/rust-intro", "Rust Intro", "A gentle introduction to rust borrow checker.", "notes on ownership"),
("https://example.com/python-tips", "Python Tips", "Daily python tricks for readable code.", ""),
("https://example.com/ocaml-why", "Why OCaml", "Type systems and inference in ocaml.", "private thoughts"),
("https://news.example.org/mesh", "Mesh Networking", "Reticulum and LoRa for decentralized networks.", ""),
]
for url, title, body, note in rows:
db.execute(
"INSERT INTO pages (url, title, body, note, last_modified) "
"VALUES (?, ?, ?, ?, '2026-04-01T00:00:00')",
(url, title, body, note),
)
db.commit()
page_ids = {
row["url"]: row["id"]
for row in db.execute("SELECT id, url FROM pages").fetchall()
}
tag_rows = [
(page_ids["https://example.com/rust-intro"], ["rust", "public"]),
(page_ids["https://example.com/python-tips"], ["python"]),
(page_ids["https://example.com/ocaml-why"], ["ocaml", "private"]),
(page_ids["https://news.example.org/mesh"], ["mesh", "public"]),
]
for pid, tags in tag_rows:
for name in tags:
db.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (name,))
tid = db.execute("SELECT id FROM tags WHERE name = ?", (name,)).fetchone()[0]
db.execute(
"INSERT OR IGNORE INTO page_tags (page_id, tag_id) VALUES (?, ?)",
(pid, tid),
)
db.execute(
"INSERT INTO links (page_id, url, label) VALUES (?, ?, ?)",
(page_ids["https://example.com/rust-intro"], "https://example.com/rust-advanced", "advanced rust guide"),
)
db.commit()
finally:
db_module.return_db(db)
return temp_db
@pytest.fixture
def csrf_session(monkeypatch):
"""Prime the CSRF thread-local so handler code that calls _get_csrf_token works."""
token = "test-csrf-token"
handlers_module._request_local.csrf_token = token
yield token
if hasattr(handlers_module._request_local, "csrf_token"):
del handlers_module._request_local.csrf_token
def patch_dns_fail(monkeypatch):
"""Make every socket.getaddrinfo call raise gaierror for the rest of this test."""
def boom(*args, **kwargs):
raise socket.gaierror("test: DNS disabled")
monkeypatch.setattr(socket, "getaddrinfo", boom)
def patch_dns_ok(monkeypatch, address="93.184.216.34"):
"""Make every getaddrinfo return a single public IP for the rest of this test."""
def ok(host, port, *args, **kwargs):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (address, port or 80))]
monkeypatch.setattr(socket, "getaddrinfo", ok)
def patch_dns_private(monkeypatch, address="127.0.0.1"):
"""Make every getaddrinfo return a private/blocked IP for the rest of this test."""
def private(host, port, *args, **kwargs):
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (address, port or 80))]
monkeypatch.setattr(socket, "getaddrinfo", private)

5
pytest.ini Normal file
View file

@ -0,0 +1,5 @@
[pytest]
testpaths = tests
python_files = test_*.py
filterwarnings =
ignore::DeprecationWarning

2
requirements-dev.txt Normal file
View file

@ -0,0 +1,2 @@
-r requirements.txt
pytest

60
tests/test_csrf.py Normal file
View file

@ -0,0 +1,60 @@
"""Tests for `_check_csrf` — form-submission CSRF protection.
Every POST handler calls this to verify the submitted _csrf field matches
the token stored in the thread-local (which is seeded from the cookie by
`dispatch_request`). Missing or mismatched tokens must fail closed.
"""
import handlers as handlers_module
from handlers import _check_csrf, _csrf_field, _get_csrf_token
def _set_token(token):
handlers_module._request_local.csrf_token = token
def _clear_token():
if hasattr(handlers_module._request_local, "csrf_token"):
del handlers_module._request_local.csrf_token
def teardown_function(_):
_clear_token()
def test_rejects_missing_token_in_body():
_set_token("server-side-token")
assert _check_csrf({}) is False
def test_rejects_empty_token_in_body():
_set_token("server-side-token")
assert _check_csrf({"_csrf": [""]}) is False
def test_rejects_mismatched_token():
_set_token("server-side-token")
assert _check_csrf({"_csrf": ["attacker-token"]}) is False
def test_accepts_matching_token():
_set_token("server-side-token")
assert _check_csrf({"_csrf": ["server-side-token"]}) is True
def test_rejects_when_server_token_missing():
"""If the server-side token is empty (shouldn't happen after dispatch_request
seeds it, but be defensive), the check must fail closed."""
_clear_token()
assert _check_csrf({"_csrf": ["anything"]}) is False
def test_csrf_field_renders_current_token():
_set_token("abc123")
field = _csrf_field()
assert 'name="_csrf"' in field
assert 'value="abc123"' in field
def test_get_csrf_token_returns_empty_when_unset():
_clear_token()
assert _get_csrf_token() == ""

155
tests/test_db_index_url.py Normal file
View file

@ -0,0 +1,155 @@
"""Tests for `index_url` — the main write path.
Covers UPSERT behavior, links being replaced on re-index, FTS index staying
in sync via triggers, and the connection pool returning clean connections.
"""
from unittest.mock import patch
from conftest import patch_dns_ok
import db as db_module
from db import get_db, return_db, index_url
def _mock_fetch_page(title="Test Page", body="test body text", links=None, meta=""):
"""Return a replacement for db.fetch_page that yields canned data."""
links = links or []
def fake(url):
return (title, body, links, meta)
return fake
def test_insert_creates_page_row_and_fts_entry(temp_db, monkeypatch):
patch_dns_ok(monkeypatch)
monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(
title="Rust Intro", body="ownership and borrowing basics", links=[],
))
index_url("https://example.com/rust")
db = get_db()
try:
row = db.execute("SELECT id, title, body FROM pages").fetchone()
assert row is not None
assert row["title"] == "Rust Intro"
assert "ownership" in row["body"]
# Verify FTS trigger fired.
fts_hits = db.execute(
"SELECT rowid FROM pages_fts WHERE pages_fts MATCH 'ownership*'"
).fetchall()
assert len(fts_hits) == 1
assert fts_hits[0]["rowid"] == row["id"]
finally:
return_db(db)
def test_re_indexing_same_url_updates_in_place(temp_db, monkeypatch):
patch_dns_ok(monkeypatch)
monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(
title="First Title", body="first body", links=[],
))
index_url("https://example.com/page")
monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(
title="Second Title", body="second body", links=[],
))
index_url("https://example.com/page")
db = get_db()
try:
rows = db.execute("SELECT title, body FROM pages").fetchall()
finally:
return_db(db)
assert len(rows) == 1, "re-indexing should UPDATE not INSERT"
assert rows[0]["title"] == "Second Title"
def test_links_replaced_on_reindex(temp_db, monkeypatch):
patch_dns_ok(monkeypatch)
monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(
title="T", body="b",
links=[("https://example.com/a", "first"), ("https://example.com/b", "second")],
))
index_url("https://example.com/src")
monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(
title="T", body="b",
links=[("https://example.com/c", "third-only")],
))
index_url("https://example.com/src")
db = get_db()
try:
rows = db.execute("SELECT url FROM links").fetchall()
finally:
return_db(db)
urls = {r["url"] for r in rows}
assert urls == {"https://example.com/c"}, "old links should be deleted on reindex"
def test_url_cleaned_before_insert(temp_db, monkeypatch):
"""index_url should apply clean_url before touching the DB, so tracking params
don't create duplicate rows."""
patch_dns_ok(monkeypatch)
monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(title="T", body="b"))
index_url("https://example.com/page?utm_source=twitter#frag")
db = get_db()
try:
rows = db.execute("SELECT url FROM pages").fetchall()
finally:
return_db(db)
assert len(rows) == 1
assert rows[0]["url"] == "https://example.com/page"
def test_summary_populated_from_meta_description(temp_db, monkeypatch):
patch_dns_ok(monkeypatch)
long_meta = "A thoughtful description that exceeds twenty chars"
monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(
title="T", body="b", meta=long_meta,
))
index_url("https://example.com/page")
db = get_db()
try:
row = db.execute("SELECT summary FROM pages").fetchone()
finally:
return_db(db)
assert row["summary"] == long_meta
def test_short_meta_description_not_stored_as_summary(temp_db, monkeypatch):
patch_dns_ok(monkeypatch)
monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(
title="T", body="b", meta="too short",
))
index_url("https://example.com/page")
db = get_db()
try:
row = db.execute("SELECT summary FROM pages").fetchone()
finally:
return_db(db)
assert row["summary"] == ""
def test_pool_returns_clean_connection(temp_db, monkeypatch):
"""Regression for 1bc695f — `return_db` should roll back uncommitted work
so the next consumer doesn't see stale state."""
patch_dns_ok(monkeypatch)
monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(title="T", body="b"))
index_url("https://example.com/one")
# Take a connection, make a dirty uncommitted change, return it.
db = get_db()
db.execute("INSERT INTO pages (url, title, body) VALUES (?, ?, ?)",
("https://dirty.example.com/", "dirty", "dirty"))
# NOTE: no commit here — this is the dirty state we want rolled back.
return_db(db)
# A later consumer must not see the dirty row.
db2 = get_db()
try:
urls = {r["url"] for r in db2.execute("SELECT url FROM pages").fetchall()}
finally:
return_db(db2)
assert "https://dirty.example.com/" not in urls

90
tests/test_db_schema.py Normal file
View file

@ -0,0 +1,90 @@
"""Tests for `init_db` and the settings key-value store.
`init_db` is called unconditionally on startup, so it must be idempotent
and create every table/trigger the rest of the app expects.
"""
from db import get_db, return_db, init_db, get_setting, set_setting, get_site_name
EXPECTED_TABLES = {
"pages", "links", "settings", "subscriptions",
"remote_pages", "tags", "page_tags", "chunks",
# FTS5 virtual tables:
"pages_fts", "remote_pages_fts",
}
def test_all_expected_tables_exist(temp_db):
db = get_db()
try:
rows = db.execute(
"SELECT name FROM sqlite_master WHERE type IN ('table') AND name NOT LIKE 'sqlite_%'"
).fetchall()
names = {r["name"] for r in rows}
finally:
return_db(db)
missing = EXPECTED_TABLES - names
assert not missing, f"tables missing after init_db: {missing}"
def test_fts_triggers_exist(temp_db):
db = get_db()
try:
rows = db.execute(
"SELECT name FROM sqlite_master WHERE type = 'trigger'"
).fetchall()
names = {r["name"] for r in rows}
finally:
return_db(db)
# These triggers keep pages_fts in sync with pages on insert/update/delete.
for trigger in ("pages_ai", "pages_ad", "pages_au"):
assert trigger in names, f"missing trigger {trigger}"
def test_init_db_is_idempotent(temp_db):
"""Running init_db twice on the same DB must not error or duplicate anything."""
init_db()
init_db() # second call should be a no-op
db = get_db()
try:
count = db.execute(
"SELECT count(*) FROM sqlite_master WHERE name = 'pages'"
).fetchone()[0]
finally:
return_db(db)
assert count == 1
def test_get_setting_returns_default_when_missing(temp_db):
assert get_setting("nonexistent", "fallback") == "fallback"
assert get_setting("nonexistent") == ""
def test_set_setting_then_get(temp_db):
set_setting("site_name", "my-personal-index")
assert get_setting("site_name") == "my-personal-index"
def test_set_setting_updates_existing(temp_db):
set_setting("key", "first")
set_setting("key", "second")
assert get_setting("key") == "second"
def test_get_site_name_has_default(temp_db):
assert get_site_name() == "tinyweb"
def test_get_site_name_reflects_override(temp_db):
set_setting("site_name", "custom-site")
assert get_site_name() == "custom-site"
def test_foreign_keys_pragma_enabled(temp_db):
"""Pool connections should have foreign_keys=ON so CASCADE deletes work."""
db = get_db()
try:
row = db.execute("PRAGMA foreign_keys").fetchone()
finally:
return_db(db)
assert row[0] == 1

113
tests/test_fts_sanitizer.py Normal file
View file

@ -0,0 +1,113 @@
"""Tests for `_sanitize_fts_query`.
The sanitizer is the boundary between user input and FTS5 MATCH syntax.
Commit 1bc695f tightened it after noticing that colons and operator words
could escape the quoting. These tests keep that regression dead.
"""
import pytest
from handlers import _sanitize_fts_query
def test_empty_query_returns_no_match_token():
assert _sanitize_fts_query("") == '""'
assert _sanitize_fts_query(" ") == '""'
def test_single_word_becomes_prefix_match():
assert _sanitize_fts_query("rust") == "rust*"
def test_multi_word_quotes_all_but_last():
result = _sanitize_fts_query("rust borrow checker")
assert result == '"rust" "borrow" checker*'
def test_stopwords_are_dropped():
# "the" and "a" should vanish; only "cat" remains (and gets prefix star).
assert _sanitize_fts_query("the a cat") == "cat*"
def test_all_stopwords_returns_no_match_token():
assert _sanitize_fts_query("the and or") == '""'
@pytest.mark.parametrize("bad_char", ["'", "(", ")", "+", "-", "^", "~", ":"])
def test_fts5_operators_stripped_from_tokens(bad_char):
"""FTS5 special chars inside user tokens must not survive — regression for 1bc695f.
The sanitizer legitimately adds `"` around tokens and a trailing `*` for prefix
matching; both are excluded from this check.
"""
payload = f"foo{bad_char}bar"
out = _sanitize_fts_query(payload)
assert bad_char not in out, f"{bad_char!r} leaked into {out!r}"
def test_asterisk_only_appears_as_trailing_prefix():
"""Input `*` should not become an in-token asterisk; the sanitizer's trailing `*` is fine."""
out = _sanitize_fts_query("foo*bar")
assert out.count("*") <= 1
if "*" in out:
assert out.endswith("*")
def test_quote_in_input_does_not_break_out_of_quoted_token():
"""A `"` in user input must not close the sanitizer's protective quoting.
The sanitizer wraps each non-last token in double quotes; if a stray `"` from
the user slipped through, the resulting FTS5 expression would be interpreted
as broken syntax or, worse, a column filter.
"""
out = _sanitize_fts_query('foo"bar baz"qux')
# Each pair of quotes in the output should be balanced and around a clean token.
assert out.count('"') % 2 == 0
# No embedded quotes inside a quoted region.
import re
for match in re.findall(r'"[^"]*"', out):
inner = match[1:-1]
assert '"' not in inner
@pytest.mark.parametrize("op", ["AND", "OR", "NOT", "NEAR", "and", "or", "not", "near"])
def test_fts5_operator_words_dropped(op):
"""AND/OR/NOT/NEAR would be interpreted as operators on the unquoted last token."""
out = _sanitize_fts_query(f"foo {op} bar")
# the operator word itself should not appear
assert op.upper() not in out.upper().split('"'), f"operator {op!r} survived in {out!r}"
def test_injection_payload_produces_valid_fts5():
"""End-to-end: a realistic injection payload must produce syntactically valid FTS5.
We run the sanitized output through a throwaway FTS5 table; if the sanitizer
leaks operator characters the MATCH either raises or interprets malicious syntax.
"""
import sqlite3
conn = sqlite3.connect(":memory:")
conn.execute("CREATE VIRTUAL TABLE t USING fts5(body)")
conn.execute("INSERT INTO t (body) VALUES ('hello world')")
for payload in [
'foo": OR bar NOT baz AND qux*()',
'" OR 1=1 --',
"title:secret AND public",
"(((",
"^^^~~~",
]:
q = _sanitize_fts_query(payload)
# Must not raise — if operators leaked, FTS5 would error or mis-parse.
conn.execute("SELECT * FROM t WHERE t MATCH ?", (q,)).fetchall()
conn.close()
def test_whitespace_only_tokens_dropped():
# tokens that become empty after stripping special chars should not produce bare quotes
out = _sanitize_fts_query('""" "" ""')
assert out == '""'
def test_colon_stripped():
"""Regression for 1bc695f — colon is an FTS5 column filter and must be stripped."""
out = _sanitize_fts_query("title:secret")
assert ":" not in out

View file

@ -0,0 +1,164 @@
"""Tests for gateway-level guards: body-size cap and Reticulum surface whitelist.
Regression targets from commit 1bc695f a 16 MiB upload limit (DoS guard)
and a strict GET-/api/sites-only whitelist for requests arriving over the
Reticulum mesh (CSRF can't protect mesh callers, so gate by whitelist).
"""
import io
import pytest
import app as app_module
from gateway import GatewayHandler, MAX_BODY_SIZE
class FakeHeaders:
"""Minimal replacement for http.server request headers."""
def __init__(self, items=None):
self._items = dict(items or {})
def get(self, key, default=None):
return self._items.get(key, default)
class FakeGatewayHandler(GatewayHandler):
"""Bypass the socket-bound __init__ and capture response calls in memory."""
def __init__(self, path="/", method="POST", headers=None, rfile=None):
self.path = path
self.command = method
self.headers = FakeHeaders(headers or {})
self.rfile = rfile or io.BytesIO()
self.wfile = io.BytesIO()
self._captured = {
"error": None, "status": None, "headers": [], "body_written": None,
}
def send_error(self, code, msg=""):
self._captured["error"] = (code, msg)
def send_response(self, code):
self._captured["status"] = code
def send_header(self, k, v):
self._captured["headers"].append((k, v))
def end_headers(self):
pass
def test_post_over_size_cap_rejected_with_413():
"""Regression for 1bc695f: request bodies over MAX_BODY_SIZE must be rejected
without being read into memory."""
oversize = MAX_BODY_SIZE + 1
handler = FakeGatewayHandler(
path="/add",
method="POST",
headers={"Content-Length": str(oversize)},
)
handler._forward("POST")
assert handler._captured["error"] is not None
code, _msg = handler._captured["error"]
assert code == 413
def test_post_at_size_cap_accepted():
"""A body exactly at MAX_BODY_SIZE should not be rejected by the size check."""
handler = FakeGatewayHandler(
path="/_does_not_matter",
method="POST",
headers={"Content-Length": str(MAX_BODY_SIZE)},
# rfile has no data; handler will try to read; local_dispatch isn't set.
# We only care that the 413 check passes, not that the request succeeds.
rfile=io.BytesIO(b""),
)
# Stub out local_dispatch so _forward doesn't try the network path.
from gateway import GatewayState
original = GatewayState.local_dispatch
GatewayState.local_dispatch = lambda data: {
"status": 404, "content_type": "text/plain", "body": "nope",
}
try:
handler._forward("POST")
finally:
GatewayState.local_dispatch = original
# Not a 413, because the body is exactly at the cap (cap is inclusive).
if handler._captured["error"]:
assert handler._captured["error"][0] != 413
def test_negative_content_length_rejected():
handler = FakeGatewayHandler(
path="/add",
method="POST",
headers={"Content-Length": "-1"},
)
handler._forward("POST")
assert handler._captured["error"] is not None
code, _msg = handler._captured["error"]
assert code == 400
def test_invalid_content_length_rejected():
handler = FakeGatewayHandler(
path="/add",
method="POST",
headers={"Content-Length": "abc"},
)
handler._forward("POST")
assert handler._captured["error"] is not None
code, _msg = handler._captured["error"]
assert code == 400
# -------- Reticulum mesh surface whitelist --------
def test_mesh_rejects_non_api_sites_get():
"""Regression for 1bc695f: remote mesh callers can only GET /api/sites."""
resp = app_module.rns_request_handler(
path="/tinyweb",
data={"method": "GET", "path": "/pages", "query": {}, "body": {}, "gateway_host": ""},
request_id="x", link_id="y", remote_identity=None, requested_at=0,
)
assert resp["status"] == 403
def test_mesh_rejects_post_to_api_sites():
resp = app_module.rns_request_handler(
path="/tinyweb",
data={"method": "POST", "path": "/api/sites", "query": {}, "body": {}, "gateway_host": ""},
request_id="x", link_id="y", remote_identity=None, requested_at=0,
)
assert resp["status"] == 403
def test_mesh_rejects_sensitive_local_endpoints():
for path in ("/add", "/delete/1", "/style", "/import", "/export"):
resp = app_module.rns_request_handler(
path="/tinyweb",
data={"method": "GET", "path": path, "query": {}, "body": {}, "gateway_host": ""},
request_id="x", link_id="y", remote_identity=None, requested_at=0,
)
assert resp["status"] == 403, f"path {path!r} leaked through mesh whitelist"
def test_mesh_allows_api_sites_get(temp_db, csrf_session):
"""Sanity check: the one whitelisted combination is accepted."""
resp = app_module.rns_request_handler(
path="/tinyweb",
data={"method": "GET", "path": "/api/sites", "query": {}, "body": {}, "gateway_host": ""},
request_id="x", link_id="y", remote_identity=None, requested_at=0,
)
# Status depends on handler output; 200 is the happy path.
assert resp["status"] in (200, 403) # 403 if sharing is disabled by default
def test_mesh_handles_missing_data_payload():
"""Regression-minded check: a None or malformed data object shouldn't crash."""
resp = app_module.rns_request_handler(
path="/tinyweb",
data=None,
request_id="x", link_id="y", remote_identity=None, requested_at=0,
)
# Default data has method=GET, path=/ which is not in the whitelist.
assert resp["status"] == 403

View file

@ -0,0 +1,174 @@
"""Tests for `handle_bulk_action`, edit flow, and the bulk-delete confirm step.
The bulk-delete confirmation flow is a data-loss guard added in commit
8dffd8c a stray POST without `confirmed=1` must render the confirmation
page instead of actually deleting.
"""
from db import get_db, return_db
from handlers import (
handle_bulk_action,
handle_edit_form,
handle_edit_submit,
handle_pages,
)
def _all_urls(seeded_db):
db = get_db()
try:
return {r["url"] for r in db.execute("SELECT url FROM pages").fetchall()}
finally:
return_db(db)
def _page_id(seeded_db, url):
db = get_db()
try:
return db.execute("SELECT id FROM pages WHERE url = ?", (url,)).fetchone()["id"]
finally:
return_db(db)
def test_bulk_delete_without_confirmed_renders_confirm_page(seeded_db, csrf_session):
"""Regression for 8dffd8c: bulk delete must NOT delete until confirmed=1 is set."""
pid = _page_id(seeded_db, "https://example.com/rust-intro")
urls_before = _all_urls(seeded_db)
resp = handle_bulk_action({
"ids": [str(pid)],
"action": ["delete"],
})
assert resp["status"] == 200
assert "confirm delete" in resp["body"].lower()
assert "Rust Intro" in resp["body"]
# Must still show a hidden confirmed=1 field in the follow-up form.
assert 'name="confirmed" value="1"' in resp["body"]
# Crucially: nothing should have been deleted.
assert _all_urls(seeded_db) == urls_before
def test_bulk_delete_with_confirmed_actually_deletes(seeded_db, csrf_session):
pid = _page_id(seeded_db, "https://example.com/rust-intro")
resp = handle_bulk_action({
"ids": [str(pid)],
"action": ["delete"],
"confirmed": ["1"],
})
# Confirmed delete redirects back to /pages.
assert resp["status"] in (302, 303)
urls = _all_urls(seeded_db)
assert "https://example.com/rust-intro" not in urls
# Other pages untouched.
assert "https://example.com/python-tips" in urls
def test_bulk_delete_with_no_ids_redirects(seeded_db, csrf_session):
resp = handle_bulk_action({
"ids": [],
"action": ["delete"],
"confirmed": ["1"],
})
assert resp["status"] in (302, 303)
assert _all_urls(seeded_db) == {
"https://example.com/rust-intro",
"https://example.com/python-tips",
"https://example.com/ocaml-why",
"https://news.example.org/mesh",
}
def test_bulk_delete_rejects_non_integer_ids(seeded_db, csrf_session):
resp = handle_bulk_action({
"ids": ["not-a-number"],
"action": ["delete"],
"confirmed": ["1"],
})
assert resp["status"] == 400
def test_bulk_retag_add_mode_merges_tags(seeded_db, csrf_session):
pid = _page_id(seeded_db, "https://example.com/python-tips")
handle_bulk_action({
"ids": [str(pid)],
"action": ["retag"],
"bulk_tags": ["scripting, tutorials"],
"tag_mode": ["add"],
})
db = get_db()
try:
rows = db.execute(
"SELECT t.name FROM tags t JOIN page_tags pt ON pt.tag_id = t.id "
"WHERE pt.page_id = ? ORDER BY t.name",
(pid,),
).fetchall()
finally:
return_db(db)
tags = [r["name"] for r in rows]
assert "python" in tags # existing kept
assert "scripting" in tags # new added
assert "tutorials" in tags
def test_bulk_retag_replace_mode_overwrites_tags(seeded_db, csrf_session):
pid = _page_id(seeded_db, "https://example.com/python-tips")
handle_bulk_action({
"ids": [str(pid)],
"action": ["retag"],
"bulk_tags": ["one, two"],
"tag_mode": ["replace"],
})
db = get_db()
try:
rows = db.execute(
"SELECT t.name FROM tags t JOIN page_tags pt ON pt.tag_id = t.id "
"WHERE pt.page_id = ?",
(pid,),
).fetchall()
finally:
return_db(db)
tags = {r["name"] for r in rows}
assert tags == {"one", "two"}
assert "python" not in tags
def test_edit_form_renders_current_values(seeded_db, csrf_session):
pid = _page_id(seeded_db, "https://example.com/rust-intro")
resp = handle_edit_form(pid)
assert resp["status"] == 200
assert "Rust Intro" in resp["body"]
# Existing tags should appear in the tag field.
assert "rust" in resp["body"]
def test_edit_form_404_for_unknown_page(temp_db, csrf_session):
resp = handle_edit_form(99999)
assert resp["status"] == 404
def test_edit_submit_updates_title_and_note(seeded_db, csrf_session):
pid = _page_id(seeded_db, "https://example.com/rust-intro")
handle_edit_submit(pid, {
"title": ["New Rust Title"],
"note": ["new annotation"],
"tags": ["rust, updated"],
})
db = get_db()
try:
row = db.execute("SELECT title, note FROM pages WHERE id = ?", (pid,)).fetchone()
finally:
return_db(db)
assert row["title"] == "New Rust Title"
assert row["note"] == "new annotation"
def test_handle_pages_lists_indexed_pages(seeded_db, csrf_session):
resp = handle_pages({})
assert resp["status"] == 200
# Every seeded page title appears on the list page.
for title in ("Rust Intro", "Python Tips", "Why OCaml", "Mesh Networking"):
assert title in resp["body"]

View file

@ -0,0 +1,63 @@
"""Tests for `handle_search` — the home page + primary user flow."""
from handlers import handle_search
def test_empty_index_empty_query_shows_welcome(temp_db, csrf_session):
resp = handle_search({})
assert resp["status"] == 200
body = resp["body"]
assert "Your index is empty" in body
# Links the welcome panel offers as equal-weight starting points.
assert "/add" in body
assert "/style" in body
assert "/subscriptions" in body
def test_empty_index_with_query_shows_no_results(temp_db, csrf_session):
resp = handle_search({"q": ["rust"]})
assert resp["status"] == 200
assert "No results in your index" in resp["body"]
def test_populated_index_with_matching_query_returns_results(seeded_db, csrf_session):
resp = handle_search({"q": ["rust"]})
assert resp["status"] == 200
assert "Rust Intro" in resp["body"]
# Page count shown in meta line.
assert "4 pages indexed" in resp["body"]
def test_query_only_matches_relevant_pages(seeded_db, csrf_session):
resp = handle_search({"q": ["ocaml"]})
body = resp["body"]
assert "Why OCaml" in body
assert "Python Tips" not in body
assert "Rust Intro" not in body
def test_pagination_query_param_respected(seeded_db, csrf_session):
"""A high page number should still render without crashing."""
resp = handle_search({"q": ["example"], "p": ["99"]})
assert resp["status"] == 200
def test_trusted_sites_fallback_surfaces_when_query_matches_link_label(seeded_db, csrf_session):
"""Links extracted from indexed pages act as a fallback when direct results
are absent or thin; labels are substring-matched case-insensitively."""
resp = handle_search({"q": ["advanced"]})
body = resp["body"]
# The label "advanced rust guide" is on a link extracted from rust-intro.
assert "advanced rust guide" in body
assert "trusted sites" in body
def test_page_count_in_meta_line(seeded_db, csrf_session):
resp = handle_search({})
assert "4 pages indexed" in resp["body"]
def test_csp_and_security_headers_not_in_handler_but_via_dispatch(seeded_db, csrf_session):
"""Handler itself returns no security headers; dispatch_request wraps them.
This test documents the boundary so future refactors don't break assumptions."""
resp = handle_search({})
assert "headers" not in resp or "Content-Security-Policy" not in resp.get("headers", {})

112
tests/test_handlers_subs.py Normal file
View file

@ -0,0 +1,112 @@
"""Tests for subscription handlers.
Subscription add validates the destination hash (32-char hex) locally
before calling `fetch_remote_sites`; browse uses cached remote_pages when
available and falls back to a live fetch otherwise.
"""
from unittest.mock import patch
import handlers as handlers_module
from db import get_db, return_db
from handlers import handle_subscription_add, handle_subscription_browse
VALID_HASH = "a" * 32
def _subscription_count():
db = get_db()
try:
return db.execute("SELECT count(*) FROM subscriptions").fetchone()[0]
finally:
return_db(db)
def test_rejects_empty_dest_hash(temp_db, csrf_session):
resp = handle_subscription_add({"dest_hash": [""]})
assert "32-character" in resp["body"]
assert _subscription_count() == 0
def test_rejects_wrong_length(temp_db, csrf_session):
resp = handle_subscription_add({"dest_hash": ["abc123"]})
assert "32-character" in resp["body"]
assert _subscription_count() == 0
def test_rejects_non_hex(temp_db, csrf_session):
resp = handle_subscription_add({"dest_hash": ["z" * 32]})
assert "hex" in resp["body"].lower()
assert _subscription_count() == 0
def test_rejects_unreachable_peer(temp_db, csrf_session):
with patch.object(handlers_module, "fetch_remote_sites") as fetch:
fetch.side_effect = ConnectionError("unreachable")
resp = handle_subscription_add({"dest_hash": [VALID_HASH]})
assert "Could not reach" in resp["body"]
assert _subscription_count() == 0
def test_rejects_peer_with_sharing_disabled(temp_db, csrf_session):
with patch.object(handlers_module, "fetch_remote_sites") as fetch:
fetch.side_effect = PermissionError("sharing disabled")
resp = handle_subscription_add({"dest_hash": [VALID_HASH]})
assert "sharing disabled" in resp["body"]
assert _subscription_count() == 0
def test_successful_add_records_subscription(temp_db, csrf_session):
with patch.object(handlers_module, "fetch_remote_sites") as fetch:
fetch.return_value = {"name": "alice", "sites": []}
resp = handle_subscription_add({"dest_hash": [VALID_HASH]})
assert "Subscribed to alice" in resp["body"]
assert _subscription_count() == 1
def test_dest_hash_strips_angle_brackets(temp_db, csrf_session):
"""Users often paste hashes as `<aaa...>` from RNS log output; strip them."""
with patch.object(handlers_module, "fetch_remote_sites") as fetch:
fetch.return_value = {"name": "bob", "sites": []}
resp = handle_subscription_add({"dest_hash": [f"<{VALID_HASH}>"]})
assert _subscription_count() == 1
def test_browse_unknown_subscription_is_404(temp_db, csrf_session):
resp = handle_subscription_browse(99999)
assert resp["status"] == 404
def test_browse_marks_already_indexed_urls(seeded_db, csrf_session):
# Insert a subscription + some remote pages (one duplicate of local, one new).
db = get_db()
try:
db.execute(
"INSERT INTO subscriptions (dest_hash, name) VALUES (?, ?)",
(VALID_HASH, "alice"),
)
sub_id = db.execute("SELECT id FROM subscriptions").fetchone()["id"]
db.execute(
"INSERT INTO remote_pages (subscription_id, url, title, note, tags) "
"VALUES (?, ?, ?, ?, ?)",
(sub_id, "https://example.com/rust-intro", "Alice rust pick", "", ""),
)
db.execute(
"INSERT INTO remote_pages (subscription_id, url, title, note, tags) "
"VALUES (?, ?, ?, ?, ?)",
(sub_id, "https://new.example.com/shiny", "Shiny New Link", "note", "tag1"),
)
db.commit()
finally:
return_db(db)
resp = handle_subscription_browse(sub_id)
body = resp["body"]
assert resp["status"] == 200
assert "already indexed" in body
# The duplicate URL should appear in the "already indexed" section.
assert "Alice rust pick" in body
# The new URL should be in the selectable section.
assert "Shiny New Link" in body
# Count summary: "2 site(s) available, 1 new"
assert "1 new" in body

101
tests/test_handlers_tags.py Normal file
View file

@ -0,0 +1,101 @@
"""Tests for tag helpers and the tag browse handler.
Tags are stored via a join table, so orphaned rows in `tags` can accumulate
if `_cleanup_orphaned_tags` isn't called after deletion/retagging. Tag
counts shown in the UI rely on this being right.
"""
from db import get_db, return_db
from handlers import (
_cleanup_orphaned_tags,
_get_page_tags,
_set_page_tags,
handle_tag_browse,
handle_tags,
)
def _page_id(url):
db = get_db()
try:
row = db.execute("SELECT id FROM pages WHERE url = ?", (url,)).fetchone()
return row["id"] if row else None
finally:
return_db(db)
def _tag_names():
db = get_db()
try:
return {r["name"] for r in db.execute("SELECT name FROM tags").fetchall()}
finally:
return_db(db)
def test_get_page_tags_returns_sorted_names(seeded_db):
pid = _page_id("https://example.com/rust-intro")
tags = _get_page_tags(pid)
assert tags == sorted(tags) # alphabetical
assert "rust" in tags
assert "public" in tags
def test_set_page_tags_replaces_existing(seeded_db):
pid = _page_id("https://example.com/rust-intro")
db = get_db()
try:
_set_page_tags(pid, "brand, new, tags", db)
db.commit()
finally:
return_db(db)
current = _get_page_tags(pid)
assert current == ["brand", "new", "tags"]
def test_set_page_tags_splits_on_comma_and_lowercases(seeded_db):
pid = _page_id("https://example.com/python-tips")
db = get_db()
try:
_set_page_tags(pid, "Foo, BAR, baz", db)
db.commit()
finally:
return_db(db)
assert set(_get_page_tags(pid)) == {"foo", "bar", "baz"}
def test_cleanup_orphaned_tags_removes_unreferenced(seeded_db):
# Clear all tags on one page; previously-unique tags become orphans.
pid = _page_id("https://example.com/rust-intro")
db = get_db()
try:
_set_page_tags(pid, "", db) # empty string = no tags
# `rust` was only on the rust-intro page; `public` is also on mesh.
_cleanup_orphaned_tags(db)
db.commit()
finally:
return_db(db)
names = _tag_names()
assert "rust" not in names # pruned
assert "public" in names # still on mesh
def test_handle_tag_browse_filters_by_tag(seeded_db, csrf_session):
resp = handle_tag_browse("rust", {})
assert resp["status"] == 200
body = resp["body"]
assert "Rust Intro" in body
assert "Python Tips" not in body
assert "Why OCaml" not in body
def test_handle_tag_browse_unknown_tag_is_graceful(seeded_db, csrf_session):
resp = handle_tag_browse("no-such-tag", {})
# Should render a valid page with zero results, not error.
assert resp["status"] == 200
def test_handle_tags_lists_all_tags_with_counts(seeded_db, csrf_session):
resp = handle_tags()
assert resp["status"] == 200
body = resp["body"]
for tag in ("rust", "python", "ocaml", "mesh", "public", "private"):
assert tag in body

View file

@ -0,0 +1,138 @@
"""Tests for link extraction inside `fetch_page`.
Link extraction powers the "trusted sites" fallback on empty searches and
feeds the `links` table. Rules: same-domain only, skip binary extensions,
skip Wikipedia special pages, resolve relatives via urljoin.
"""
from unittest.mock import patch
from conftest import patch_dns_ok
import db as db_module
class FakeResponse:
def __init__(self, text, status_code=200):
self.text = text
self.status_code = status_code
self.is_redirect = False
self.headers = {}
def raise_for_status(self):
if self.status_code >= 400:
raise Exception(f"status {self.status_code}")
def _fetch_with_html(monkeypatch, url, html):
"""Invoke fetch_page against `url` with `html` as the mocked response body."""
patch_dns_ok(monkeypatch)
with patch.object(db_module, "requests") as mock_requests:
mock_requests.get.return_value = FakeResponse(html)
return db_module.fetch_page(url)
def test_only_same_domain_links_kept(monkeypatch):
html = """
<html><body>
<a href="https://example.com/a">same</a>
<a href="https://other.com/b">cross</a>
<a href="https://sub.example.com/c">subdomain</a>
</body></html>
"""
_, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html)
urls = [u for u, _label in links]
assert "https://example.com/a" in urls
assert "https://other.com/b" not in urls
assert "https://sub.example.com/c" not in urls
def test_binary_extensions_skipped(monkeypatch):
html = """
<html><body>
<a href="/real-page">keep</a>
<a href="/image.png">skip</a>
<a href="/doc.pdf">skip</a>
<a href="/archive.zip">skip</a>
<a href="/song.mp3">skip</a>
<a href="/styles.css">skip</a>
</body></html>
"""
_, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html)
urls = [u for u, _label in links]
assert "https://example.com/real-page" in urls
for ext in (".png", ".pdf", ".zip", ".mp3", ".css"):
assert not any(u.endswith(ext) for u in urls), f"{ext} leaked through"
def test_wikipedia_special_pages_skipped(monkeypatch):
html = """
<html><body>
<a href="/wiki/Main_Page">keep</a>
<a href="/wiki/Special:Random">skip</a>
<a href="/wiki/Talk:Foo">skip</a>
<a href="/wiki/User:Jimbo">skip</a>
<a href="/wiki/Category:Bar">skip</a>
</body></html>
"""
_, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html)
urls = [u for u, _label in links]
assert "https://example.com/wiki/Main_Page" in urls
for skip in ("Special:Random", "Talk:Foo", "User:Jimbo", "Category:Bar"):
assert not any(skip in u for u in urls), f"wiki {skip!r} leaked"
def test_relative_urls_resolved(monkeypatch):
html = """<html><body><a href="/relative/path">r</a></body></html>"""
_, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/start", html)
urls = [u for u, _label in links]
assert "https://example.com/relative/path" in urls
def test_fragment_stripped_from_extracted_links(monkeypatch):
html = """<html><body><a href="/page#section">r</a></body></html>"""
_, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html)
urls = [u for u, _label in links]
assert "https://example.com/page" in urls
assert not any("#" in u for u in urls)
def test_duplicate_links_deduped(monkeypatch):
html = """
<html><body>
<a href="/a">first</a>
<a href="/a">second</a>
<a href="/a">third</a>
</body></html>
"""
_, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html)
urls = [u for u, _label in links]
assert urls.count("https://example.com/a") == 1
def test_label_truncated_to_200(monkeypatch):
long_text = "x" * 500
html = f'<html><body><a href="/p">{long_text}</a></body></html>'
_, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html)
assert len(links) == 1
_, label = links[0]
assert len(label) <= 200
def test_meta_description_extracted(monkeypatch):
html = """
<html><head>
<meta name="description" content="the real description">
</head><body><p>body content</p></body></html>
"""
title, body, links, meta = _fetch_with_html(monkeypatch, "https://example.com/", html)
assert meta == "the real description"
def test_og_description_fallback(monkeypatch):
"""When there's no <meta name=description>, og:description wins."""
html = """
<html><head>
<meta property="og:description" content="open graph fallback">
</head><body><p>body</p></body></html>
"""
_, _, _, meta = _fetch_with_html(monkeypatch, "https://example.com/", html)
assert meta == "open graph fallback"

58
tests/test_pagination.py Normal file
View file

@ -0,0 +1,58 @@
"""Tests for `_paginate` and `_page_nav`."""
from handlers import _paginate, _page_nav, PER_PAGE
def test_paginate_default_is_one():
assert _paginate({}) == 1
def test_paginate_reads_query_string():
assert _paginate({"p": ["3"]}) == 3
def test_paginate_clamps_to_one():
assert _paginate({"p": ["0"]}) == 1
assert _paginate({"p": ["-5"]}) == 1
def test_paginate_handles_bad_input():
assert _paginate({"p": ["not-a-number"]}) == 1
assert _paginate({"p": []}) == 1
def test_paginate_custom_key():
assert _paginate({"batch": ["7"]}, key="batch") == 7
def test_page_nav_empty_when_single_page():
assert _page_nav(1, PER_PAGE, "/?q=foo") == ""
assert _page_nav(1, 0, "/?q=foo") == ""
def test_page_nav_shows_next_on_first_page():
out = _page_nav(1, PER_PAGE * 3, "/?q=foo")
assert "next" in out
assert "prev" not in out
assert "page 1 of 3" in out
def test_page_nav_shows_both_in_middle():
out = _page_nav(2, PER_PAGE * 3, "/?q=foo")
assert "next" in out
assert "prev" in out
def test_page_nav_shows_prev_on_last_page():
out = _page_nav(3, PER_PAGE * 3, "/?q=foo")
assert "next" not in out
assert "prev" in out
assert "page 3 of 3" in out
def test_page_nav_handles_query_string_separator():
# when base_url already has ?, pagination links must use &
out = _page_nav(1, PER_PAGE * 2, "/?q=foo")
assert "&p=2" in out
# when base_url has no ?, pagination links use ?
out = _page_nav(1, PER_PAGE * 2, "/pages")
assert "?p=2" in out

107
tests/test_regressions.py Normal file
View file

@ -0,0 +1,107 @@
"""Aggregator of regression tests tied to specific bug-fix commits.
Each test here guards against a specific bug that was once shipped. Running
just this file gives a one-line-per-bug audit:
pytest tests/test_regressions.py -v
The test bodies are intentionally small; for the exhaustive behavior of each
module, see the topical test files (test_fts_sanitizer.py, test_url_cleanup.py,
etc.). This file's job is to make the bug catalog scannable.
"""
import socket
from unittest.mock import patch
import pytest
import app as app_module
import db as db_module
import handlers as handlers_module
from conftest import patch_dns_fail, patch_dns_ok
from db import clean_url
from handlers import _sanitize_fts_query, handle_bulk_action
def test_6ffd38d_clean_url_preserves_www_when_bare_domain_fails(monkeypatch):
"""6ffd38d: `clean_url` used to strip `www.` unconditionally; for sites that
only serve at `www.`, this produced unreachable clean URLs."""
patch_dns_fail(monkeypatch)
assert clean_url("https://www.example.com/page") == "https://www.example.com/page"
def test_1bc695f_fts_sanitizer_strips_colon():
"""1bc695f: FTS5 colon is a column filter — must not appear in sanitized output."""
assert ":" not in _sanitize_fts_query("title:secret body:exposed")
@pytest.mark.parametrize("op", ["AND", "OR", "NOT", "NEAR"])
def test_1bc695f_fts_sanitizer_drops_operator_words(op):
"""1bc695f: operator words (AND/OR/NOT/NEAR) would be interpreted as FTS5
operators if they landed on the unquoted last token."""
out = _sanitize_fts_query(f"foo {op} bar")
# operator itself should not appear in the output
tokens = out.replace('"', '').split()
assert op not in [t.rstrip("*") for t in tokens]
def test_1bc695f_gateway_rejects_oversize_body():
"""1bc695f: 16 MiB body-size cap prevents memory-exhaustion DoS."""
from tests.test_gateway_limits import FakeGatewayHandler
from gateway import MAX_BODY_SIZE
h = FakeGatewayHandler(
path="/add", method="POST",
headers={"Content-Length": str(MAX_BODY_SIZE + 1)},
)
h._forward("POST")
assert h._captured["error"] and h._captured["error"][0] == 413
def test_1bc695f_mesh_rejects_non_whitelisted_paths():
"""1bc695f: Reticulum callers are limited to GET /api/sites; CSRF cannot
authenticate mesh callers."""
resp = app_module.rns_request_handler(
path="/tinyweb",
data={"method": "POST", "path": "/add", "query": {}, "body": {}, "gateway_host": ""},
request_id="x", link_id="y", remote_identity=None, requested_at=0,
)
assert resp["status"] == 403
def test_1bc695f_pool_returns_clean_connection(temp_db, monkeypatch):
"""1bc695f: uncommitted transactions on a pooled connection used to leak
into the next consumer."""
from db import get_db, return_db
db = get_db()
db.execute(
"INSERT INTO pages (url, title, body) VALUES (?, ?, ?)",
("https://leak.example.com/", "should not persist", "body"),
)
return_db(db) # no commit
db2 = get_db()
try:
urls = {r["url"] for r in db2.execute("SELECT url FROM pages").fetchall()}
finally:
return_db(db2)
assert "https://leak.example.com/" not in urls
def test_8dffd8c_bulk_delete_requires_confirmation(seeded_db, csrf_session):
"""8dffd8c: bulk delete without confirmed=1 must render a confirm page
instead of deleting the JS confirm on /pages is a first-line filter only."""
from db import get_db, return_db
db = get_db()
try:
pid = db.execute("SELECT id FROM pages LIMIT 1").fetchone()["id"]
count_before = db.execute("SELECT count(*) FROM pages").fetchone()[0]
finally:
return_db(db)
resp = handle_bulk_action({"ids": [str(pid)], "action": ["delete"]})
assert "confirm delete" in resp["body"].lower()
db = get_db()
try:
count_after = db.execute("SELECT count(*) FROM pages").fetchone()[0]
finally:
return_db(db)
assert count_before == count_after, "bulk delete ran without confirmation"

View file

@ -0,0 +1,38 @@
"""Tests for `_page_is_shared`.
This function decides whether a page is exposed over Reticulum to
subscribers. Getting it wrong means either a privacy leak or silently
hiding pages the user meant to share both are worth a regression net.
"""
import pytest
from handlers import _page_is_shared
@pytest.mark.parametrize("mode", ["exclude_private", "require_public"])
def test_private_tag_always_excludes(mode):
"""`private` tag overrides every mode — the most important invariant."""
assert _page_is_shared(["private"], mode) is False
assert _page_is_shared(["public", "private"], mode) is False
def test_exclude_private_defaults_to_shared():
assert _page_is_shared([], "exclude_private") is True
assert _page_is_shared(["random-tag"], "exclude_private") is True
def test_require_public_needs_public_tag():
assert _page_is_shared([], "require_public") is False
assert _page_is_shared(["rust"], "require_public") is False
assert _page_is_shared(["public"], "require_public") is True
def test_require_public_still_vetoes_private():
# public AND private → private wins.
assert _page_is_shared(["public", "private"], "require_public") is False
def test_unknown_mode_treated_as_exclude_private():
"""The default mode is 'exclude_private'; unknown modes fall through to it."""
assert _page_is_shared([], "totally-bogus-mode") is True
assert _page_is_shared(["private"], "totally-bogus-mode") is False

64
tests/test_ssrf.py Normal file
View file

@ -0,0 +1,64 @@
"""Tests for `_validate_url_target` — SSRF prevention.
Any URL the app fetches must resolve to a public IP; private/internal/
loopback addresses must be rejected so attacker-controlled URLs cannot
reach internal services via our HTTP client.
"""
import socket
from unittest.mock import patch
import pytest
from db import _validate_url_target
def _mock_getaddrinfo(address):
"""Return a function suitable as a socket.getaddrinfo replacement."""
def f(host, port, *args, **kwargs):
family = socket.AF_INET6 if ":" in address else socket.AF_INET
return [(family, socket.SOCK_STREAM, 0, "", (address, port or 80))]
return f
@pytest.mark.parametrize("blocked_ip", [
"127.0.0.1",
"127.1.2.3",
"10.0.0.1",
"10.255.255.255",
"172.16.0.1",
"172.31.255.255",
"192.168.0.1",
"192.168.255.255",
"169.254.169.254",
"0.0.0.0",
"::1",
"fc00::1",
"fe80::1",
])
def test_blocks_private_and_loopback(monkeypatch, blocked_ip):
monkeypatch.setattr(socket, "getaddrinfo", _mock_getaddrinfo(blocked_ip))
with pytest.raises(ValueError, match="blocked"):
_validate_url_target("https://evil.example.com/internal")
def test_allows_public_ipv4(monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", _mock_getaddrinfo("8.8.8.8"))
_validate_url_target("https://dns.example.com/") # does not raise
def test_allows_public_ipv6(monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", _mock_getaddrinfo("2001:4860:4860::8888"))
_validate_url_target("https://v6.example.com/") # does not raise
def test_rejects_unresolvable_hostname(monkeypatch):
def boom(*args, **kwargs):
raise socket.gaierror("no such host")
monkeypatch.setattr(socket, "getaddrinfo", boom)
with pytest.raises(ValueError, match="Cannot resolve"):
_validate_url_target("https://does-not-exist.example.com/")
def test_rejects_missing_hostname():
with pytest.raises(ValueError, match="No hostname"):
_validate_url_target("http:///path-only")

101
tests/test_url_cleanup.py Normal file
View file

@ -0,0 +1,101 @@
"""Tests for `clean_url` — URL normalization and tracking-param stripping.
Clean URLs are the deduplication key in the pages table, so any change to
this function can silently cause duplicate rows or mask legitimate saves.
"""
import pytest
from conftest import patch_dns_ok, patch_dns_fail
from db import clean_url, TRACKING_PARAMS
def test_strips_fragment(monkeypatch):
patch_dns_ok(monkeypatch)
assert clean_url("https://example.com/page#section") == "https://example.com/page"
def test_prefers_https(monkeypatch):
patch_dns_ok(monkeypatch)
assert clean_url("http://example.com/page") == "https://example.com/page"
def test_lowercases_hostname(monkeypatch):
patch_dns_ok(monkeypatch)
assert clean_url("https://EXAMPLE.COM/page") == "https://example.com/page"
def test_preserves_path_case(monkeypatch):
"""Paths are case-sensitive and should not be lowercased."""
patch_dns_ok(monkeypatch)
assert clean_url("https://example.com/Foo/Bar") == "https://example.com/Foo/Bar"
def test_strips_default_https_port(monkeypatch):
patch_dns_ok(monkeypatch)
assert clean_url("https://example.com:443/page") == "https://example.com/page"
@pytest.mark.xfail(reason="clean_url upgrades http->https before the port-default check, "
"so port 80 is not stripped. Minor dedup bug — harmless but worth fixing.")
def test_strips_http_port_80(monkeypatch):
"""Expected: http://foo:80 → https://foo (both scheme-upgrade and port-strip).
Currently fails because scheme is upgraded to https *before* the port check,
so `scheme == "http" and port == 80` is never true by the time the check runs.
"""
patch_dns_ok(monkeypatch)
assert clean_url("http://example.com:80/page") == "https://example.com/page"
def test_preserves_non_default_port(monkeypatch):
patch_dns_ok(monkeypatch)
assert clean_url("https://example.com:8443/page") == "https://example.com:8443/page"
def test_strips_trailing_slash(monkeypatch):
patch_dns_ok(monkeypatch)
assert clean_url("https://example.com/page/") == "https://example.com/page"
def test_root_slash_preserved(monkeypatch):
patch_dns_ok(monkeypatch)
assert clean_url("https://example.com/") == "https://example.com/"
@pytest.mark.parametrize("param", sorted(TRACKING_PARAMS))
def test_tracking_params_stripped(monkeypatch, param):
patch_dns_ok(monkeypatch)
result = clean_url(f"https://example.com/page?{param}=value&keep=yes")
assert param not in result
assert "keep=yes" in result
def test_strips_www_when_nonwww_resolves(monkeypatch):
"""Standard case: strip `www.` prefix to canonicalize."""
patch_dns_ok(monkeypatch)
assert clean_url("https://www.example.com/page") == "https://example.com/page"
def test_preserves_www_when_nonwww_does_not_resolve(monkeypatch):
"""Regression for 6ffd38d.
Some sites only serve their content at `www.domain.tld`; the bare domain
doesn't resolve. Stripping `www.` in that case produced a URL that we could
never actually fetch or dedupe against the real one.
"""
patch_dns_fail(monkeypatch)
assert clean_url("https://www.example.com/page") == "https://www.example.com/page"
def test_query_params_sorted_for_stable_ordering(monkeypatch):
"""Same URL with different param orderings should produce the same clean URL."""
patch_dns_ok(monkeypatch)
a = clean_url("https://example.com/page?b=2&a=1")
b = clean_url("https://example.com/page?a=1&b=2")
assert a == b
def test_path_and_query_preserved_through_cleanup(monkeypatch):
patch_dns_ok(monkeypatch)
result = clean_url("https://example.com/path/to/page?id=42&utm_source=twitter")
assert result == "https://example.com/path/to/page?id=42"