diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..9a2f26e --- /dev/null +++ b/conftest.py @@ -0,0 +1,128 @@ +"""Shared pytest fixtures for TinyWeb tests. + +Three fixtures cover most tests: `temp_db` swaps the SQLite path to a +per-test tempfile, `seeded_db` layers sample rows on top, and `csrf_session` +primes the thread-local CSRF token that handlers read. +""" +import socket +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).parent)) + +import db as db_module +import handlers as handlers_module + + +@pytest.fixture +def temp_db(tmp_path, monkeypatch): + """Isolated SQLite DB per test. + + Swaps `db.DATABASE` and `db.DATA_DIR` to a tempdir, clears the connection + pool before and after so state doesn't leak across tests, and calls + `init_db()` so every schema object exists. + """ + data_dir = tmp_path / "tinyweb" + data_dir.mkdir() + db_path = data_dir / "index.db" + + monkeypatch.setattr(db_module, "DATA_DIR", str(data_dir)) + monkeypatch.setattr(db_module, "DATABASE", str(db_path)) + + with db_module._pool_lock: + for conn in db_module._pool: + try: + conn.close() + except Exception: + pass + db_module._pool.clear() + + db_module.init_db() + yield db_path + + with db_module._pool_lock: + for conn in db_module._pool: + try: + conn.close() + except Exception: + pass + db_module._pool.clear() + + +@pytest.fixture +def seeded_db(temp_db): + """A temp DB with a small, realistic set of pages/tags/links.""" + db = db_module.get_db() + try: + rows = [ + ("https://example.com/rust-intro", "Rust Intro", "A gentle introduction to rust borrow checker.", "notes on ownership"), + ("https://example.com/python-tips", "Python Tips", "Daily python tricks for readable code.", ""), + ("https://example.com/ocaml-why", "Why OCaml", "Type systems and inference in ocaml.", "private thoughts"), + ("https://news.example.org/mesh", "Mesh Networking", "Reticulum and LoRa for decentralized networks.", ""), + ] + for url, title, body, note in rows: + db.execute( + "INSERT INTO pages (url, title, body, note, last_modified) " + "VALUES (?, ?, ?, ?, '2026-04-01T00:00:00')", + (url, title, body, note), + ) + db.commit() + page_ids = { + row["url"]: row["id"] + for row in db.execute("SELECT id, url FROM pages").fetchall() + } + tag_rows = [ + (page_ids["https://example.com/rust-intro"], ["rust", "public"]), + (page_ids["https://example.com/python-tips"], ["python"]), + (page_ids["https://example.com/ocaml-why"], ["ocaml", "private"]), + (page_ids["https://news.example.org/mesh"], ["mesh", "public"]), + ] + for pid, tags in tag_rows: + for name in tags: + db.execute("INSERT OR IGNORE INTO tags (name) VALUES (?)", (name,)) + tid = db.execute("SELECT id FROM tags WHERE name = ?", (name,)).fetchone()[0] + db.execute( + "INSERT OR IGNORE INTO page_tags (page_id, tag_id) VALUES (?, ?)", + (pid, tid), + ) + db.execute( + "INSERT INTO links (page_id, url, label) VALUES (?, ?, ?)", + (page_ids["https://example.com/rust-intro"], "https://example.com/rust-advanced", "advanced rust guide"), + ) + db.commit() + finally: + db_module.return_db(db) + return temp_db + + +@pytest.fixture +def csrf_session(monkeypatch): + """Prime the CSRF thread-local so handler code that calls _get_csrf_token works.""" + token = "test-csrf-token" + handlers_module._request_local.csrf_token = token + yield token + if hasattr(handlers_module._request_local, "csrf_token"): + del handlers_module._request_local.csrf_token + + +def patch_dns_fail(monkeypatch): + """Make every socket.getaddrinfo call raise gaierror for the rest of this test.""" + def boom(*args, **kwargs): + raise socket.gaierror("test: DNS disabled") + monkeypatch.setattr(socket, "getaddrinfo", boom) + + +def patch_dns_ok(monkeypatch, address="93.184.216.34"): + """Make every getaddrinfo return a single public IP for the rest of this test.""" + def ok(host, port, *args, **kwargs): + return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (address, port or 80))] + monkeypatch.setattr(socket, "getaddrinfo", ok) + + +def patch_dns_private(monkeypatch, address="127.0.0.1"): + """Make every getaddrinfo return a private/blocked IP for the rest of this test.""" + def private(host, port, *args, **kwargs): + return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (address, port or 80))] + monkeypatch.setattr(socket, "getaddrinfo", private) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..16d6cc5 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +testpaths = tests +python_files = test_*.py +filterwarnings = + ignore::DeprecationWarning diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..26b77f6 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,2 @@ +-r requirements.txt +pytest diff --git a/tests/test_csrf.py b/tests/test_csrf.py new file mode 100644 index 0000000..43b4487 --- /dev/null +++ b/tests/test_csrf.py @@ -0,0 +1,60 @@ +"""Tests for `_check_csrf` — form-submission CSRF protection. + +Every POST handler calls this to verify the submitted _csrf field matches +the token stored in the thread-local (which is seeded from the cookie by +`dispatch_request`). Missing or mismatched tokens must fail closed. +""" +import handlers as handlers_module +from handlers import _check_csrf, _csrf_field, _get_csrf_token + + +def _set_token(token): + handlers_module._request_local.csrf_token = token + + +def _clear_token(): + if hasattr(handlers_module._request_local, "csrf_token"): + del handlers_module._request_local.csrf_token + + +def teardown_function(_): + _clear_token() + + +def test_rejects_missing_token_in_body(): + _set_token("server-side-token") + assert _check_csrf({}) is False + + +def test_rejects_empty_token_in_body(): + _set_token("server-side-token") + assert _check_csrf({"_csrf": [""]}) is False + + +def test_rejects_mismatched_token(): + _set_token("server-side-token") + assert _check_csrf({"_csrf": ["attacker-token"]}) is False + + +def test_accepts_matching_token(): + _set_token("server-side-token") + assert _check_csrf({"_csrf": ["server-side-token"]}) is True + + +def test_rejects_when_server_token_missing(): + """If the server-side token is empty (shouldn't happen after dispatch_request + seeds it, but be defensive), the check must fail closed.""" + _clear_token() + assert _check_csrf({"_csrf": ["anything"]}) is False + + +def test_csrf_field_renders_current_token(): + _set_token("abc123") + field = _csrf_field() + assert 'name="_csrf"' in field + assert 'value="abc123"' in field + + +def test_get_csrf_token_returns_empty_when_unset(): + _clear_token() + assert _get_csrf_token() == "" diff --git a/tests/test_db_index_url.py b/tests/test_db_index_url.py new file mode 100644 index 0000000..50f73ce --- /dev/null +++ b/tests/test_db_index_url.py @@ -0,0 +1,155 @@ +"""Tests for `index_url` — the main write path. + +Covers UPSERT behavior, links being replaced on re-index, FTS index staying +in sync via triggers, and the connection pool returning clean connections. +""" +from unittest.mock import patch + +from conftest import patch_dns_ok +import db as db_module +from db import get_db, return_db, index_url + + +def _mock_fetch_page(title="Test Page", body="test body text", links=None, meta=""): + """Return a replacement for db.fetch_page that yields canned data.""" + links = links or [] + def fake(url): + return (title, body, links, meta) + return fake + + +def test_insert_creates_page_row_and_fts_entry(temp_db, monkeypatch): + patch_dns_ok(monkeypatch) + monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page( + title="Rust Intro", body="ownership and borrowing basics", links=[], + )) + index_url("https://example.com/rust") + + db = get_db() + try: + row = db.execute("SELECT id, title, body FROM pages").fetchone() + assert row is not None + assert row["title"] == "Rust Intro" + assert "ownership" in row["body"] + # Verify FTS trigger fired. + fts_hits = db.execute( + "SELECT rowid FROM pages_fts WHERE pages_fts MATCH 'ownership*'" + ).fetchall() + assert len(fts_hits) == 1 + assert fts_hits[0]["rowid"] == row["id"] + finally: + return_db(db) + + +def test_re_indexing_same_url_updates_in_place(temp_db, monkeypatch): + patch_dns_ok(monkeypatch) + monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page( + title="First Title", body="first body", links=[], + )) + index_url("https://example.com/page") + + monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page( + title="Second Title", body="second body", links=[], + )) + index_url("https://example.com/page") + + db = get_db() + try: + rows = db.execute("SELECT title, body FROM pages").fetchall() + finally: + return_db(db) + assert len(rows) == 1, "re-indexing should UPDATE not INSERT" + assert rows[0]["title"] == "Second Title" + + +def test_links_replaced_on_reindex(temp_db, monkeypatch): + patch_dns_ok(monkeypatch) + monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page( + title="T", body="b", + links=[("https://example.com/a", "first"), ("https://example.com/b", "second")], + )) + index_url("https://example.com/src") + + monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page( + title="T", body="b", + links=[("https://example.com/c", "third-only")], + )) + index_url("https://example.com/src") + + db = get_db() + try: + rows = db.execute("SELECT url FROM links").fetchall() + finally: + return_db(db) + urls = {r["url"] for r in rows} + assert urls == {"https://example.com/c"}, "old links should be deleted on reindex" + + +def test_url_cleaned_before_insert(temp_db, monkeypatch): + """index_url should apply clean_url before touching the DB, so tracking params + don't create duplicate rows.""" + patch_dns_ok(monkeypatch) + monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(title="T", body="b")) + index_url("https://example.com/page?utm_source=twitter#frag") + + db = get_db() + try: + rows = db.execute("SELECT url FROM pages").fetchall() + finally: + return_db(db) + assert len(rows) == 1 + assert rows[0]["url"] == "https://example.com/page" + + +def test_summary_populated_from_meta_description(temp_db, monkeypatch): + patch_dns_ok(monkeypatch) + long_meta = "A thoughtful description that exceeds twenty chars" + monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page( + title="T", body="b", meta=long_meta, + )) + index_url("https://example.com/page") + + db = get_db() + try: + row = db.execute("SELECT summary FROM pages").fetchone() + finally: + return_db(db) + assert row["summary"] == long_meta + + +def test_short_meta_description_not_stored_as_summary(temp_db, monkeypatch): + patch_dns_ok(monkeypatch) + monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page( + title="T", body="b", meta="too short", + )) + index_url("https://example.com/page") + + db = get_db() + try: + row = db.execute("SELECT summary FROM pages").fetchone() + finally: + return_db(db) + assert row["summary"] == "" + + +def test_pool_returns_clean_connection(temp_db, monkeypatch): + """Regression for 1bc695f — `return_db` should roll back uncommitted work + so the next consumer doesn't see stale state.""" + patch_dns_ok(monkeypatch) + monkeypatch.setattr(db_module, "fetch_page", _mock_fetch_page(title="T", body="b")) + index_url("https://example.com/one") + + # Take a connection, make a dirty uncommitted change, return it. + db = get_db() + db.execute("INSERT INTO pages (url, title, body) VALUES (?, ?, ?)", + ("https://dirty.example.com/", "dirty", "dirty")) + # NOTE: no commit here — this is the dirty state we want rolled back. + return_db(db) + + # A later consumer must not see the dirty row. + db2 = get_db() + try: + urls = {r["url"] for r in db2.execute("SELECT url FROM pages").fetchall()} + finally: + return_db(db2) + assert "https://dirty.example.com/" not in urls diff --git a/tests/test_db_schema.py b/tests/test_db_schema.py new file mode 100644 index 0000000..5a4f77c --- /dev/null +++ b/tests/test_db_schema.py @@ -0,0 +1,90 @@ +"""Tests for `init_db` and the settings key-value store. + +`init_db` is called unconditionally on startup, so it must be idempotent +and create every table/trigger the rest of the app expects. +""" +from db import get_db, return_db, init_db, get_setting, set_setting, get_site_name + + +EXPECTED_TABLES = { + "pages", "links", "settings", "subscriptions", + "remote_pages", "tags", "page_tags", "chunks", + # FTS5 virtual tables: + "pages_fts", "remote_pages_fts", +} + + +def test_all_expected_tables_exist(temp_db): + db = get_db() + try: + rows = db.execute( + "SELECT name FROM sqlite_master WHERE type IN ('table') AND name NOT LIKE 'sqlite_%'" + ).fetchall() + names = {r["name"] for r in rows} + finally: + return_db(db) + missing = EXPECTED_TABLES - names + assert not missing, f"tables missing after init_db: {missing}" + + +def test_fts_triggers_exist(temp_db): + db = get_db() + try: + rows = db.execute( + "SELECT name FROM sqlite_master WHERE type = 'trigger'" + ).fetchall() + names = {r["name"] for r in rows} + finally: + return_db(db) + # These triggers keep pages_fts in sync with pages on insert/update/delete. + for trigger in ("pages_ai", "pages_ad", "pages_au"): + assert trigger in names, f"missing trigger {trigger}" + + +def test_init_db_is_idempotent(temp_db): + """Running init_db twice on the same DB must not error or duplicate anything.""" + init_db() + init_db() # second call should be a no-op + db = get_db() + try: + count = db.execute( + "SELECT count(*) FROM sqlite_master WHERE name = 'pages'" + ).fetchone()[0] + finally: + return_db(db) + assert count == 1 + + +def test_get_setting_returns_default_when_missing(temp_db): + assert get_setting("nonexistent", "fallback") == "fallback" + assert get_setting("nonexistent") == "" + + +def test_set_setting_then_get(temp_db): + set_setting("site_name", "my-personal-index") + assert get_setting("site_name") == "my-personal-index" + + +def test_set_setting_updates_existing(temp_db): + set_setting("key", "first") + set_setting("key", "second") + assert get_setting("key") == "second" + + +def test_get_site_name_has_default(temp_db): + assert get_site_name() == "tinyweb" + + +def test_get_site_name_reflects_override(temp_db): + set_setting("site_name", "custom-site") + assert get_site_name() == "custom-site" + + +def test_foreign_keys_pragma_enabled(temp_db): + """Pool connections should have foreign_keys=ON so CASCADE deletes work.""" + db = get_db() + try: + row = db.execute("PRAGMA foreign_keys").fetchone() + finally: + return_db(db) + assert row[0] == 1 diff --git a/tests/test_fts_sanitizer.py b/tests/test_fts_sanitizer.py new file mode 100644 index 0000000..ad061da --- /dev/null +++ b/tests/test_fts_sanitizer.py @@ -0,0 +1,113 @@ +"""Tests for `_sanitize_fts_query`. + +The sanitizer is the boundary between user input and FTS5 MATCH syntax. +Commit 1bc695f tightened it after noticing that colons and operator words +could escape the quoting. These tests keep that regression dead. +""" +import pytest + +from handlers import _sanitize_fts_query + + +def test_empty_query_returns_no_match_token(): + assert _sanitize_fts_query("") == '""' + assert _sanitize_fts_query(" ") == '""' + + +def test_single_word_becomes_prefix_match(): + assert _sanitize_fts_query("rust") == "rust*" + + +def test_multi_word_quotes_all_but_last(): + result = _sanitize_fts_query("rust borrow checker") + assert result == '"rust" "borrow" checker*' + + +def test_stopwords_are_dropped(): + # "the" and "a" should vanish; only "cat" remains (and gets prefix star). + assert _sanitize_fts_query("the a cat") == "cat*" + + +def test_all_stopwords_returns_no_match_token(): + assert _sanitize_fts_query("the and or") == '""' + + +@pytest.mark.parametrize("bad_char", ["'", "(", ")", "+", "-", "^", "~", ":"]) +def test_fts5_operators_stripped_from_tokens(bad_char): + """FTS5 special chars inside user tokens must not survive — regression for 1bc695f. + + The sanitizer legitimately adds `"` around tokens and a trailing `*` for prefix + matching; both are excluded from this check. + """ + payload = f"foo{bad_char}bar" + out = _sanitize_fts_query(payload) + assert bad_char not in out, f"{bad_char!r} leaked into {out!r}" + + +def test_asterisk_only_appears_as_trailing_prefix(): + """Input `*` should not become an in-token asterisk; the sanitizer's trailing `*` is fine.""" + out = _sanitize_fts_query("foo*bar") + assert out.count("*") <= 1 + if "*" in out: + assert out.endswith("*") + + +def test_quote_in_input_does_not_break_out_of_quoted_token(): + """A `"` in user input must not close the sanitizer's protective quoting. + + The sanitizer wraps each non-last token in double quotes; if a stray `"` from + the user slipped through, the resulting FTS5 expression would be interpreted + as broken syntax or, worse, a column filter. + """ + out = _sanitize_fts_query('foo"bar baz"qux') + # Each pair of quotes in the output should be balanced and around a clean token. + assert out.count('"') % 2 == 0 + # No embedded quotes inside a quoted region. + import re + for match in re.findall(r'"[^"]*"', out): + inner = match[1:-1] + assert '"' not in inner + + +@pytest.mark.parametrize("op", ["AND", "OR", "NOT", "NEAR", "and", "or", "not", "near"]) +def test_fts5_operator_words_dropped(op): + """AND/OR/NOT/NEAR would be interpreted as operators on the unquoted last token.""" + out = _sanitize_fts_query(f"foo {op} bar") + # the operator word itself should not appear + assert op.upper() not in out.upper().split('"'), f"operator {op!r} survived in {out!r}" + + +def test_injection_payload_produces_valid_fts5(): + """End-to-end: a realistic injection payload must produce syntactically valid FTS5. + + We run the sanitized output through a throwaway FTS5 table; if the sanitizer + leaks operator characters the MATCH either raises or interprets malicious syntax. + """ + import sqlite3 + conn = sqlite3.connect(":memory:") + conn.execute("CREATE VIRTUAL TABLE t USING fts5(body)") + conn.execute("INSERT INTO t (body) VALUES ('hello world')") + + for payload in [ + 'foo": OR bar NOT baz AND qux*()', + '" OR 1=1 --', + "title:secret AND public", + "(((", + "^^^~~~", + ]: + q = _sanitize_fts_query(payload) + # Must not raise — if operators leaked, FTS5 would error or mis-parse. + conn.execute("SELECT * FROM t WHERE t MATCH ?", (q,)).fetchall() + conn.close() + + +def test_whitespace_only_tokens_dropped(): + # tokens that become empty after stripping special chars should not produce bare quotes + out = _sanitize_fts_query('""" "" ""') + assert out == '""' + + +def test_colon_stripped(): + """Regression for 1bc695f — colon is an FTS5 column filter and must be stripped.""" + out = _sanitize_fts_query("title:secret") + assert ":" not in out diff --git a/tests/test_gateway_limits.py b/tests/test_gateway_limits.py new file mode 100644 index 0000000..6033c3a --- /dev/null +++ b/tests/test_gateway_limits.py @@ -0,0 +1,164 @@ +"""Tests for gateway-level guards: body-size cap and Reticulum surface whitelist. + +Regression targets from commit 1bc695f — a 16 MiB upload limit (DoS guard) +and a strict GET-/api/sites-only whitelist for requests arriving over the +Reticulum mesh (CSRF can't protect mesh callers, so gate by whitelist). +""" +import io + +import pytest + +import app as app_module +from gateway import GatewayHandler, MAX_BODY_SIZE + + +class FakeHeaders: + """Minimal replacement for http.server request headers.""" + def __init__(self, items=None): + self._items = dict(items or {}) + + def get(self, key, default=None): + return self._items.get(key, default) + + +class FakeGatewayHandler(GatewayHandler): + """Bypass the socket-bound __init__ and capture response calls in memory.""" + def __init__(self, path="/", method="POST", headers=None, rfile=None): + self.path = path + self.command = method + self.headers = FakeHeaders(headers or {}) + self.rfile = rfile or io.BytesIO() + self.wfile = io.BytesIO() + self._captured = { + "error": None, "status": None, "headers": [], "body_written": None, + } + + def send_error(self, code, msg=""): + self._captured["error"] = (code, msg) + + def send_response(self, code): + self._captured["status"] = code + + def send_header(self, k, v): + self._captured["headers"].append((k, v)) + + def end_headers(self): + pass + + +def test_post_over_size_cap_rejected_with_413(): + """Regression for 1bc695f: request bodies over MAX_BODY_SIZE must be rejected + without being read into memory.""" + oversize = MAX_BODY_SIZE + 1 + handler = FakeGatewayHandler( + path="/add", + method="POST", + headers={"Content-Length": str(oversize)}, + ) + handler._forward("POST") + assert handler._captured["error"] is not None + code, _msg = handler._captured["error"] + assert code == 413 + + +def test_post_at_size_cap_accepted(): + """A body exactly at MAX_BODY_SIZE should not be rejected by the size check.""" + handler = FakeGatewayHandler( + path="/_does_not_matter", + method="POST", + headers={"Content-Length": str(MAX_BODY_SIZE)}, + # rfile has no data; handler will try to read; local_dispatch isn't set. + # We only care that the 413 check passes, not that the request succeeds. + rfile=io.BytesIO(b""), + ) + # Stub out local_dispatch so _forward doesn't try the network path. + from gateway import GatewayState + original = GatewayState.local_dispatch + GatewayState.local_dispatch = lambda data: { + "status": 404, "content_type": "text/plain", "body": "nope", + } + try: + handler._forward("POST") + finally: + GatewayState.local_dispatch = original + # Not a 413, because the body is exactly at the cap (cap is inclusive). + if handler._captured["error"]: + assert handler._captured["error"][0] != 413 + + +def test_negative_content_length_rejected(): + handler = FakeGatewayHandler( + path="/add", + method="POST", + headers={"Content-Length": "-1"}, + ) + handler._forward("POST") + assert handler._captured["error"] is not None + code, _msg = handler._captured["error"] + assert code == 400 + + +def test_invalid_content_length_rejected(): + handler = FakeGatewayHandler( + path="/add", + method="POST", + headers={"Content-Length": "abc"}, + ) + handler._forward("POST") + assert handler._captured["error"] is not None + code, _msg = handler._captured["error"] + assert code == 400 + + +# -------- Reticulum mesh surface whitelist -------- + + +def test_mesh_rejects_non_api_sites_get(): + """Regression for 1bc695f: remote mesh callers can only GET /api/sites.""" + resp = app_module.rns_request_handler( + path="/tinyweb", + data={"method": "GET", "path": "/pages", "query": {}, "body": {}, "gateway_host": ""}, + request_id="x", link_id="y", remote_identity=None, requested_at=0, + ) + assert resp["status"] == 403 + + +def test_mesh_rejects_post_to_api_sites(): + resp = app_module.rns_request_handler( + path="/tinyweb", + data={"method": "POST", "path": "/api/sites", "query": {}, "body": {}, "gateway_host": ""}, + request_id="x", link_id="y", remote_identity=None, requested_at=0, + ) + assert resp["status"] == 403 + + +def test_mesh_rejects_sensitive_local_endpoints(): + for path in ("/add", "/delete/1", "/style", "/import", "/export"): + resp = app_module.rns_request_handler( + path="/tinyweb", + data={"method": "GET", "path": path, "query": {}, "body": {}, "gateway_host": ""}, + request_id="x", link_id="y", remote_identity=None, requested_at=0, + ) + assert resp["status"] == 403, f"path {path!r} leaked through mesh whitelist" + + +def test_mesh_allows_api_sites_get(temp_db, csrf_session): + """Sanity check: the one whitelisted combination is accepted.""" + resp = app_module.rns_request_handler( + path="/tinyweb", + data={"method": "GET", "path": "/api/sites", "query": {}, "body": {}, "gateway_host": ""}, + request_id="x", link_id="y", remote_identity=None, requested_at=0, + ) + # Status depends on handler output; 200 is the happy path. + assert resp["status"] in (200, 403) # 403 if sharing is disabled by default + + +def test_mesh_handles_missing_data_payload(): + """Regression-minded check: a None or malformed data object shouldn't crash.""" + resp = app_module.rns_request_handler( + path="/tinyweb", + data=None, + request_id="x", link_id="y", remote_identity=None, requested_at=0, + ) + # Default data has method=GET, path=/ which is not in the whitelist. + assert resp["status"] == 403 diff --git a/tests/test_handlers_pages.py b/tests/test_handlers_pages.py new file mode 100644 index 0000000..ab4704c --- /dev/null +++ b/tests/test_handlers_pages.py @@ -0,0 +1,174 @@ +"""Tests for `handle_bulk_action`, edit flow, and the bulk-delete confirm step. + +The bulk-delete confirmation flow is a data-loss guard added in commit +8dffd8c — a stray POST without `confirmed=1` must render the confirmation +page instead of actually deleting. +""" +from db import get_db, return_db +from handlers import ( + handle_bulk_action, + handle_edit_form, + handle_edit_submit, + handle_pages, +) + + +def _all_urls(seeded_db): + db = get_db() + try: + return {r["url"] for r in db.execute("SELECT url FROM pages").fetchall()} + finally: + return_db(db) + + +def _page_id(seeded_db, url): + db = get_db() + try: + return db.execute("SELECT id FROM pages WHERE url = ?", (url,)).fetchone()["id"] + finally: + return_db(db) + + +def test_bulk_delete_without_confirmed_renders_confirm_page(seeded_db, csrf_session): + """Regression for 8dffd8c: bulk delete must NOT delete until confirmed=1 is set.""" + pid = _page_id(seeded_db, "https://example.com/rust-intro") + urls_before = _all_urls(seeded_db) + + resp = handle_bulk_action({ + "ids": [str(pid)], + "action": ["delete"], + }) + assert resp["status"] == 200 + assert "confirm delete" in resp["body"].lower() + assert "Rust Intro" in resp["body"] + # Must still show a hidden confirmed=1 field in the follow-up form. + assert 'name="confirmed" value="1"' in resp["body"] + + # Crucially: nothing should have been deleted. + assert _all_urls(seeded_db) == urls_before + + +def test_bulk_delete_with_confirmed_actually_deletes(seeded_db, csrf_session): + pid = _page_id(seeded_db, "https://example.com/rust-intro") + + resp = handle_bulk_action({ + "ids": [str(pid)], + "action": ["delete"], + "confirmed": ["1"], + }) + # Confirmed delete redirects back to /pages. + assert resp["status"] in (302, 303) + + urls = _all_urls(seeded_db) + assert "https://example.com/rust-intro" not in urls + # Other pages untouched. + assert "https://example.com/python-tips" in urls + + +def test_bulk_delete_with_no_ids_redirects(seeded_db, csrf_session): + resp = handle_bulk_action({ + "ids": [], + "action": ["delete"], + "confirmed": ["1"], + }) + assert resp["status"] in (302, 303) + assert _all_urls(seeded_db) == { + "https://example.com/rust-intro", + "https://example.com/python-tips", + "https://example.com/ocaml-why", + "https://news.example.org/mesh", + } + + +def test_bulk_delete_rejects_non_integer_ids(seeded_db, csrf_session): + resp = handle_bulk_action({ + "ids": ["not-a-number"], + "action": ["delete"], + "confirmed": ["1"], + }) + assert resp["status"] == 400 + + +def test_bulk_retag_add_mode_merges_tags(seeded_db, csrf_session): + pid = _page_id(seeded_db, "https://example.com/python-tips") + + handle_bulk_action({ + "ids": [str(pid)], + "action": ["retag"], + "bulk_tags": ["scripting, tutorials"], + "tag_mode": ["add"], + }) + db = get_db() + try: + rows = db.execute( + "SELECT t.name FROM tags t JOIN page_tags pt ON pt.tag_id = t.id " + "WHERE pt.page_id = ? ORDER BY t.name", + (pid,), + ).fetchall() + finally: + return_db(db) + tags = [r["name"] for r in rows] + assert "python" in tags # existing kept + assert "scripting" in tags # new added + assert "tutorials" in tags + + +def test_bulk_retag_replace_mode_overwrites_tags(seeded_db, csrf_session): + pid = _page_id(seeded_db, "https://example.com/python-tips") + + handle_bulk_action({ + "ids": [str(pid)], + "action": ["retag"], + "bulk_tags": ["one, two"], + "tag_mode": ["replace"], + }) + db = get_db() + try: + rows = db.execute( + "SELECT t.name FROM tags t JOIN page_tags pt ON pt.tag_id = t.id " + "WHERE pt.page_id = ?", + (pid,), + ).fetchall() + finally: + return_db(db) + tags = {r["name"] for r in rows} + assert tags == {"one", "two"} + assert "python" not in tags + + +def test_edit_form_renders_current_values(seeded_db, csrf_session): + pid = _page_id(seeded_db, "https://example.com/rust-intro") + resp = handle_edit_form(pid) + assert resp["status"] == 200 + assert "Rust Intro" in resp["body"] + # Existing tags should appear in the tag field. + assert "rust" in resp["body"] + + +def test_edit_form_404_for_unknown_page(temp_db, csrf_session): + resp = handle_edit_form(99999) + assert resp["status"] == 404 + + +def test_edit_submit_updates_title_and_note(seeded_db, csrf_session): + pid = _page_id(seeded_db, "https://example.com/rust-intro") + handle_edit_submit(pid, { + "title": ["New Rust Title"], + "note": ["new annotation"], + "tags": ["rust, updated"], + }) + db = get_db() + try: + row = db.execute("SELECT title, note FROM pages WHERE id = ?", (pid,)).fetchone() + finally: + return_db(db) + assert row["title"] == "New Rust Title" + assert row["note"] == "new annotation" + + +def test_handle_pages_lists_indexed_pages(seeded_db, csrf_session): + resp = handle_pages({}) + assert resp["status"] == 200 + # Every seeded page title appears on the list page. + for title in ("Rust Intro", "Python Tips", "Why OCaml", "Mesh Networking"): + assert title in resp["body"] diff --git a/tests/test_handlers_search.py b/tests/test_handlers_search.py new file mode 100644 index 0000000..f7d2f9e --- /dev/null +++ b/tests/test_handlers_search.py @@ -0,0 +1,63 @@ +"""Tests for `handle_search` — the home page + primary user flow.""" +from handlers import handle_search + + +def test_empty_index_empty_query_shows_welcome(temp_db, csrf_session): + resp = handle_search({}) + assert resp["status"] == 200 + body = resp["body"] + assert "Your index is empty" in body + # Links the welcome panel offers as equal-weight starting points. + assert "/add" in body + assert "/style" in body + assert "/subscriptions" in body + + +def test_empty_index_with_query_shows_no_results(temp_db, csrf_session): + resp = handle_search({"q": ["rust"]}) + assert resp["status"] == 200 + assert "No results in your index" in resp["body"] + + +def test_populated_index_with_matching_query_returns_results(seeded_db, csrf_session): + resp = handle_search({"q": ["rust"]}) + assert resp["status"] == 200 + assert "Rust Intro" in resp["body"] + # Page count shown in meta line. + assert "4 pages indexed" in resp["body"] + + +def test_query_only_matches_relevant_pages(seeded_db, csrf_session): + resp = handle_search({"q": ["ocaml"]}) + body = resp["body"] + assert "Why OCaml" in body + assert "Python Tips" not in body + assert "Rust Intro" not in body + + +def test_pagination_query_param_respected(seeded_db, csrf_session): + """A high page number should still render without crashing.""" + resp = handle_search({"q": ["example"], "p": ["99"]}) + assert resp["status"] == 200 + + +def test_trusted_sites_fallback_surfaces_when_query_matches_link_label(seeded_db, csrf_session): + """Links extracted from indexed pages act as a fallback when direct results + are absent or thin; labels are substring-matched case-insensitively.""" + resp = handle_search({"q": ["advanced"]}) + body = resp["body"] + # The label "advanced rust guide" is on a link extracted from rust-intro. + assert "advanced rust guide" in body + assert "trusted sites" in body + + +def test_page_count_in_meta_line(seeded_db, csrf_session): + resp = handle_search({}) + assert "4 pages indexed" in resp["body"] + + +def test_csp_and_security_headers_not_in_handler_but_via_dispatch(seeded_db, csrf_session): + """Handler itself returns no security headers; dispatch_request wraps them. + This test documents the boundary so future refactors don't break assumptions.""" + resp = handle_search({}) + assert "headers" not in resp or "Content-Security-Policy" not in resp.get("headers", {}) diff --git a/tests/test_handlers_subs.py b/tests/test_handlers_subs.py new file mode 100644 index 0000000..93ee97d --- /dev/null +++ b/tests/test_handlers_subs.py @@ -0,0 +1,112 @@ +"""Tests for subscription handlers. + +Subscription add validates the destination hash (32-char hex) locally +before calling `fetch_remote_sites`; browse uses cached remote_pages when +available and falls back to a live fetch otherwise. +""" +from unittest.mock import patch + +import handlers as handlers_module +from db import get_db, return_db +from handlers import handle_subscription_add, handle_subscription_browse + + +VALID_HASH = "a" * 32 + + +def _subscription_count(): + db = get_db() + try: + return db.execute("SELECT count(*) FROM subscriptions").fetchone()[0] + finally: + return_db(db) + + +def test_rejects_empty_dest_hash(temp_db, csrf_session): + resp = handle_subscription_add({"dest_hash": [""]}) + assert "32-character" in resp["body"] + assert _subscription_count() == 0 + + +def test_rejects_wrong_length(temp_db, csrf_session): + resp = handle_subscription_add({"dest_hash": ["abc123"]}) + assert "32-character" in resp["body"] + assert _subscription_count() == 0 + + +def test_rejects_non_hex(temp_db, csrf_session): + resp = handle_subscription_add({"dest_hash": ["z" * 32]}) + assert "hex" in resp["body"].lower() + assert _subscription_count() == 0 + + +def test_rejects_unreachable_peer(temp_db, csrf_session): + with patch.object(handlers_module, "fetch_remote_sites") as fetch: + fetch.side_effect = ConnectionError("unreachable") + resp = handle_subscription_add({"dest_hash": [VALID_HASH]}) + assert "Could not reach" in resp["body"] + assert _subscription_count() == 0 + + +def test_rejects_peer_with_sharing_disabled(temp_db, csrf_session): + with patch.object(handlers_module, "fetch_remote_sites") as fetch: + fetch.side_effect = PermissionError("sharing disabled") + resp = handle_subscription_add({"dest_hash": [VALID_HASH]}) + assert "sharing disabled" in resp["body"] + assert _subscription_count() == 0 + + +def test_successful_add_records_subscription(temp_db, csrf_session): + with patch.object(handlers_module, "fetch_remote_sites") as fetch: + fetch.return_value = {"name": "alice", "sites": []} + resp = handle_subscription_add({"dest_hash": [VALID_HASH]}) + assert "Subscribed to alice" in resp["body"] + assert _subscription_count() == 1 + + +def test_dest_hash_strips_angle_brackets(temp_db, csrf_session): + """Users often paste hashes as `` from RNS log output; strip them.""" + with patch.object(handlers_module, "fetch_remote_sites") as fetch: + fetch.return_value = {"name": "bob", "sites": []} + resp = handle_subscription_add({"dest_hash": [f"<{VALID_HASH}>"]}) + assert _subscription_count() == 1 + + +def test_browse_unknown_subscription_is_404(temp_db, csrf_session): + resp = handle_subscription_browse(99999) + assert resp["status"] == 404 + + +def test_browse_marks_already_indexed_urls(seeded_db, csrf_session): + # Insert a subscription + some remote pages (one duplicate of local, one new). + db = get_db() + try: + db.execute( + "INSERT INTO subscriptions (dest_hash, name) VALUES (?, ?)", + (VALID_HASH, "alice"), + ) + sub_id = db.execute("SELECT id FROM subscriptions").fetchone()["id"] + db.execute( + "INSERT INTO remote_pages (subscription_id, url, title, note, tags) " + "VALUES (?, ?, ?, ?, ?)", + (sub_id, "https://example.com/rust-intro", "Alice rust pick", "", ""), + ) + db.execute( + "INSERT INTO remote_pages (subscription_id, url, title, note, tags) " + "VALUES (?, ?, ?, ?, ?)", + (sub_id, "https://new.example.com/shiny", "Shiny New Link", "note", "tag1"), + ) + db.commit() + finally: + return_db(db) + + resp = handle_subscription_browse(sub_id) + body = resp["body"] + assert resp["status"] == 200 + assert "already indexed" in body + # The duplicate URL should appear in the "already indexed" section. + assert "Alice rust pick" in body + # The new URL should be in the selectable section. + assert "Shiny New Link" in body + # Count summary: "2 site(s) available, 1 new" + assert "1 new" in body diff --git a/tests/test_handlers_tags.py b/tests/test_handlers_tags.py new file mode 100644 index 0000000..7ec8f05 --- /dev/null +++ b/tests/test_handlers_tags.py @@ -0,0 +1,101 @@ +"""Tests for tag helpers and the tag browse handler. + +Tags are stored via a join table, so orphaned rows in `tags` can accumulate +if `_cleanup_orphaned_tags` isn't called after deletion/retagging. Tag +counts shown in the UI rely on this being right. +""" +from db import get_db, return_db +from handlers import ( + _cleanup_orphaned_tags, + _get_page_tags, + _set_page_tags, + handle_tag_browse, + handle_tags, +) + + +def _page_id(url): + db = get_db() + try: + row = db.execute("SELECT id FROM pages WHERE url = ?", (url,)).fetchone() + return row["id"] if row else None + finally: + return_db(db) + + +def _tag_names(): + db = get_db() + try: + return {r["name"] for r in db.execute("SELECT name FROM tags").fetchall()} + finally: + return_db(db) + + +def test_get_page_tags_returns_sorted_names(seeded_db): + pid = _page_id("https://example.com/rust-intro") + tags = _get_page_tags(pid) + assert tags == sorted(tags) # alphabetical + assert "rust" in tags + assert "public" in tags + + +def test_set_page_tags_replaces_existing(seeded_db): + pid = _page_id("https://example.com/rust-intro") + db = get_db() + try: + _set_page_tags(pid, "brand, new, tags", db) + db.commit() + finally: + return_db(db) + current = _get_page_tags(pid) + assert current == ["brand", "new", "tags"] + + +def test_set_page_tags_splits_on_comma_and_lowercases(seeded_db): + pid = _page_id("https://example.com/python-tips") + db = get_db() + try: + _set_page_tags(pid, "Foo, BAR, baz", db) + db.commit() + finally: + return_db(db) + assert set(_get_page_tags(pid)) == {"foo", "bar", "baz"} + + +def test_cleanup_orphaned_tags_removes_unreferenced(seeded_db): + # Clear all tags on one page; previously-unique tags become orphans. + pid = _page_id("https://example.com/rust-intro") + db = get_db() + try: + _set_page_tags(pid, "", db) # empty string = no tags + # `rust` was only on the rust-intro page; `public` is also on mesh. + _cleanup_orphaned_tags(db) + db.commit() + finally: + return_db(db) + names = _tag_names() + assert "rust" not in names # pruned + assert "public" in names # still on mesh + + +def test_handle_tag_browse_filters_by_tag(seeded_db, csrf_session): + resp = handle_tag_browse("rust", {}) + assert resp["status"] == 200 + body = resp["body"] + assert "Rust Intro" in body + assert "Python Tips" not in body + assert "Why OCaml" not in body + + +def test_handle_tag_browse_unknown_tag_is_graceful(seeded_db, csrf_session): + resp = handle_tag_browse("no-such-tag", {}) + # Should render a valid page with zero results, not error. + assert resp["status"] == 200 + + +def test_handle_tags_lists_all_tags_with_counts(seeded_db, csrf_session): + resp = handle_tags() + assert resp["status"] == 200 + body = resp["body"] + for tag in ("rust", "python", "ocaml", "mesh", "public", "private"): + assert tag in body diff --git a/tests/test_link_extraction.py b/tests/test_link_extraction.py new file mode 100644 index 0000000..2d8c741 --- /dev/null +++ b/tests/test_link_extraction.py @@ -0,0 +1,138 @@ +"""Tests for link extraction inside `fetch_page`. + +Link extraction powers the "trusted sites" fallback on empty searches and +feeds the `links` table. Rules: same-domain only, skip binary extensions, +skip Wikipedia special pages, resolve relatives via urljoin. +""" +from unittest.mock import patch + +from conftest import patch_dns_ok +import db as db_module + + +class FakeResponse: + def __init__(self, text, status_code=200): + self.text = text + self.status_code = status_code + self.is_redirect = False + self.headers = {} + + def raise_for_status(self): + if self.status_code >= 400: + raise Exception(f"status {self.status_code}") + + +def _fetch_with_html(monkeypatch, url, html): + """Invoke fetch_page against `url` with `html` as the mocked response body.""" + patch_dns_ok(monkeypatch) + with patch.object(db_module, "requests") as mock_requests: + mock_requests.get.return_value = FakeResponse(html) + return db_module.fetch_page(url) + + +def test_only_same_domain_links_kept(monkeypatch): + html = """ + + same + cross + subdomain + + """ + _, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html) + urls = [u for u, _label in links] + assert "https://example.com/a" in urls + assert "https://other.com/b" not in urls + assert "https://sub.example.com/c" not in urls + + +def test_binary_extensions_skipped(monkeypatch): + html = """ + + keep + skip + skip + skip + skip + skip + + """ + _, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html) + urls = [u for u, _label in links] + assert "https://example.com/real-page" in urls + for ext in (".png", ".pdf", ".zip", ".mp3", ".css"): + assert not any(u.endswith(ext) for u in urls), f"{ext} leaked through" + + +def test_wikipedia_special_pages_skipped(monkeypatch): + html = """ + + keep + skip + skip + skip + skip + + """ + _, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html) + urls = [u for u, _label in links] + assert "https://example.com/wiki/Main_Page" in urls + for skip in ("Special:Random", "Talk:Foo", "User:Jimbo", "Category:Bar"): + assert not any(skip in u for u in urls), f"wiki {skip!r} leaked" + + +def test_relative_urls_resolved(monkeypatch): + html = """r""" + _, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/start", html) + urls = [u for u, _label in links] + assert "https://example.com/relative/path" in urls + + +def test_fragment_stripped_from_extracted_links(monkeypatch): + html = """r""" + _, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html) + urls = [u for u, _label in links] + assert "https://example.com/page" in urls + assert not any("#" in u for u in urls) + + +def test_duplicate_links_deduped(monkeypatch): + html = """ + + first + second + third + + """ + _, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html) + urls = [u for u, _label in links] + assert urls.count("https://example.com/a") == 1 + + +def test_label_truncated_to_200(monkeypatch): + long_text = "x" * 500 + html = f'{long_text}' + _, _, links, _ = _fetch_with_html(monkeypatch, "https://example.com/", html) + assert len(links) == 1 + _, label = links[0] + assert len(label) <= 200 + + +def test_meta_description_extracted(monkeypatch): + html = """ + + +

body content

+ """ + title, body, links, meta = _fetch_with_html(monkeypatch, "https://example.com/", html) + assert meta == "the real description" + + +def test_og_description_fallback(monkeypatch): + """When there's no , og:description wins.""" + html = """ + + +

body

+ """ + _, _, _, meta = _fetch_with_html(monkeypatch, "https://example.com/", html) + assert meta == "open graph fallback" diff --git a/tests/test_pagination.py b/tests/test_pagination.py new file mode 100644 index 0000000..05077e0 --- /dev/null +++ b/tests/test_pagination.py @@ -0,0 +1,58 @@ +"""Tests for `_paginate` and `_page_nav`.""" +from handlers import _paginate, _page_nav, PER_PAGE + + +def test_paginate_default_is_one(): + assert _paginate({}) == 1 + + +def test_paginate_reads_query_string(): + assert _paginate({"p": ["3"]}) == 3 + + +def test_paginate_clamps_to_one(): + assert _paginate({"p": ["0"]}) == 1 + assert _paginate({"p": ["-5"]}) == 1 + + +def test_paginate_handles_bad_input(): + assert _paginate({"p": ["not-a-number"]}) == 1 + assert _paginate({"p": []}) == 1 + + +def test_paginate_custom_key(): + assert _paginate({"batch": ["7"]}, key="batch") == 7 + + +def test_page_nav_empty_when_single_page(): + assert _page_nav(1, PER_PAGE, "/?q=foo") == "" + assert _page_nav(1, 0, "/?q=foo") == "" + + +def test_page_nav_shows_next_on_first_page(): + out = _page_nav(1, PER_PAGE * 3, "/?q=foo") + assert "next" in out + assert "prev" not in out + assert "page 1 of 3" in out + + +def test_page_nav_shows_both_in_middle(): + out = _page_nav(2, PER_PAGE * 3, "/?q=foo") + assert "next" in out + assert "prev" in out + + +def test_page_nav_shows_prev_on_last_page(): + out = _page_nav(3, PER_PAGE * 3, "/?q=foo") + assert "next" not in out + assert "prev" in out + assert "page 3 of 3" in out + + +def test_page_nav_handles_query_string_separator(): + # when base_url already has ?, pagination links must use & + out = _page_nav(1, PER_PAGE * 2, "/?q=foo") + assert "&p=2" in out + # when base_url has no ?, pagination links use ? + out = _page_nav(1, PER_PAGE * 2, "/pages") + assert "?p=2" in out diff --git a/tests/test_regressions.py b/tests/test_regressions.py new file mode 100644 index 0000000..f8a5df7 --- /dev/null +++ b/tests/test_regressions.py @@ -0,0 +1,107 @@ +"""Aggregator of regression tests tied to specific bug-fix commits. + +Each test here guards against a specific bug that was once shipped. Running +just this file gives a one-line-per-bug audit: + + pytest tests/test_regressions.py -v + +The test bodies are intentionally small; for the exhaustive behavior of each +module, see the topical test files (test_fts_sanitizer.py, test_url_cleanup.py, +etc.). This file's job is to make the bug catalog scannable. +""" +import socket +from unittest.mock import patch + +import pytest + +import app as app_module +import db as db_module +import handlers as handlers_module +from conftest import patch_dns_fail, patch_dns_ok +from db import clean_url +from handlers import _sanitize_fts_query, handle_bulk_action + + +def test_6ffd38d_clean_url_preserves_www_when_bare_domain_fails(monkeypatch): + """6ffd38d: `clean_url` used to strip `www.` unconditionally; for sites that + only serve at `www.`, this produced unreachable clean URLs.""" + patch_dns_fail(monkeypatch) + assert clean_url("https://www.example.com/page") == "https://www.example.com/page" + + +def test_1bc695f_fts_sanitizer_strips_colon(): + """1bc695f: FTS5 colon is a column filter — must not appear in sanitized output.""" + assert ":" not in _sanitize_fts_query("title:secret body:exposed") + + +@pytest.mark.parametrize("op", ["AND", "OR", "NOT", "NEAR"]) +def test_1bc695f_fts_sanitizer_drops_operator_words(op): + """1bc695f: operator words (AND/OR/NOT/NEAR) would be interpreted as FTS5 + operators if they landed on the unquoted last token.""" + out = _sanitize_fts_query(f"foo {op} bar") + # operator itself should not appear in the output + tokens = out.replace('"', '').split() + assert op not in [t.rstrip("*") for t in tokens] + + +def test_1bc695f_gateway_rejects_oversize_body(): + """1bc695f: 16 MiB body-size cap prevents memory-exhaustion DoS.""" + from tests.test_gateway_limits import FakeGatewayHandler + from gateway import MAX_BODY_SIZE + h = FakeGatewayHandler( + path="/add", method="POST", + headers={"Content-Length": str(MAX_BODY_SIZE + 1)}, + ) + h._forward("POST") + assert h._captured["error"] and h._captured["error"][0] == 413 + + +def test_1bc695f_mesh_rejects_non_whitelisted_paths(): + """1bc695f: Reticulum callers are limited to GET /api/sites; CSRF cannot + authenticate mesh callers.""" + resp = app_module.rns_request_handler( + path="/tinyweb", + data={"method": "POST", "path": "/add", "query": {}, "body": {}, "gateway_host": ""}, + request_id="x", link_id="y", remote_identity=None, requested_at=0, + ) + assert resp["status"] == 403 + + +def test_1bc695f_pool_returns_clean_connection(temp_db, monkeypatch): + """1bc695f: uncommitted transactions on a pooled connection used to leak + into the next consumer.""" + from db import get_db, return_db + db = get_db() + db.execute( + "INSERT INTO pages (url, title, body) VALUES (?, ?, ?)", + ("https://leak.example.com/", "should not persist", "body"), + ) + return_db(db) # no commit + db2 = get_db() + try: + urls = {r["url"] for r in db2.execute("SELECT url FROM pages").fetchall()} + finally: + return_db(db2) + assert "https://leak.example.com/" not in urls + + +def test_8dffd8c_bulk_delete_requires_confirmation(seeded_db, csrf_session): + """8dffd8c: bulk delete without confirmed=1 must render a confirm page + instead of deleting — the JS confirm on /pages is a first-line filter only.""" + from db import get_db, return_db + db = get_db() + try: + pid = db.execute("SELECT id FROM pages LIMIT 1").fetchone()["id"] + count_before = db.execute("SELECT count(*) FROM pages").fetchone()[0] + finally: + return_db(db) + + resp = handle_bulk_action({"ids": [str(pid)], "action": ["delete"]}) + assert "confirm delete" in resp["body"].lower() + + db = get_db() + try: + count_after = db.execute("SELECT count(*) FROM pages").fetchone()[0] + finally: + return_db(db) + assert count_before == count_after, "bulk delete ran without confirmation" diff --git a/tests/test_sharing_logic.py b/tests/test_sharing_logic.py new file mode 100644 index 0000000..c9c06d4 --- /dev/null +++ b/tests/test_sharing_logic.py @@ -0,0 +1,38 @@ +"""Tests for `_page_is_shared`. + +This function decides whether a page is exposed over Reticulum to +subscribers. Getting it wrong means either a privacy leak or silently +hiding pages the user meant to share — both are worth a regression net. +""" +import pytest + +from handlers import _page_is_shared + + +@pytest.mark.parametrize("mode", ["exclude_private", "require_public"]) +def test_private_tag_always_excludes(mode): + """`private` tag overrides every mode — the most important invariant.""" + assert _page_is_shared(["private"], mode) is False + assert _page_is_shared(["public", "private"], mode) is False + + +def test_exclude_private_defaults_to_shared(): + assert _page_is_shared([], "exclude_private") is True + assert _page_is_shared(["random-tag"], "exclude_private") is True + + +def test_require_public_needs_public_tag(): + assert _page_is_shared([], "require_public") is False + assert _page_is_shared(["rust"], "require_public") is False + assert _page_is_shared(["public"], "require_public") is True + + +def test_require_public_still_vetoes_private(): + # public AND private → private wins. + assert _page_is_shared(["public", "private"], "require_public") is False + + +def test_unknown_mode_treated_as_exclude_private(): + """The default mode is 'exclude_private'; unknown modes fall through to it.""" + assert _page_is_shared([], "totally-bogus-mode") is True + assert _page_is_shared(["private"], "totally-bogus-mode") is False diff --git a/tests/test_ssrf.py b/tests/test_ssrf.py new file mode 100644 index 0000000..807f9bd --- /dev/null +++ b/tests/test_ssrf.py @@ -0,0 +1,64 @@ +"""Tests for `_validate_url_target` — SSRF prevention. + +Any URL the app fetches must resolve to a public IP; private/internal/ +loopback addresses must be rejected so attacker-controlled URLs cannot +reach internal services via our HTTP client. +""" +import socket +from unittest.mock import patch + +import pytest + +from db import _validate_url_target + + +def _mock_getaddrinfo(address): + """Return a function suitable as a socket.getaddrinfo replacement.""" + def f(host, port, *args, **kwargs): + family = socket.AF_INET6 if ":" in address else socket.AF_INET + return [(family, socket.SOCK_STREAM, 0, "", (address, port or 80))] + return f + + +@pytest.mark.parametrize("blocked_ip", [ + "127.0.0.1", + "127.1.2.3", + "10.0.0.1", + "10.255.255.255", + "172.16.0.1", + "172.31.255.255", + "192.168.0.1", + "192.168.255.255", + "169.254.169.254", + "0.0.0.0", + "::1", + "fc00::1", + "fe80::1", +]) +def test_blocks_private_and_loopback(monkeypatch, blocked_ip): + monkeypatch.setattr(socket, "getaddrinfo", _mock_getaddrinfo(blocked_ip)) + with pytest.raises(ValueError, match="blocked"): + _validate_url_target("https://evil.example.com/internal") + + +def test_allows_public_ipv4(monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", _mock_getaddrinfo("8.8.8.8")) + _validate_url_target("https://dns.example.com/") # does not raise + + +def test_allows_public_ipv6(monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", _mock_getaddrinfo("2001:4860:4860::8888")) + _validate_url_target("https://v6.example.com/") # does not raise + + +def test_rejects_unresolvable_hostname(monkeypatch): + def boom(*args, **kwargs): + raise socket.gaierror("no such host") + monkeypatch.setattr(socket, "getaddrinfo", boom) + with pytest.raises(ValueError, match="Cannot resolve"): + _validate_url_target("https://does-not-exist.example.com/") + + +def test_rejects_missing_hostname(): + with pytest.raises(ValueError, match="No hostname"): + _validate_url_target("http:///path-only") diff --git a/tests/test_url_cleanup.py b/tests/test_url_cleanup.py new file mode 100644 index 0000000..1eef72b --- /dev/null +++ b/tests/test_url_cleanup.py @@ -0,0 +1,101 @@ +"""Tests for `clean_url` — URL normalization and tracking-param stripping. + +Clean URLs are the deduplication key in the pages table, so any change to +this function can silently cause duplicate rows or mask legitimate saves. +""" +import pytest + +from conftest import patch_dns_ok, patch_dns_fail +from db import clean_url, TRACKING_PARAMS + + +def test_strips_fragment(monkeypatch): + patch_dns_ok(monkeypatch) + assert clean_url("https://example.com/page#section") == "https://example.com/page" + + +def test_prefers_https(monkeypatch): + patch_dns_ok(monkeypatch) + assert clean_url("http://example.com/page") == "https://example.com/page" + + +def test_lowercases_hostname(monkeypatch): + patch_dns_ok(monkeypatch) + assert clean_url("https://EXAMPLE.COM/page") == "https://example.com/page" + + +def test_preserves_path_case(monkeypatch): + """Paths are case-sensitive and should not be lowercased.""" + patch_dns_ok(monkeypatch) + assert clean_url("https://example.com/Foo/Bar") == "https://example.com/Foo/Bar" + + +def test_strips_default_https_port(monkeypatch): + patch_dns_ok(monkeypatch) + assert clean_url("https://example.com:443/page") == "https://example.com/page" + + +@pytest.mark.xfail(reason="clean_url upgrades http->https before the port-default check, " + "so port 80 is not stripped. Minor dedup bug — harmless but worth fixing.") +def test_strips_http_port_80(monkeypatch): + """Expected: http://foo:80 → https://foo (both scheme-upgrade and port-strip). + + Currently fails because scheme is upgraded to https *before* the port check, + so `scheme == "http" and port == 80` is never true by the time the check runs. + """ + patch_dns_ok(monkeypatch) + assert clean_url("http://example.com:80/page") == "https://example.com/page" + + +def test_preserves_non_default_port(monkeypatch): + patch_dns_ok(monkeypatch) + assert clean_url("https://example.com:8443/page") == "https://example.com:8443/page" + + +def test_strips_trailing_slash(monkeypatch): + patch_dns_ok(monkeypatch) + assert clean_url("https://example.com/page/") == "https://example.com/page" + + +def test_root_slash_preserved(monkeypatch): + patch_dns_ok(monkeypatch) + assert clean_url("https://example.com/") == "https://example.com/" + + +@pytest.mark.parametrize("param", sorted(TRACKING_PARAMS)) +def test_tracking_params_stripped(monkeypatch, param): + patch_dns_ok(monkeypatch) + result = clean_url(f"https://example.com/page?{param}=value&keep=yes") + assert param not in result + assert "keep=yes" in result + + +def test_strips_www_when_nonwww_resolves(monkeypatch): + """Standard case: strip `www.` prefix to canonicalize.""" + patch_dns_ok(monkeypatch) + assert clean_url("https://www.example.com/page") == "https://example.com/page" + + +def test_preserves_www_when_nonwww_does_not_resolve(monkeypatch): + """Regression for 6ffd38d. + + Some sites only serve their content at `www.domain.tld`; the bare domain + doesn't resolve. Stripping `www.` in that case produced a URL that we could + never actually fetch or dedupe against the real one. + """ + patch_dns_fail(monkeypatch) + assert clean_url("https://www.example.com/page") == "https://www.example.com/page" + + +def test_query_params_sorted_for_stable_ordering(monkeypatch): + """Same URL with different param orderings should produce the same clean URL.""" + patch_dns_ok(monkeypatch) + a = clean_url("https://example.com/page?b=2&a=1") + b = clean_url("https://example.com/page?a=1&b=2") + assert a == b + + +def test_path_and_query_preserved_through_cleanup(monkeypatch): + patch_dns_ok(monkeypatch) + result = clean_url("https://example.com/path/to/page?id=42&utm_source=twitter") + assert result == "https://example.com/path/to/page?id=42"