tinyweb/embeddings.py
lichenblankie 5ded9f1339 added hybrid semantic search with reranking
Implements a three-stage search pipeline:
1. BM25 keyword search via FTS5 with column weights
2. Semantic search via Snowflake arctic-embed-s bi-encoder + HNSW index
3. Optional cross-encoder reranking (on by default, toggleable in settings)

Top 20 results are reranked for precision, next 10 appended from RRF
for coverage, giving 30 total results across 3 pages.

- New embeddings.py with ONNX Runtime inference, text chunking, HNSW
  index management, RRF fusion, and cross-encoder reranking
- Meta description extraction for authentic page snippets with centroid
  extractive fallback
- Stopword filtering in FTS5 queries to avoid overly strict matching
- /reindex page for batch embedding of existing pages
- Semantic embedding of remote pages during subscription sync
- ~125MB dependency footprint (onnxruntime, tokenizers, hnswlib, numpy)
- Models: 34MB bi-encoder + 22MB cross-encoder (downloaded on first use)
2026-06-05 05:29:35 +00:00

553 lines
19 KiB
Python

"""Semantic search using Snowflake arctic-embed-s via ONNX Runtime + hnswlib."""
import os
import re
import threading
import numpy as np
MODEL_ID = "Snowflake/snowflake-arctic-embed-s"
MODEL_DIR = os.path.join(os.path.dirname(__file__), "models", "snowflake-arctic-embed-s")
RERANKER_DIR = os.path.join(os.path.dirname(__file__), "models", "cross-encoder")
HNSW_PATH = os.path.join(os.path.dirname(__file__), "index.hnsw")
DIMS = 384
MAX_TOKENS = 512
QUERY_PREFIX = "Represent this sentence for searching relevant passages: "
_session = None
_tokenizer = None
_lock = threading.Lock()
_reranker_session = None
_reranker_tokenizer = None
_reranker_lock = threading.Lock()
# Live HNSW index and chunk-id mapping
_hnsw_index = None
_hnsw_ids = [] # maps internal HNSW label -> chunks.id
_hnsw_lock = threading.Lock()
# ---------------------------------------------------------------------------
# Model download & loading
# ---------------------------------------------------------------------------
def _ensure_model():
"""Download the ONNX model and tokenizer from HuggingFace if not present."""
model_path = os.path.join(MODEL_DIR, "model.onnx")
tokenizer_path = os.path.join(MODEL_DIR, "tokenizer.json")
if os.path.exists(model_path) and os.path.exists(tokenizer_path):
return
from huggingface_hub import hf_hub_download
os.makedirs(MODEL_DIR, exist_ok=True)
files = {
"onnx/model_quantized.onnx": "model.onnx",
"tokenizer.json": "tokenizer.json",
"tokenizer_config.json": "tokenizer_config.json",
}
for remote, local in files.items():
target = os.path.join(MODEL_DIR, local)
if os.path.exists(target):
continue
cached = hf_hub_download(repo_id=MODEL_ID, filename=remote)
# hf_hub_download returns the cached file path; copy to our model dir
import shutil
shutil.copy2(cached, target)
def _get_session():
"""Return (onnxruntime.InferenceSession, tokenizers.Tokenizer) singleton."""
global _session, _tokenizer
if _session is not None:
return _session, _tokenizer
with _lock:
if _session is not None:
return _session, _tokenizer
_ensure_model()
import onnxruntime as ort
from tokenizers import Tokenizer
_session = ort.InferenceSession(
os.path.join(MODEL_DIR, "model.onnx"),
providers=["CPUExecutionProvider"],
)
_tokenizer = Tokenizer.from_file(os.path.join(MODEL_DIR, "tokenizer.json"))
_tokenizer.enable_truncation(max_length=MAX_TOKENS)
_tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=None)
return _session, _tokenizer
def _get_reranker():
"""Return (onnxruntime.InferenceSession, tokenizers.Tokenizer) for the cross-encoder reranker."""
global _reranker_session, _reranker_tokenizer
if _reranker_session is not None:
return _reranker_session, _reranker_tokenizer
with _reranker_lock:
if _reranker_session is not None:
return _reranker_session, _reranker_tokenizer
model_path = os.path.join(RERANKER_DIR, "model.onnx")
tok_path = os.path.join(RERANKER_DIR, "tokenizer.json")
if not os.path.exists(model_path) or not os.path.exists(tok_path):
return None, None
import onnxruntime as ort
from tokenizers import Tokenizer
_reranker_session = ort.InferenceSession(
model_path, providers=["CPUExecutionProvider"],
)
_reranker_tokenizer = Tokenizer.from_file(tok_path)
_reranker_tokenizer.enable_truncation(max_length=512)
_reranker_tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=None)
return _reranker_session, _reranker_tokenizer
def rerank(query, documents, limit=10):
"""Score query-document pairs with the cross-encoder and return reranked indices.
Args:
query: search query string
documents: list of document texts to score against the query
limit: max results to return
Returns: list of (original_index, score) sorted by score descending.
"""
session, tokenizer = _get_reranker()
if session is None:
return [(i, 0.0) for i in range(min(limit, len(documents)))]
# Cross-encoder takes (query, document) pairs — encode as pair sequences
pairs = [[query, doc] for doc in documents]
encodings = tokenizer.encode_batch(pairs)
input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
token_type_ids = np.array([e.type_ids for e in encodings], dtype=np.int64)
outputs = session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
},
)
# Output is logits — higher = more relevant
scores = outputs[0].flatten()
ranked = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
return [(i, float(scores[i])) for i in ranked[:limit]]
# ---------------------------------------------------------------------------
# Text chunking
# ---------------------------------------------------------------------------
_SENTENCE_RE = re.compile(r'(?<=[.!?])\s+')
def chunk_text(title, body):
"""Split body into chunks, each prefixed with title for context.
Strategy: split on double newlines (paragraphs). If a paragraph exceeds
MAX_TOKENS words, split at sentence boundaries. Each chunk is prefixed
with the page title.
"""
if not body or not body.strip():
return [f"{title}"] if title else []
prefix = f"{title}: " if title else ""
# Rough word budget for chunk body (leave room for prefix)
prefix_words = len(prefix.split())
max_words = MAX_TOKENS - prefix_words # approximate; tokenizer may differ
paragraphs = re.split(r'\n\s*\n', body.strip())
chunks = []
for para in paragraphs:
para = para.strip()
if len(para) < 20:
continue
words = para.split()
if len(words) <= max_words:
chunks.append(prefix + para)
else:
# Split paragraph into sentences
sentences = _SENTENCE_RE.split(para)
current = []
current_len = 0
for sent in sentences:
sent_words = len(sent.split())
if current_len + sent_words > max_words and current:
chunks.append(prefix + " ".join(current))
current = []
current_len = 0
if sent_words > max_words:
# Sentence too long — use sliding window
s_words = sent.split()
for i in range(0, len(s_words), max_words - 50):
window = s_words[i:i + max_words]
chunks.append(prefix + " ".join(window))
else:
current.append(sent)
current_len += sent_words
if current:
chunks.append(prefix + " ".join(current))
if not chunks and title:
chunks = [title]
return chunks
# ---------------------------------------------------------------------------
# Embedding
# ---------------------------------------------------------------------------
def embed(texts, is_query=False):
"""Encode texts into L2-normalized float32 embeddings (N, 384).
For queries, prepend the model's query prefix.
Processes in batches of 32 to limit memory usage.
"""
if not texts:
return np.empty((0, DIMS), dtype=np.float32)
session, tokenizer = _get_session()
if is_query:
texts = [QUERY_PREFIX + t for t in texts]
batch_size = 32
all_embeddings = []
for start in range(0, len(texts), batch_size):
batch = texts[start:start + batch_size]
encodings = tokenizer.encode_batch(batch)
input_ids = np.array([e.ids for e in encodings], dtype=np.int64)
attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64)
token_type_ids = np.zeros_like(input_ids)
outputs = session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
},
)
# CLS token pooling — take the first token's hidden state
emb = outputs[0][:, 0, :]
all_embeddings.append(emb)
embeddings = np.concatenate(all_embeddings, axis=0)
# L2 normalize
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms = np.maximum(norms, 1e-12)
embeddings = embeddings / norms
return embeddings.astype(np.float32)
# ---------------------------------------------------------------------------
# HNSW index management
# ---------------------------------------------------------------------------
def build_index(db=None):
"""Load all embeddings from chunks table and build HNSW index."""
import hnswlib
global _hnsw_index, _hnsw_ids
from db import get_db, return_db
own_db = db is None
if own_db:
db = get_db()
try:
rows = db.execute("SELECT id, embedding FROM chunks ORDER BY id").fetchall()
finally:
if own_db:
return_db(db)
with _hnsw_lock:
if not rows:
_hnsw_index = None
_hnsw_ids = []
return
n = len(rows)
ids = [r["id"] for r in rows]
matrix = np.frombuffer(b"".join(r["embedding"] for r in rows), dtype=np.float32).reshape(n, DIMS)
index = hnswlib.Index(space="cosine", dim=DIMS)
# ef_construction and M balance build speed vs recall;
# these defaults give >99% recall at reasonable build time
index.init_index(max_elements=max(n, 1024), ef_construction=200, M=16)
index.add_items(matrix, list(range(n)))
index.set_ef(50) # query-time accuracy parameter
_hnsw_index = index
_hnsw_ids = ids
def _add_to_index(chunk_ids, embeddings_matrix):
"""Add new embeddings to the live HNSW index."""
import hnswlib
global _hnsw_index, _hnsw_ids
with _hnsw_lock:
if _hnsw_index is None:
index = hnswlib.Index(space="cosine", dim=DIMS)
index.init_index(max_elements=1024, ef_construction=200, M=16)
index.set_ef(50)
_hnsw_index = index
_hnsw_ids = []
current_max = _hnsw_index.get_max_elements()
needed = len(_hnsw_ids) + len(chunk_ids)
if needed > current_max:
_hnsw_index.resize_index(max(needed * 2, current_max * 2))
labels = list(range(len(_hnsw_ids), len(_hnsw_ids) + len(chunk_ids)))
_hnsw_index.add_items(embeddings_matrix, labels)
_hnsw_ids.extend(chunk_ids)
# ---------------------------------------------------------------------------
# Store embeddings for pages
# ---------------------------------------------------------------------------
def store_embeddings(page_id, title, body, db):
"""Chunk, embed, and store embeddings for a page. Adds to HNSW index."""
chunks = chunk_text(title, body)
if not chunks:
return
embeddings_matrix = embed(chunks)
# Delete old chunks for this page
db.execute("DELETE FROM chunks WHERE page_id = ?", (page_id,))
new_ids = []
for i, (text, emb) in enumerate(zip(chunks, embeddings_matrix)):
cursor = db.execute(
"INSERT INTO chunks (page_id, remote_page_id, chunk_index, chunk_text, embedding) "
"VALUES (?, NULL, ?, ?, ?)",
(page_id, i, text, emb.tobytes()),
)
new_ids.append(cursor.lastrowid)
db.commit()
_add_to_index(new_ids, embeddings_matrix)
def store_remote_embeddings(remote_page_id, title, note, db):
"""Store a single embedding for a remote page (title + note)."""
text = f"{title}: {note}" if note else (title or "")
if not text.strip():
return
embeddings_matrix = embed([text])
db.execute("DELETE FROM chunks WHERE remote_page_id = ?", (remote_page_id,))
cursor = db.execute(
"INSERT INTO chunks (page_id, remote_page_id, chunk_index, chunk_text, embedding) "
"VALUES (NULL, ?, 0, ?, ?)",
(remote_page_id, text, embeddings_matrix[0].tobytes()),
)
db.commit()
_add_to_index([cursor.lastrowid], embeddings_matrix)
# ---------------------------------------------------------------------------
# Search
# ---------------------------------------------------------------------------
def semantic_search(query_text, limit=100, db=None):
"""Search for pages by semantic similarity.
Returns: [(page_id, score, best_chunk_text), ...] sorted by score desc.
Groups by page_id, taking the max chunk score per page.
"""
if _hnsw_index is None or not _hnsw_ids:
return []
query_emb = embed([query_text], is_query=True)
with _hnsw_lock:
if _hnsw_index is None or not _hnsw_ids:
return []
k = min(limit * 3, len(_hnsw_ids)) # oversample to account for grouping
if k == 0:
return []
labels, distances = _hnsw_index.knn_query(query_emb, k=k)
# Map HNSW labels back to chunk IDs
chunk_ids = [_hnsw_ids[int(lbl)] for lbl in labels[0]]
# cosine distance -> similarity: hnswlib returns 1-cosine for "cosine" space
scores = [1.0 - float(d) for d in distances[0]]
# Fetch chunk details from DB
from db import get_db, return_db
own_db = db is None
if own_db:
db = get_db()
try:
placeholders = ",".join("?" * len(chunk_ids))
rows = db.execute(
f"SELECT id, page_id, chunk_text FROM chunks WHERE id IN ({placeholders})",
chunk_ids,
).fetchall()
finally:
if own_db:
return_db(db)
chunk_map = {r["id"]: r for r in rows}
# Group by page_id, keep best score and chunk text per page
page_best = {} # page_id -> (score, chunk_text)
for cid, score in zip(chunk_ids, scores):
chunk = chunk_map.get(cid)
if not chunk or chunk["page_id"] is None:
continue
pid = chunk["page_id"]
if pid not in page_best or score > page_best[pid][0]:
page_best[pid] = (score, chunk["chunk_text"])
results = [(pid, score, text) for pid, (score, text) in page_best.items()]
results.sort(key=lambda x: x[1], reverse=True)
return results[:limit]
def hybrid_search(query_text, bm25_ranked_ids, limit=10, db=None, use_reranker=False):
"""Merge BM25 and semantic results via RRF, optionally rerank with cross-encoder.
Default (two-stage): BM25 + semantic fused via RRF.
With use_reranker=True (three-stage): rerank top 20 with cross-encoder.
Returns: [(page_id, best_chunk_text), ...] in ranked order.
"""
k = 60 # RRF constant
sem_results = semantic_search(query_text, limit=100, db=db)
best_chunks = {} # page_id -> chunk_text
for _rank, (pid, _score, chunk_text) in enumerate(sem_results):
if pid not in best_chunks:
best_chunks[pid] = chunk_text
# When BM25 has no hits, use raw semantic similarity scores directly
# (RRF rank positions distort nearly-equal scores)
if not bm25_ranked_ids:
fused_ids = [(pid, score) for pid, score, _ in sem_results]
else:
rrf_scores = {}
for rank, pid in enumerate(bm25_ranked_ids):
rrf_scores[pid] = rrf_scores.get(pid, 0) + 1.0 / (k + rank + 1)
for rank, (pid, _score, chunk_text) in enumerate(sem_results):
rrf_scores[pid] = rrf_scores.get(pid, 0) + 1.0 / (k + rank + 1)
fused_ids = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
fused = fused_ids
all_ids = [pid for pid, _ in fused]
if not all_ids:
return []
if not use_reranker:
return [(pid, best_chunks.get(pid, "")) for pid in all_ids[:limit]]
# --- Rerank top 20, append next 10 from RRF order ---
rerank_ids = all_ids[:20]
tail_ids = all_ids[20:30]
from db import get_db, return_db
own_db = db is None
if own_db:
db = get_db()
try:
placeholders = ",".join("?" * len(rerank_ids))
rows = db.execute(
f"SELECT id, title, body FROM pages WHERE id IN ({placeholders})",
rerank_ids,
).fetchall()
finally:
if own_db:
return_db(db)
page_map = {r["id"]: r for r in rows}
doc_texts = []
ordered_ids = []
for pid in rerank_ids:
page = page_map.get(pid)
if not page:
continue
chunk = best_chunks.get(pid, "")
body_preview = chunk[:200] if chunk else page["body"][:200]
doc = f"{page['title']}. {body_preview}"
doc_texts.append(doc)
ordered_ids.append(pid)
if not doc_texts:
return []
try:
reranked = rerank(query_text, doc_texts, limit=20)
results = [(ordered_ids[idx], best_chunks.get(ordered_ids[idx], "")) for idx, _score in reranked]
except Exception:
results = [(pid, best_chunks.get(pid, "")) for pid in ordered_ids[:20]]
# Append next 10 from RRF order (no reranking)
reranked_set = {pid for pid, _ in results}
for pid in tail_ids:
if pid not in reranked_set:
results.append((pid, best_chunks.get(pid, "")))
return results[:30]
# ---------------------------------------------------------------------------
# Reindex
# ---------------------------------------------------------------------------
def reindex_all(db=None, progress_callback=None):
"""Embed all pages that don't yet have chunks. Also generates missing summaries. Rebuilds HNSW index."""
from db import get_db, return_db, _generate_summary
own_db = db is None
if own_db:
db = get_db()
try:
rows = db.execute(
"SELECT p.id, p.title, p.body, p.summary FROM pages p "
"WHERE p.id NOT IN (SELECT DISTINCT page_id FROM chunks WHERE page_id IS NOT NULL)"
).fetchall()
total = len(rows)
for i, row in enumerate(rows):
store_embeddings(row["id"], row["title"], row["body"], db)
# Generate summary if missing
if not row["summary"]:
summary = _generate_summary(row["title"], row["body"])
db.execute("UPDATE pages SET summary = ? WHERE id = ?", (summary, row["id"]))
db.commit()
if progress_callback:
progress_callback(i + 1, total)
# Generate summaries for pages that already have chunks but no summary
no_summary = db.execute(
"SELECT id, title, body FROM pages WHERE summary = '' OR summary IS NULL"
).fetchall()
for row in no_summary:
summary = _generate_summary(row["title"], row["body"])
db.execute("UPDATE pages SET summary = ? WHERE id = ?", (summary, row["id"]))
if no_summary:
db.commit()
# Also handle remote pages
remote_rows = db.execute(
"SELECT rp.id, rp.title, rp.note FROM remote_pages rp "
"WHERE rp.id NOT IN (SELECT DISTINCT remote_page_id FROM chunks WHERE remote_page_id IS NOT NULL)"
).fetchall()
for rp in remote_rows:
store_remote_embeddings(rp["id"], rp["title"], rp["note"], db)
finally:
if own_db:
return_db(db)
build_index(db)