Add hybrid semantic search with optional cross-encoder reranking
Implements a three-stage search pipeline: 1. BM25 keyword search via FTS5 with column weights 2. Semantic search via Snowflake arctic-embed-s bi-encoder + HNSW index 3. Optional cross-encoder reranking (on by default, toggleable in settings) Top 20 results are reranked for precision, next 10 appended from RRF for coverage, giving 30 total results across 3 pages. - New embeddings.py with ONNX Runtime inference, text chunking, HNSW index management, RRF fusion, and cross-encoder reranking - Meta description extraction for authentic page snippets with centroid extractive fallback - Stopword filtering in FTS5 queries to avoid overly strict matching - /reindex page for batch embedding of existing pages - Semantic embedding of remote pages during subscription sync - ~125MB dependency footprint (onnxruntime, tokenizers, hnswlib, numpy) - Models: 34MB bi-encoder + 22MB cross-encoder (downloaded on first use) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
2df92752b6
commit
395fc17092
6 changed files with 839 additions and 17 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -1,3 +1,7 @@
|
|||
__pycache__/
|
||||
tinyweb_identity
|
||||
index.db
|
||||
index.db-shm
|
||||
index.db-wal
|
||||
models/
|
||||
index.hnsw
|
||||
|
|
|
|||
17
app.py
17
app.py
|
|
@ -72,8 +72,25 @@ def ensure_rns_config(config_dir):
|
|||
print(f"Created Reticulum config at {config_file}")
|
||||
|
||||
|
||||
def _preload_embeddings():
|
||||
"""Pre-load the embedding model and build the HNSW index in background."""
|
||||
try:
|
||||
from embeddings import _get_session, _get_reranker, build_index
|
||||
_get_session() # downloads model on first run, loads ONNX session
|
||||
build_index() # builds HNSW index from existing chunks
|
||||
# Preload cross-encoder unless user has explicitly disabled it
|
||||
if get_setting("use_reranker", "1") == "1":
|
||||
_get_reranker()
|
||||
print("Semantic search ready (with reranker).")
|
||||
else:
|
||||
print("Semantic search ready.")
|
||||
except Exception as e:
|
||||
print(f"Semantic search unavailable: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
init_db()
|
||||
threading.Thread(target=_preload_embeddings, daemon=True).start()
|
||||
config_dir = os.environ.get("RNS_CONFIG_DIR")
|
||||
ensure_rns_config(config_dir)
|
||||
reticulum = RNS.Reticulum(configdir=config_dir)
|
||||
|
|
|
|||
108
db.py
108
db.py
|
|
@ -226,6 +226,27 @@ def init_db():
|
|||
db.execute("UPDATE pages SET last_modified = strftime('%Y-%m-%dT%H:%M:%S','now') WHERE last_modified = ''")
|
||||
db.commit()
|
||||
|
||||
# Migrate pages: add summary column if missing
|
||||
if "summary" not in page_cols:
|
||||
db.execute("ALTER TABLE pages ADD COLUMN summary TEXT DEFAULT ''")
|
||||
db.commit()
|
||||
|
||||
# Chunks table for semantic search embeddings
|
||||
db.execute(
|
||||
"CREATE TABLE IF NOT EXISTS chunks ("
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,"
|
||||
" page_id INTEGER,"
|
||||
" remote_page_id INTEGER,"
|
||||
" chunk_index INTEGER NOT NULL,"
|
||||
" chunk_text TEXT NOT NULL,"
|
||||
" embedding BLOB NOT NULL,"
|
||||
" FOREIGN KEY (page_id) REFERENCES pages(id) ON DELETE CASCADE,"
|
||||
" FOREIGN KEY (remote_page_id) REFERENCES remote_pages(id) ON DELETE CASCADE"
|
||||
")"
|
||||
)
|
||||
db.execute("CREATE INDEX IF NOT EXISTS idx_chunks_page ON chunks(page_id)")
|
||||
db.execute("CREATE INDEX IF NOT EXISTS idx_chunks_remote ON chunks(remote_page_id)")
|
||||
|
||||
db.execute("PRAGMA journal_mode=WAL")
|
||||
db.commit()
|
||||
db.close()
|
||||
|
|
@ -296,24 +317,96 @@ def fetch_page(url):
|
|||
label = a.get_text(strip=True) or href
|
||||
links.append((href, label[:200]))
|
||||
|
||||
# Extract meta description before stripping tags
|
||||
meta_desc = ""
|
||||
meta_tag = soup.find("meta", attrs={"name": "description"})
|
||||
if meta_tag and meta_tag.get("content"):
|
||||
meta_desc = meta_tag["content"].strip()
|
||||
if not meta_desc:
|
||||
# Try og:description as fallback
|
||||
og_tag = soup.find("meta", attrs={"property": "og:description"})
|
||||
if og_tag and og_tag.get("content"):
|
||||
meta_desc = og_tag["content"].strip()
|
||||
|
||||
for tag in soup(["script", "style", "nav", "footer", "header"]):
|
||||
tag.decompose()
|
||||
title = soup.title.string.strip() if soup.title and soup.title.string else url
|
||||
body = soup.get_text(separator=" ", strip=True)
|
||||
return title, body, links
|
||||
return title, body, links, meta_desc
|
||||
|
||||
|
||||
def _generate_summary(title, body):
|
||||
"""Generate a summary from body text using centroid extractive method.
|
||||
|
||||
Filters out UI debris, embeds remaining sentences, finds the one
|
||||
closest to the centroid (most representative of the page).
|
||||
"""
|
||||
import re
|
||||
# Split on sentence boundaries
|
||||
raw = re.split(r'(?<=[.!?])\s+', body)
|
||||
sentences = []
|
||||
noise_patterns = re.compile(
|
||||
r'arrow-|fedilink|message-square|link-external|'
|
||||
r'skip to|cookie|subscribe|sign up|log in|'
|
||||
r'privacy policy|terms of|©|\bads?\b',
|
||||
re.IGNORECASE
|
||||
)
|
||||
for s in raw:
|
||||
s = s.strip()
|
||||
if len(s) < 40:
|
||||
continue
|
||||
words = s.split()
|
||||
if len(words) < 7:
|
||||
continue
|
||||
# Skip if mostly non-alpha (icons, arrows, encoded chars)
|
||||
alpha_chars = sum(1 for c in s if c.isalpha() or c == ' ')
|
||||
if alpha_chars < len(s) * 0.6:
|
||||
continue
|
||||
# Skip nav/menu patterns
|
||||
if s.count('|') > 2 or s.count('·') > 2 or s.count('►') > 0:
|
||||
continue
|
||||
# Skip UI debris
|
||||
if noise_patterns.search(s):
|
||||
continue
|
||||
sentences.append(s)
|
||||
|
||||
if not sentences:
|
||||
# Last resort: take the first chunk of body that looks like prose
|
||||
clean = re.sub(r'\s+', ' ', body).strip()
|
||||
return clean[:160] + "..." if len(clean) > 160 else clean
|
||||
if len(sentences) == 1:
|
||||
s = sentences[0]
|
||||
return s[:200] if len(s) > 200 else s
|
||||
try:
|
||||
from embeddings import embed
|
||||
import numpy as np
|
||||
embs = embed(sentences[:50]) # cap to avoid embedding too many
|
||||
centroid = embs.mean(axis=0, keepdims=True)
|
||||
centroid = centroid / max(np.linalg.norm(centroid), 1e-12)
|
||||
scores = (embs @ centroid.T).flatten()
|
||||
best_idx = int(np.argmax(scores))
|
||||
result = sentences[best_idx]
|
||||
# Try to add a second sentence if it fits
|
||||
if best_idx + 1 < len(sentences) and len(result) + len(sentences[best_idx + 1]) + 1 <= 200:
|
||||
result += " " + sentences[best_idx + 1]
|
||||
return result[:200] if len(result) > 200 else result
|
||||
except Exception:
|
||||
return sentences[0][:200]
|
||||
|
||||
|
||||
def index_url(url, note=""):
|
||||
url = clean_url(url)
|
||||
title, body, links = fetch_page(url)
|
||||
title, body, links, meta_desc = fetch_page(url)
|
||||
# Use meta description if available, otherwise generate from body
|
||||
summary = meta_desc if meta_desc else _generate_summary(title, body)
|
||||
db = get_db()
|
||||
try:
|
||||
now = __import__("datetime").datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
||||
db.execute(
|
||||
"INSERT INTO pages (url, title, body, note, last_modified) VALUES (?, ?, ?, ?, ?) "
|
||||
"INSERT INTO pages (url, title, body, note, last_modified, summary) VALUES (?, ?, ?, ?, ?, ?) "
|
||||
"ON CONFLICT(url) DO UPDATE SET title=excluded.title, body=excluded.body, "
|
||||
"note=excluded.note, last_modified=excluded.last_modified",
|
||||
(url, title, body, note, now),
|
||||
"note=excluded.note, last_modified=excluded.last_modified, summary=excluded.summary",
|
||||
(url, title, body, note, now, summary),
|
||||
)
|
||||
page_id = db.execute("SELECT id FROM pages WHERE url = ?", (url,)).fetchone()[0]
|
||||
db.execute("DELETE FROM links WHERE page_id = ?", (page_id,))
|
||||
|
|
@ -323,6 +416,11 @@ def index_url(url, note=""):
|
|||
(page_id, href, label),
|
||||
)
|
||||
db.commit()
|
||||
try:
|
||||
from embeddings import store_embeddings
|
||||
store_embeddings(page_id, title, body, db)
|
||||
except Exception:
|
||||
pass # embedding generation is best-effort
|
||||
finally:
|
||||
return_db(db)
|
||||
return title
|
||||
|
|
|
|||
553
embeddings.py
Normal file
553
embeddings.py
Normal file
|
|
@ -0,0 +1,553 @@
|
|||
"""Semantic search using Snowflake arctic-embed-s via ONNX Runtime + hnswlib."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
MODEL_ID = "Snowflake/snowflake-arctic-embed-s"
|
||||
MODEL_DIR = os.path.join(os.path.dirname(__file__), "models", "snowflake-arctic-embed-s")
|
||||
RERANKER_DIR = os.path.join(os.path.dirname(__file__), "models", "cross-encoder")
|
||||
HNSW_PATH = os.path.join(os.path.dirname(__file__), "index.hnsw")
|
||||
DIMS = 384
|
||||
MAX_TOKENS = 512
|
||||
QUERY_PREFIX = "Represent this sentence for searching relevant passages: "
|
||||
|
||||
_session = None
|
||||
_tokenizer = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
_reranker_session = None
|
||||
_reranker_tokenizer = None
|
||||
_reranker_lock = threading.Lock()
|
||||
|
||||
# Live HNSW index and chunk-id mapping
|
||||
_hnsw_index = None
|
||||
_hnsw_ids = [] # maps internal HNSW label -> chunks.id
|
||||
_hnsw_lock = threading.Lock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model download & loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_model():
|
||||
"""Download the ONNX model and tokenizer from HuggingFace if not present."""
|
||||
model_path = os.path.join(MODEL_DIR, "model.onnx")
|
||||
tokenizer_path = os.path.join(MODEL_DIR, "tokenizer.json")
|
||||
if os.path.exists(model_path) and os.path.exists(tokenizer_path):
|
||||
return
|
||||
from huggingface_hub import hf_hub_download
|
||||
os.makedirs(MODEL_DIR, exist_ok=True)
|
||||
files = {
|
||||
"onnx/model_quantized.onnx": "model.onnx",
|
||||
"tokenizer.json": "tokenizer.json",
|
||||
"tokenizer_config.json": "tokenizer_config.json",
|
||||
}
|
||||
for remote, local in files.items():
|
||||
target = os.path.join(MODEL_DIR, local)
|
||||
if os.path.exists(target):
|
||||
continue
|
||||
cached = hf_hub_download(repo_id=MODEL_ID, filename=remote)
|
||||
# hf_hub_download returns the cached file path; copy to our model dir
|
||||
import shutil
|
||||
shutil.copy2(cached, target)
|
||||
|
||||
|
||||
def _get_session():
|
||||
"""Return (onnxruntime.InferenceSession, tokenizers.Tokenizer) singleton."""
|
||||
global _session, _tokenizer
|
||||
if _session is not None:
|
||||
return _session, _tokenizer
|
||||
with _lock:
|
||||
if _session is not None:
|
||||
return _session, _tokenizer
|
||||
_ensure_model()
|
||||
import onnxruntime as ort
|
||||
from tokenizers import Tokenizer
|
||||
_session = ort.InferenceSession(
|
||||
os.path.join(MODEL_DIR, "model.onnx"),
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
_tokenizer = Tokenizer.from_file(os.path.join(MODEL_DIR, "tokenizer.json"))
|
||||
_tokenizer.enable_truncation(max_length=MAX_TOKENS)
|
||||
_tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=None)
|
||||
return _session, _tokenizer
|
||||
|
||||
|
||||
def _get_reranker():
|
||||
"""Return (onnxruntime.InferenceSession, tokenizers.Tokenizer) for the cross-encoder reranker."""
|
||||
global _reranker_session, _reranker_tokenizer
|
||||
if _reranker_session is not None:
|
||||
return _reranker_session, _reranker_tokenizer
|
||||
with _reranker_lock:
|
||||
if _reranker_session is not None:
|
||||
return _reranker_session, _reranker_tokenizer
|
||||
model_path = os.path.join(RERANKER_DIR, "model.onnx")
|
||||
tok_path = os.path.join(RERANKER_DIR, "tokenizer.json")
|
||||
if not os.path.exists(model_path) or not os.path.exists(tok_path):
|
||||
return None, None
|
||||
import onnxruntime as ort
|
||||
from tokenizers import Tokenizer
|
||||
_reranker_session = ort.InferenceSession(
|
||||
model_path, providers=["CPUExecutionProvider"],
|
||||
)
|
||||
_reranker_tokenizer = Tokenizer.from_file(tok_path)
|
||||
_reranker_tokenizer.enable_truncation(max_length=512)
|
||||
_reranker_tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=None)
|
||||
return _reranker_session, _reranker_tokenizer
|
||||
|
||||
|
||||
def rerank(query, documents, limit=10):
|
||||
"""Score query-document pairs with the cross-encoder and return reranked indices.
|
||||
|
||||
Args:
|
||||
query: search query string
|
||||
documents: list of document texts to score against the query
|
||||
limit: max results to return
|
||||
|
||||
Returns: list of (original_index, score) sorted by score descending.
|
||||
"""
|
||||
session, tokenizer = _get_reranker()
|
||||
if session is None:
|
||||
return [(i, 0.0) for i in range(min(limit, len(documents)))]
|
||||
|
||||
# Cross-encoder takes (query, document) pairs — encode as pair sequences
|
||||
pairs = [[query, doc] for doc in documents]
|
||||
encodings = tokenizer.encode_batch(pairs)
|
||||
|
||||
input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
|
||||
attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
|
||||
token_type_ids = np.array([e.type_ids for e in encodings], dtype=np.int64)
|
||||
|
||||
outputs = session.run(
|
||||
None,
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
},
|
||||
)
|
||||
# Output is logits — higher = more relevant
|
||||
scores = outputs[0].flatten()
|
||||
ranked = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
||||
return [(i, float(scores[i])) for i in ranked[:limit]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text chunking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+')
|
||||
|
||||
|
||||
def chunk_text(title, body):
|
||||
"""Split body into chunks, each prefixed with title for context.
|
||||
|
||||
Strategy: split on double newlines (paragraphs). If a paragraph exceeds
|
||||
MAX_TOKENS words, split at sentence boundaries. Each chunk is prefixed
|
||||
with the page title.
|
||||
"""
|
||||
if not body or not body.strip():
|
||||
return [f"{title}"] if title else []
|
||||
|
||||
prefix = f"{title}: " if title else ""
|
||||
# Rough word budget for chunk body (leave room for prefix)
|
||||
prefix_words = len(prefix.split())
|
||||
max_words = MAX_TOKENS - prefix_words # approximate; tokenizer may differ
|
||||
|
||||
paragraphs = re.split(r'\n\s*\n', body.strip())
|
||||
chunks = []
|
||||
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if len(para) < 20:
|
||||
continue
|
||||
words = para.split()
|
||||
if len(words) <= max_words:
|
||||
chunks.append(prefix + para)
|
||||
else:
|
||||
# Split paragraph into sentences
|
||||
sentences = _SENTENCE_RE.split(para)
|
||||
current = []
|
||||
current_len = 0
|
||||
for sent in sentences:
|
||||
sent_words = len(sent.split())
|
||||
if current_len + sent_words > max_words and current:
|
||||
chunks.append(prefix + " ".join(current))
|
||||
current = []
|
||||
current_len = 0
|
||||
if sent_words > max_words:
|
||||
# Sentence too long — use sliding window
|
||||
s_words = sent.split()
|
||||
for i in range(0, len(s_words), max_words - 50):
|
||||
window = s_words[i:i + max_words]
|
||||
chunks.append(prefix + " ".join(window))
|
||||
else:
|
||||
current.append(sent)
|
||||
current_len += sent_words
|
||||
if current:
|
||||
chunks.append(prefix + " ".join(current))
|
||||
|
||||
if not chunks and title:
|
||||
chunks = [title]
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Embedding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def embed(texts, is_query=False):
|
||||
"""Encode texts into L2-normalized float32 embeddings (N, 384).
|
||||
|
||||
For queries, prepend the model's query prefix.
|
||||
Processes in batches of 32 to limit memory usage.
|
||||
"""
|
||||
if not texts:
|
||||
return np.empty((0, DIMS), dtype=np.float32)
|
||||
|
||||
session, tokenizer = _get_session()
|
||||
|
||||
if is_query:
|
||||
texts = [QUERY_PREFIX + t for t in texts]
|
||||
|
||||
batch_size = 32
|
||||
all_embeddings = []
|
||||
|
||||
for start in range(0, len(texts), batch_size):
|
||||
batch = texts[start:start + batch_size]
|
||||
encodings = tokenizer.encode_batch(batch)
|
||||
input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
|
||||
attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
|
||||
token_type_ids = np.zeros_like(input_ids)
|
||||
|
||||
outputs = session.run(
|
||||
None,
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
},
|
||||
)
|
||||
# CLS token pooling — take the first token's hidden state
|
||||
emb = outputs[0][:, 0, :]
|
||||
all_embeddings.append(emb)
|
||||
|
||||
embeddings = np.concatenate(all_embeddings, axis=0)
|
||||
# L2 normalize
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
norms = np.maximum(norms, 1e-12)
|
||||
embeddings = embeddings / norms
|
||||
return embeddings.astype(np.float32)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HNSW index management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_index(db=None):
|
||||
"""Load all embeddings from chunks table and build HNSW index."""
|
||||
import hnswlib
|
||||
global _hnsw_index, _hnsw_ids
|
||||
|
||||
from db import get_db, return_db
|
||||
own_db = db is None
|
||||
if own_db:
|
||||
db = get_db()
|
||||
try:
|
||||
rows = db.execute("SELECT id, embedding FROM chunks ORDER BY id").fetchall()
|
||||
finally:
|
||||
if own_db:
|
||||
return_db(db)
|
||||
|
||||
with _hnsw_lock:
|
||||
if not rows:
|
||||
_hnsw_index = None
|
||||
_hnsw_ids = []
|
||||
return
|
||||
|
||||
n = len(rows)
|
||||
ids = [r["id"] for r in rows]
|
||||
matrix = np.frombuffer(b"".join(r["embedding"] for r in rows), dtype=np.float32).reshape(n, DIMS)
|
||||
|
||||
index = hnswlib.Index(space="cosine", dim=DIMS)
|
||||
# ef_construction and M balance build speed vs recall;
|
||||
# these defaults give >99% recall at reasonable build time
|
||||
index.init_index(max_elements=max(n, 1024), ef_construction=200, M=16)
|
||||
index.add_items(matrix, list(range(n)))
|
||||
index.set_ef(50) # query-time accuracy parameter
|
||||
|
||||
_hnsw_index = index
|
||||
_hnsw_ids = ids
|
||||
|
||||
|
||||
def _add_to_index(chunk_ids, embeddings_matrix):
|
||||
"""Add new embeddings to the live HNSW index."""
|
||||
import hnswlib
|
||||
global _hnsw_index, _hnsw_ids
|
||||
|
||||
with _hnsw_lock:
|
||||
if _hnsw_index is None:
|
||||
index = hnswlib.Index(space="cosine", dim=DIMS)
|
||||
index.init_index(max_elements=1024, ef_construction=200, M=16)
|
||||
index.set_ef(50)
|
||||
_hnsw_index = index
|
||||
_hnsw_ids = []
|
||||
|
||||
current_max = _hnsw_index.get_max_elements()
|
||||
needed = len(_hnsw_ids) + len(chunk_ids)
|
||||
if needed > current_max:
|
||||
_hnsw_index.resize_index(max(needed * 2, current_max * 2))
|
||||
|
||||
labels = list(range(len(_hnsw_ids), len(_hnsw_ids) + len(chunk_ids)))
|
||||
_hnsw_index.add_items(embeddings_matrix, labels)
|
||||
_hnsw_ids.extend(chunk_ids)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Store embeddings for pages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def store_embeddings(page_id, title, body, db):
|
||||
"""Chunk, embed, and store embeddings for a page. Adds to HNSW index."""
|
||||
chunks = chunk_text(title, body)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
embeddings_matrix = embed(chunks)
|
||||
|
||||
# Delete old chunks for this page
|
||||
db.execute("DELETE FROM chunks WHERE page_id = ?", (page_id,))
|
||||
|
||||
new_ids = []
|
||||
for i, (text, emb) in enumerate(zip(chunks, embeddings_matrix)):
|
||||
cursor = db.execute(
|
||||
"INSERT INTO chunks (page_id, remote_page_id, chunk_index, chunk_text, embedding) "
|
||||
"VALUES (?, NULL, ?, ?, ?)",
|
||||
(page_id, i, text, emb.tobytes()),
|
||||
)
|
||||
new_ids.append(cursor.lastrowid)
|
||||
db.commit()
|
||||
|
||||
_add_to_index(new_ids, embeddings_matrix)
|
||||
|
||||
|
||||
def store_remote_embeddings(remote_page_id, title, note, db):
|
||||
"""Store a single embedding for a remote page (title + note)."""
|
||||
text = f"{title}: {note}" if note else (title or "")
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
embeddings_matrix = embed([text])
|
||||
|
||||
db.execute("DELETE FROM chunks WHERE remote_page_id = ?", (remote_page_id,))
|
||||
cursor = db.execute(
|
||||
"INSERT INTO chunks (page_id, remote_page_id, chunk_index, chunk_text, embedding) "
|
||||
"VALUES (NULL, ?, 0, ?, ?)",
|
||||
(remote_page_id, text, embeddings_matrix[0].tobytes()),
|
||||
)
|
||||
db.commit()
|
||||
|
||||
_add_to_index([cursor.lastrowid], embeddings_matrix)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Search
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def semantic_search(query_text, limit=100, db=None):
|
||||
"""Search for pages by semantic similarity.
|
||||
|
||||
Returns: [(page_id, score, best_chunk_text), ...] sorted by score desc.
|
||||
Groups by page_id, taking the max chunk score per page.
|
||||
"""
|
||||
if _hnsw_index is None or not _hnsw_ids:
|
||||
return []
|
||||
|
||||
query_emb = embed([query_text], is_query=True)
|
||||
|
||||
with _hnsw_lock:
|
||||
if _hnsw_index is None or not _hnsw_ids:
|
||||
return []
|
||||
k = min(limit * 3, len(_hnsw_ids)) # oversample to account for grouping
|
||||
if k == 0:
|
||||
return []
|
||||
labels, distances = _hnsw_index.knn_query(query_emb, k=k)
|
||||
|
||||
# Map HNSW labels back to chunk IDs
|
||||
chunk_ids = [_hnsw_ids[int(lbl)] for lbl in labels[0]]
|
||||
# cosine distance -> similarity: hnswlib returns 1-cosine for "cosine" space
|
||||
scores = [1.0 - float(d) for d in distances[0]]
|
||||
|
||||
# Fetch chunk details from DB
|
||||
from db import get_db, return_db
|
||||
own_db = db is None
|
||||
if own_db:
|
||||
db = get_db()
|
||||
try:
|
||||
placeholders = ",".join("?" * len(chunk_ids))
|
||||
rows = db.execute(
|
||||
f"SELECT id, page_id, chunk_text FROM chunks WHERE id IN ({placeholders})",
|
||||
chunk_ids,
|
||||
).fetchall()
|
||||
finally:
|
||||
if own_db:
|
||||
return_db(db)
|
||||
|
||||
chunk_map = {r["id"]: r for r in rows}
|
||||
|
||||
# Group by page_id, keep best score and chunk text per page
|
||||
page_best = {} # page_id -> (score, chunk_text)
|
||||
for cid, score in zip(chunk_ids, scores):
|
||||
chunk = chunk_map.get(cid)
|
||||
if not chunk or chunk["page_id"] is None:
|
||||
continue
|
||||
pid = chunk["page_id"]
|
||||
if pid not in page_best or score > page_best[pid][0]:
|
||||
page_best[pid] = (score, chunk["chunk_text"])
|
||||
|
||||
results = [(pid, score, text) for pid, (score, text) in page_best.items()]
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
return results[:limit]
|
||||
|
||||
|
||||
def hybrid_search(query_text, bm25_ranked_ids, limit=10, db=None, use_reranker=False):
|
||||
"""Merge BM25 and semantic results via RRF, optionally rerank with cross-encoder.
|
||||
|
||||
Default (two-stage): BM25 + semantic fused via RRF.
|
||||
With use_reranker=True (three-stage): rerank top 20 with cross-encoder.
|
||||
|
||||
Returns: [(page_id, best_chunk_text), ...] in ranked order.
|
||||
"""
|
||||
k = 60 # RRF constant
|
||||
|
||||
sem_results = semantic_search(query_text, limit=100, db=db)
|
||||
|
||||
best_chunks = {} # page_id -> chunk_text
|
||||
for _rank, (pid, _score, chunk_text) in enumerate(sem_results):
|
||||
if pid not in best_chunks:
|
||||
best_chunks[pid] = chunk_text
|
||||
|
||||
# When BM25 has no hits, use raw semantic similarity scores directly
|
||||
# (RRF rank positions distort nearly-equal scores)
|
||||
if not bm25_ranked_ids:
|
||||
fused_ids = [(pid, score) for pid, score, _ in sem_results]
|
||||
else:
|
||||
rrf_scores = {}
|
||||
for rank, pid in enumerate(bm25_ranked_ids):
|
||||
rrf_scores[pid] = rrf_scores.get(pid, 0) + 1.0 / (k + rank + 1)
|
||||
for rank, (pid, _score, chunk_text) in enumerate(sem_results):
|
||||
rrf_scores[pid] = rrf_scores.get(pid, 0) + 1.0 / (k + rank + 1)
|
||||
fused_ids = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
fused = fused_ids
|
||||
all_ids = [pid for pid, _ in fused]
|
||||
|
||||
if not all_ids:
|
||||
return []
|
||||
|
||||
if not use_reranker:
|
||||
return [(pid, best_chunks.get(pid, "")) for pid in all_ids[:limit]]
|
||||
|
||||
# --- Rerank top 20, append next 10 from RRF order ---
|
||||
rerank_ids = all_ids[:20]
|
||||
tail_ids = all_ids[20:30]
|
||||
|
||||
from db import get_db, return_db
|
||||
own_db = db is None
|
||||
if own_db:
|
||||
db = get_db()
|
||||
try:
|
||||
placeholders = ",".join("?" * len(rerank_ids))
|
||||
rows = db.execute(
|
||||
f"SELECT id, title, body FROM pages WHERE id IN ({placeholders})",
|
||||
rerank_ids,
|
||||
).fetchall()
|
||||
finally:
|
||||
if own_db:
|
||||
return_db(db)
|
||||
|
||||
page_map = {r["id"]: r for r in rows}
|
||||
|
||||
doc_texts = []
|
||||
ordered_ids = []
|
||||
for pid in rerank_ids:
|
||||
page = page_map.get(pid)
|
||||
if not page:
|
||||
continue
|
||||
chunk = best_chunks.get(pid, "")
|
||||
body_preview = chunk[:200] if chunk else page["body"][:200]
|
||||
doc = f"{page['title']}. {body_preview}"
|
||||
doc_texts.append(doc)
|
||||
ordered_ids.append(pid)
|
||||
|
||||
if not doc_texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
reranked = rerank(query_text, doc_texts, limit=20)
|
||||
results = [(ordered_ids[idx], best_chunks.get(ordered_ids[idx], "")) for idx, _score in reranked]
|
||||
except Exception:
|
||||
results = [(pid, best_chunks.get(pid, "")) for pid in ordered_ids[:20]]
|
||||
|
||||
# Append next 10 from RRF order (no reranking)
|
||||
reranked_set = {pid for pid, _ in results}
|
||||
for pid in tail_ids:
|
||||
if pid not in reranked_set:
|
||||
results.append((pid, best_chunks.get(pid, "")))
|
||||
|
||||
return results[:30]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reindex
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def reindex_all(db=None, progress_callback=None):
|
||||
"""Embed all pages that don't yet have chunks. Also generates missing summaries. Rebuilds HNSW index."""
|
||||
from db import get_db, return_db, _generate_summary
|
||||
own_db = db is None
|
||||
if own_db:
|
||||
db = get_db()
|
||||
try:
|
||||
rows = db.execute(
|
||||
"SELECT p.id, p.title, p.body, p.summary FROM pages p "
|
||||
"WHERE p.id NOT IN (SELECT DISTINCT page_id FROM chunks WHERE page_id IS NOT NULL)"
|
||||
).fetchall()
|
||||
|
||||
total = len(rows)
|
||||
for i, row in enumerate(rows):
|
||||
store_embeddings(row["id"], row["title"], row["body"], db)
|
||||
# Generate summary if missing
|
||||
if not row["summary"]:
|
||||
summary = _generate_summary(row["title"], row["body"])
|
||||
db.execute("UPDATE pages SET summary = ? WHERE id = ?", (summary, row["id"]))
|
||||
db.commit()
|
||||
if progress_callback:
|
||||
progress_callback(i + 1, total)
|
||||
|
||||
# Generate summaries for pages that already have chunks but no summary
|
||||
no_summary = db.execute(
|
||||
"SELECT id, title, body FROM pages WHERE summary = '' OR summary IS NULL"
|
||||
).fetchall()
|
||||
for row in no_summary:
|
||||
summary = _generate_summary(row["title"], row["body"])
|
||||
db.execute("UPDATE pages SET summary = ? WHERE id = ?", (summary, row["id"]))
|
||||
if no_summary:
|
||||
db.commit()
|
||||
|
||||
# Also handle remote pages
|
||||
remote_rows = db.execute(
|
||||
"SELECT rp.id, rp.title, rp.note FROM remote_pages rp "
|
||||
"WHERE rp.id NOT IN (SELECT DISTINCT remote_page_id FROM chunks WHERE remote_page_id IS NOT NULL)"
|
||||
).fetchall()
|
||||
|
||||
for rp in remote_rows:
|
||||
store_remote_embeddings(rp["id"], rp["title"], rp["note"], db)
|
||||
finally:
|
||||
if own_db:
|
||||
return_db(db)
|
||||
|
||||
build_index(db)
|
||||
169
handlers.py
169
handlers.py
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import re
|
||||
import secrets
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
|
@ -27,10 +28,41 @@ def _check_csrf(body):
|
|||
return secrets.compare_digest(token, expected)
|
||||
|
||||
|
||||
_STOPWORDS = frozenset({
|
||||
"a", "an", "the", "and", "or", "but", "is", "are", "was", "were",
|
||||
"in", "on", "at", "to", "for", "of", "with", "by", "from", "as",
|
||||
"into", "about", "how", "what", "which", "who", "where", "when",
|
||||
"do", "does", "did", "be", "been", "being", "have", "has", "had",
|
||||
"it", "its", "this", "that", "not", "no", "so", "if", "can", "will",
|
||||
"my", "your", "i", "me", "we", "you", "he", "she", "they",
|
||||
})
|
||||
|
||||
|
||||
def _sanitize_fts_query(query):
|
||||
"""Escape user input for safe use in FTS5 MATCH."""
|
||||
escaped = query.replace('"', '""')
|
||||
return f'"{escaped}"'
|
||||
"""Escape user input for safe use in FTS5 MATCH.
|
||||
|
||||
Splits into individual quoted tokens joined by implicit AND,
|
||||
so all words must appear but in any order. Appends * to the
|
||||
last token for prefix matching. Stopwords are dropped to avoid
|
||||
overly strict matching.
|
||||
"""
|
||||
words = query.split()
|
||||
if not words:
|
||||
return '""'
|
||||
tokens = []
|
||||
for i, w in enumerate(words):
|
||||
# Strip FTS5 special characters to prevent injection
|
||||
cleaned = re.sub(r'["\'\(\)\*\+\-\^~]', '', w).strip()
|
||||
if not cleaned:
|
||||
continue
|
||||
if cleaned.lower() in _STOPWORDS:
|
||||
continue
|
||||
if i == len(words) - 1:
|
||||
# Prefix match on the last token for partial word matching
|
||||
tokens.append(f"{cleaned}*")
|
||||
else:
|
||||
tokens.append(f'"{cleaned}"')
|
||||
return " ".join(tokens) if tokens else '""'
|
||||
|
||||
|
||||
def _get_bookmark_token():
|
||||
|
|
@ -155,20 +187,46 @@ def handle_search(query):
|
|||
result_html = ""
|
||||
trusted_html = ""
|
||||
if q:
|
||||
# BM25 keyword search with column weights: title=10, body=1, url=5, note=3
|
||||
try:
|
||||
total_results = db.execute(
|
||||
"SELECT count(*) FROM pages_fts WHERE pages_fts MATCH ?",
|
||||
(_sanitize_fts_query(q),),
|
||||
).fetchone()[0]
|
||||
rows = db.execute(
|
||||
fts_q = _sanitize_fts_query(q)
|
||||
bm25_rows = db.execute(
|
||||
"SELECT p.id, p.url, p.title, p.body, p.note "
|
||||
"FROM pages_fts f JOIN pages p ON f.rowid = p.id "
|
||||
"WHERE pages_fts MATCH ? ORDER BY rank LIMIT ? OFFSET ?",
|
||||
(_sanitize_fts_query(q), PER_PAGE, offset),
|
||||
"WHERE pages_fts MATCH ? "
|
||||
"ORDER BY bm25(pages_fts, 10.0, 1.0, 5.0, 3.0) LIMIT 100",
|
||||
(fts_q,),
|
||||
).fetchall()
|
||||
except Exception:
|
||||
bm25_rows = []
|
||||
|
||||
# Hybrid search: merge BM25 + semantic via RRF
|
||||
bm25_ids = [r["id"] for r in bm25_rows]
|
||||
chunk_snippets = {} # page_id -> best chunk text
|
||||
try:
|
||||
from embeddings import hybrid_search
|
||||
use_reranker = get_setting("use_reranker", "1") == "1"
|
||||
fused = hybrid_search(q, bm25_ids, limit=100, db=db, use_reranker=use_reranker)
|
||||
fused_ids = [pid for pid, _ in fused]
|
||||
chunk_snippets = {pid: text for pid, text in fused if text}
|
||||
except Exception:
|
||||
fused_ids = bm25_ids
|
||||
|
||||
total_results = len(fused_ids)
|
||||
page_ids = fused_ids[offset:offset + PER_PAGE]
|
||||
|
||||
if page_ids:
|
||||
# Fetch rows in fused order
|
||||
placeholders = ",".join("?" * len(page_ids))
|
||||
all_rows = db.execute(
|
||||
f"SELECT id, url, title, body, note, summary FROM pages WHERE id IN ({placeholders})",
|
||||
page_ids,
|
||||
).fetchall()
|
||||
row_map = {r["id"]: r for r in all_rows}
|
||||
rows = [row_map[pid] for pid in page_ids if pid in row_map]
|
||||
else:
|
||||
rows = []
|
||||
total_results = 0
|
||||
|
||||
if rows:
|
||||
for r in rows:
|
||||
note_html = ""
|
||||
|
|
@ -179,11 +237,13 @@ def handle_search(query):
|
|||
if tags:
|
||||
tag_links = " ".join(f'<a href="/tags/{esc(t)}" class="tag">[{esc(t)}]</a>' for t in tags)
|
||||
tags_html = f'<div class="tags">{tag_links}</div>'
|
||||
# Use page summary as snippet (meta description or centroid sentence)
|
||||
snip = r["summary"] if r["summary"] else snippet(r["body"], q)
|
||||
result_html += (
|
||||
f'<div class="result">'
|
||||
f'<a href="{esc(r["url"])}">{esc(r["title"])}</a><br>'
|
||||
f'<small>{esc(r["url"])}</small><br>'
|
||||
f'{esc(snippet(r["body"], q))}'
|
||||
f'{esc(snip)}'
|
||||
f'{note_html}{tags_html}'
|
||||
f'</div>'
|
||||
)
|
||||
|
|
@ -495,6 +555,8 @@ def handle_style_form(msg=""):
|
|||
name = get_site_name()
|
||||
sharing = get_setting("sharing_enabled", "0")
|
||||
checked = " checked" if sharing == "1" else ""
|
||||
reranker = get_setting("use_reranker", "1")
|
||||
reranker_checked = " checked" if reranker == "1" else ""
|
||||
return _respond(
|
||||
f"<h1>customize</h1>"
|
||||
f"<h2>name your search engine</h2>"
|
||||
|
|
@ -504,6 +566,10 @@ def handle_style_form(msg=""):
|
|||
f"<h2>sharing</h2>"
|
||||
f'<label><input type="checkbox" name="sharing_enabled" value="1"{checked}>'
|
||||
f" share your site list publicly at /api/sites</label><br><br>"
|
||||
f"<h2>search</h2>"
|
||||
f'<label><input type="checkbox" name="use_reranker" value="1"{reranker_checked}>'
|
||||
f" cross-encoder reranking (more accurate, on by default)</label><br>"
|
||||
f"<small>Uses a 22MB model. Adds ~50ms per search. Disable for faster results.</small><br><br>"
|
||||
f"<h2>custom html</h2>"
|
||||
f"<p>Edit the full page template. Use <code>{esc('{{content}}')}</code> "
|
||||
f"where page content should appear.</p>"
|
||||
|
|
@ -528,9 +594,11 @@ def handle_style_submit(body):
|
|||
template = body.get("template", [""])[0].replace("\r\n", "\n").replace("\r", "\n")
|
||||
name = body.get("site_name", ["tinyweb"])[0].strip()
|
||||
sharing = "1" if body.get("sharing_enabled") else "0"
|
||||
reranker = "1" if body.get("use_reranker") else "0"
|
||||
set_setting("custom_template", template if template.strip() != DEFAULT_TEMPLATE.strip() else "")
|
||||
set_setting("site_name", name or "tinyweb")
|
||||
set_setting("sharing_enabled", sharing)
|
||||
set_setting("use_reranker", reranker)
|
||||
return handle_style_form("Saved.")
|
||||
|
||||
|
||||
|
|
@ -904,6 +972,16 @@ def handle_subscription_sync(sub_id):
|
|||
"ON CONFLICT(subscription_id, url) DO UPDATE SET title=excluded.title, note=excluded.note, tags=excluded.tags",
|
||||
(sub_id, s["url"], s["title"], s.get("note", ""), tags_str),
|
||||
)
|
||||
# Embed remote page for semantic search
|
||||
try:
|
||||
from embeddings import store_remote_embeddings
|
||||
rp_id = db.execute(
|
||||
"SELECT id FROM remote_pages WHERE subscription_id = ? AND url = ?",
|
||||
(sub_id, s["url"]),
|
||||
).fetchone()["id"]
|
||||
store_remote_embeddings(rp_id, s["title"], s.get("note", ""), db)
|
||||
except Exception:
|
||||
pass
|
||||
synced += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -970,6 +1048,15 @@ def handle_subscription_syncall():
|
|||
"ON CONFLICT(subscription_id, url) DO UPDATE SET title=excluded.title, note=excluded.note, tags=excluded.tags",
|
||||
(sub["id"], s["url"], s["title"], s.get("note", ""), tags_str),
|
||||
)
|
||||
try:
|
||||
from embeddings import store_remote_embeddings
|
||||
rp_id = db.execute(
|
||||
"SELECT id FROM remote_pages WHERE subscription_id = ? AND url = ?",
|
||||
(sub["id"], s["url"]),
|
||||
).fetchone()["id"]
|
||||
store_remote_embeddings(rp_id, s["title"], s.get("note", ""), db)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
now = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
||||
|
|
@ -983,6 +1070,60 @@ def handle_subscription_syncall():
|
|||
return handle_subscriptions(f"Synced {total} subscription(s).")
|
||||
|
||||
|
||||
# --- Reindex (semantic search) ---
|
||||
|
||||
|
||||
_reindex_thread = None
|
||||
|
||||
|
||||
def handle_reindex_form():
|
||||
db = get_db()
|
||||
try:
|
||||
total_pages = db.execute("SELECT count(*) FROM pages").fetchone()[0]
|
||||
pages_with_chunks = db.execute(
|
||||
"SELECT count(DISTINCT page_id) FROM chunks WHERE page_id IS NOT NULL"
|
||||
).fetchone()[0]
|
||||
finally:
|
||||
return_db(db)
|
||||
progress = get_setting("reindex_progress", "")
|
||||
status_html = ""
|
||||
if progress:
|
||||
status_html = f'<p class="meta">Reindex in progress: {esc(progress)}</p>'
|
||||
elif _reindex_thread and _reindex_thread.is_alive():
|
||||
status_html = '<p class="meta">Reindex running...</p>'
|
||||
return _respond(
|
||||
f"<h2>semantic search index</h2>"
|
||||
f"<p>{pages_with_chunks} of {total_pages} pages have embeddings.</p>"
|
||||
f'{status_html}'
|
||||
f'<form method="post" action="/reindex">'
|
||||
f'{_csrf_field()}'
|
||||
f'<button type="submit">reindex all pages</button>'
|
||||
f'</form>'
|
||||
f'<p><a href="/">back to search</a></p>'
|
||||
)
|
||||
|
||||
|
||||
def handle_reindex_submit(body):
|
||||
global _reindex_thread
|
||||
if _reindex_thread and _reindex_thread.is_alive():
|
||||
return handle_reindex_form()
|
||||
|
||||
def _run():
|
||||
try:
|
||||
from embeddings import reindex_all
|
||||
def progress(current, total):
|
||||
set_setting("reindex_progress", f"{current}/{total}")
|
||||
reindex_all(progress_callback=progress)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
set_setting("reindex_progress", "")
|
||||
|
||||
_reindex_thread = threading.Thread(target=_run, daemon=True)
|
||||
_reindex_thread.start()
|
||||
return _redirect("/reindex")
|
||||
|
||||
|
||||
# --- Dispatcher ---
|
||||
|
||||
|
||||
|
|
@ -1027,6 +1168,8 @@ def _dispatch_inner(data):
|
|||
elif path.startswith("/tags/"):
|
||||
tag_name = unquote(path[len("/tags/"):])
|
||||
return handle_tag_browse(tag_name, query) if tag_name else _error(400)
|
||||
elif path == "/reindex":
|
||||
return handle_reindex_form()
|
||||
elif path == "/api/sites":
|
||||
return handle_api_sites(query)
|
||||
elif path == "/subscriptions":
|
||||
|
|
@ -1052,6 +1195,8 @@ def _dispatch_inner(data):
|
|||
return handle_style_form("Template reset to default.")
|
||||
elif path == "/import":
|
||||
return handle_import_submit(body)
|
||||
elif path == "/reindex":
|
||||
return handle_reindex_submit(body)
|
||||
elif path == "/subscriptions/add":
|
||||
return handle_subscription_add(body)
|
||||
elif path == "/subscriptions/pick":
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
requests
|
||||
beautifulsoup4
|
||||
rns
|
||||
onnxruntime
|
||||
tokenizers
|
||||
hnswlib
|
||||
numpy
|
||||
huggingface_hub
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue