"""Semantic search using Snowflake arctic-embed-s via ONNX Runtime + hnswlib.""" import os import re import threading import numpy as np DATA_DIR = os.path.expanduser("~/.tinyweb") MODEL_ID = "Snowflake/snowflake-arctic-embed-s" MODEL_DIR = os.path.join(DATA_DIR, "models", "snowflake-arctic-embed-s") RERANKER_DIR = os.path.join(DATA_DIR, "models", "cross-encoder") HNSW_PATH = os.path.join(DATA_DIR, "index.hnsw") DIMS = 384 MAX_TOKENS = 512 QUERY_PREFIX = "Represent this sentence for searching relevant passages: " _session = None _tokenizer = None _lock = threading.Lock() _reranker_session = None _reranker_tokenizer = None _reranker_lock = threading.Lock() # Live HNSW index and chunk-id mapping _hnsw_index = None _hnsw_ids = [] # maps internal HNSW label -> chunks.id _hnsw_lock = threading.Lock() # --------------------------------------------------------------------------- # Model download & loading # --------------------------------------------------------------------------- def _ensure_model(): """Download the ONNX model and tokenizer from HuggingFace if not present.""" os.makedirs(MODEL_DIR, exist_ok=True) model_path = os.path.join(MODEL_DIR, "model.onnx") tokenizer_path = os.path.join(MODEL_DIR, "tokenizer.json") if os.path.exists(model_path) and os.path.exists(tokenizer_path): return from huggingface_hub import hf_hub_download os.makedirs(MODEL_DIR, exist_ok=True) files = { "onnx/model_quantized.onnx": "model.onnx", "tokenizer.json": "tokenizer.json", "tokenizer_config.json": "tokenizer_config.json", } for remote, local in files.items(): target = os.path.join(MODEL_DIR, local) if os.path.exists(target): continue cached = hf_hub_download(repo_id=MODEL_ID, filename=remote) # hf_hub_download returns the cached file path; copy to our model dir import shutil shutil.copy2(cached, target) def _get_session(): """Return (onnxruntime.InferenceSession, tokenizers.Tokenizer) singleton.""" global _session, _tokenizer if _session is not None: return _session, _tokenizer with _lock: if _session is not None: return _session, _tokenizer _ensure_model() import onnxruntime as ort from tokenizers import Tokenizer _session = ort.InferenceSession( os.path.join(MODEL_DIR, "model.onnx"), providers=["CPUExecutionProvider"], ) _tokenizer = Tokenizer.from_file(os.path.join(MODEL_DIR, "tokenizer.json")) _tokenizer.enable_truncation(max_length=MAX_TOKENS) _tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=None) return _session, _tokenizer def _get_reranker(): """Return (onnxruntime.InferenceSession, tokenizers.Tokenizer) for the cross-encoder reranker.""" global _reranker_session, _reranker_tokenizer if _reranker_session is not None: return _reranker_session, _reranker_tokenizer with _reranker_lock: if _reranker_session is not None: return _reranker_session, _reranker_tokenizer model_path = os.path.join(RERANKER_DIR, "model.onnx") tok_path = os.path.join(RERANKER_DIR, "tokenizer.json") if not os.path.exists(model_path) or not os.path.exists(tok_path): return None, None import onnxruntime as ort from tokenizers import Tokenizer _reranker_session = ort.InferenceSession( model_path, providers=["CPUExecutionProvider"], ) _reranker_tokenizer = Tokenizer.from_file(tok_path) _reranker_tokenizer.enable_truncation(max_length=512) _reranker_tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=None) return _reranker_session, _reranker_tokenizer def rerank(query, documents, limit=10): """Score query-document pairs with the cross-encoder and return reranked indices. Args: query: search query string documents: list of document texts to score against the query limit: max results to return Returns: list of (original_index, score) sorted by score descending. """ session, tokenizer = _get_reranker() if session is None: return [(i, 0.0) for i in range(min(limit, len(documents)))] # Cross-encoder takes (query, document) pairs — encode as pair sequences pairs = [[query, doc] for doc in documents] encodings = tokenizer.encode_batch(pairs) input_ids = np.array([e.ids for e in encodings], dtype=np.int64) attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64) token_type_ids = np.array([e.type_ids for e in encodings], dtype=np.int64) outputs = session.run( None, { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, }, ) # Output is logits — higher = more relevant scores = outputs[0].flatten() ranked = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True) return [(i, float(scores[i])) for i in ranked[:limit]] # --------------------------------------------------------------------------- # Text chunking # --------------------------------------------------------------------------- _SENTENCE_RE = re.compile(r'(?<=[.!?])\s+') def chunk_text(title, body): """Split body into chunks, each prefixed with title for context. Strategy: split on double newlines (paragraphs). If a paragraph exceeds MAX_TOKENS words, split at sentence boundaries. Each chunk is prefixed with the page title. """ if not body or not body.strip(): return [f"{title}"] if title else [] prefix = f"{title}: " if title else "" # Rough word budget for chunk body (leave room for prefix) prefix_words = len(prefix.split()) max_words = MAX_TOKENS - prefix_words # approximate; tokenizer may differ paragraphs = re.split(r'\n\s*\n', body.strip()) chunks = [] for para in paragraphs: para = para.strip() if len(para) < 20: continue words = para.split() if len(words) <= max_words: chunks.append(prefix + para) else: # Split paragraph into sentences sentences = _SENTENCE_RE.split(para) current = [] current_len = 0 for sent in sentences: sent_words = len(sent.split()) if current_len + sent_words > max_words and current: chunks.append(prefix + " ".join(current)) current = [] current_len = 0 if sent_words > max_words: # Sentence too long — use sliding window s_words = sent.split() for i in range(0, len(s_words), max_words - 50): window = s_words[i:i + max_words] chunks.append(prefix + " ".join(window)) else: current.append(sent) current_len += sent_words if current: chunks.append(prefix + " ".join(current)) if not chunks and title: chunks = [title] return chunks # --------------------------------------------------------------------------- # Embedding # --------------------------------------------------------------------------- def embed(texts, is_query=False): """Encode texts into L2-normalized float32 embeddings (N, 384). For queries, prepend the model's query prefix. Processes in batches of 32 to limit memory usage. """ if not texts: return np.empty((0, DIMS), dtype=np.float32) session, tokenizer = _get_session() if is_query: texts = [QUERY_PREFIX + t for t in texts] batch_size = 32 all_embeddings = [] for start in range(0, len(texts), batch_size): batch = texts[start:start + batch_size] encodings = tokenizer.encode_batch(batch) input_ids = np.array([e.ids for e in encodings], dtype=np.int64) attention_mask = np.array([e.attention_mask for e in encodings], dtype=np.int64) token_type_ids = np.zeros_like(input_ids) outputs = session.run( None, { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, }, ) emb = outputs[0][:, 0, :] all_embeddings.append(emb) embeddings = np.concatenate(all_embeddings, axis=0) norms = np.linalg.norm(embeddings, axis=1, keepdims=True) norms = np.maximum(norms, 1e-12) embeddings = embeddings / norms return _maybe_compress(embeddings.astype(np.float32)) def _maybe_compress(embeddings): """Compress embeddings to float16 if compression is enabled.""" try: from db import get_setting if get_setting("compress_embeddings", "0") == "1": return embeddings.astype(np.float16) except Exception: pass return embeddings def _decompress(embeddings): """Decompress float16 embeddings to float32 if needed.""" if embeddings.dtype == np.float16: return embeddings.astype(np.float32) return embeddings def _blob_to_vec(buf): """Decode a stored embedding blob to a float32 vector, inferring dtype from length.""" if len(buf) == DIMS * 2: return np.frombuffer(buf, dtype=np.float16).astype(np.float32) return np.frombuffer(buf, dtype=np.float32) # --------------------------------------------------------------------------- # HNSW index management # --------------------------------------------------------------------------- BATCH_SIZE = 50000 def build_index(db=None): """Load all embeddings from chunks table and build HNSW index in batches.""" import hnswlib global _hnsw_index, _hnsw_ids from db import get_db, return_db own_db = db is None if own_db: db = get_db() try: total = db.execute("SELECT COUNT(*) FROM chunks").fetchone()[0] if total == 0: with _hnsw_lock: _hnsw_index = None _hnsw_ids = [] return all_ids = [] all_embeddings = [] for offset in range(0, total, BATCH_SIZE): rows = db.execute( "SELECT id, embedding FROM chunks ORDER BY id LIMIT ? OFFSET ?", (BATCH_SIZE, offset), ).fetchall() for r in rows: emb = _blob_to_vec(r["embedding"]) all_ids.append(r["id"]) all_embeddings.append(emb) finally: if own_db: return_db(db) if not all_ids: with _hnsw_lock: _hnsw_index = None _hnsw_ids = [] return matrix = np.stack(all_embeddings) n = len(all_ids) ids = all_ids index = hnswlib.Index(space="cosine", dim=DIMS) index.init_index(max_elements=max(n, 1024), ef_construction=200, M=16) index.add_items(matrix, list(range(n))) index.set_ef(50) with _hnsw_lock: _hnsw_index = index _hnsw_ids = ids def _add_to_index(chunk_ids, embeddings_matrix): """Add new embeddings to the live HNSW index.""" import hnswlib global _hnsw_index, _hnsw_ids with _hnsw_lock: if _hnsw_index is None: index = hnswlib.Index(space="cosine", dim=DIMS) index.init_index(max_elements=1024, ef_construction=200, M=16) index.set_ef(50) _hnsw_index = index _hnsw_ids = [] current_max = _hnsw_index.get_max_elements() needed = len(_hnsw_ids) + len(chunk_ids) if needed > current_max: _hnsw_index.resize_index(max(needed * 2, current_max * 2)) labels = list(range(len(_hnsw_ids), len(_hnsw_ids) + len(chunk_ids))) _hnsw_index.add_items(embeddings_matrix, labels) _hnsw_ids.extend(chunk_ids) # --------------------------------------------------------------------------- # Store embeddings for pages # --------------------------------------------------------------------------- def store_embeddings(page_id, title, body, db): """Chunk, embed, and store embeddings for a page. Adds to HNSW index.""" chunks = chunk_text(title, body) if not chunks: return embeddings_matrix = embed(chunks) embeddings_matrix = _decompress(embeddings_matrix) db.execute("DELETE FROM chunks WHERE page_id = ?", (page_id,)) new_ids = [] for i, (text, emb) in enumerate(zip(chunks, embeddings_matrix)): cursor = db.execute( "INSERT INTO chunks (page_id, remote_page_id, chunk_index, chunk_text, embedding) " "VALUES (?, NULL, ?, ?, ?)", (page_id, i, text, emb.tobytes()), ) new_ids.append(cursor.lastrowid) db.commit() _add_to_index(new_ids, embeddings_matrix) def store_remote_embeddings(remote_page_id, title, note, db): """Store a single embedding for a remote page (title + note).""" text = f"{title}: {note}" if note else (title or "") if not text.strip(): return embeddings_matrix = embed([text]) embeddings_matrix = _decompress(embeddings_matrix) db.execute("DELETE FROM chunks WHERE remote_page_id = ?", (remote_page_id,)) cursor = db.execute( "INSERT INTO chunks (page_id, remote_page_id, chunk_index, chunk_text, embedding) " "VALUES (NULL, ?, 0, ?, ?)", (remote_page_id, text, embeddings_matrix[0].tobytes()), ) db.commit() _add_to_index([cursor.lastrowid], embeddings_matrix) # --------------------------------------------------------------------------- # Search # --------------------------------------------------------------------------- def semantic_search(query_text, limit=100, db=None): """Search for pages by semantic similarity. Returns: [(page_id, score, best_chunk_text), ...] sorted by score desc. Groups by page_id, taking the max chunk score per page. """ if _hnsw_index is None or not _hnsw_ids: return [] query_emb = embed([query_text], is_query=True) with _hnsw_lock: if _hnsw_index is None or not _hnsw_ids: return [] k = min(limit * 3, len(_hnsw_ids)) # oversample to account for grouping if k == 0: return [] labels, distances = _hnsw_index.knn_query(query_emb, k=k) # Map HNSW labels back to chunk IDs chunk_ids = [_hnsw_ids[int(lbl)] for lbl in labels[0]] # cosine distance -> similarity: hnswlib returns 1-cosine for "cosine" space scores = [1.0 - float(d) for d in distances[0]] # Fetch chunk details from DB from db import get_db, return_db own_db = db is None if own_db: db = get_db() try: placeholders = ",".join("?" * len(chunk_ids)) rows = db.execute( f"SELECT id, page_id, chunk_text FROM chunks WHERE id IN ({placeholders})", chunk_ids, ).fetchall() finally: if own_db: return_db(db) chunk_map = {r["id"]: r for r in rows} # Group by page_id, keep best score and chunk text per page page_best = {} # page_id -> (score, chunk_text) for cid, score in zip(chunk_ids, scores): chunk = chunk_map.get(cid) if not chunk or chunk["page_id"] is None: continue pid = chunk["page_id"] if pid not in page_best or score > page_best[pid][0]: page_best[pid] = (score, chunk["chunk_text"]) results = [(pid, score, text) for pid, (score, text) in page_best.items()] results.sort(key=lambda x: x[1], reverse=True) return results[:limit] def hybrid_search(query_text, bm25_ranked_ids, limit=10, db=None, use_reranker=False): """Merge BM25 and semantic results via RRF, optionally rerank with cross-encoder. Default (two-stage): BM25 + semantic fused via RRF. With use_reranker=True (three-stage): rerank top 20 with cross-encoder. Returns: [(page_id, best_chunk_text), ...] in ranked order. """ k = 60 # RRF constant sem_results = semantic_search(query_text, limit=100, db=db) best_chunks = {} # page_id -> chunk_text for _rank, (pid, _score, chunk_text) in enumerate(sem_results): if pid not in best_chunks: best_chunks[pid] = chunk_text # When BM25 has no hits, use raw semantic similarity scores directly # (RRF rank positions distort nearly-equal scores) if not bm25_ranked_ids: fused_ids = [(pid, score) for pid, score, _ in sem_results] else: rrf_scores = {} for rank, pid in enumerate(bm25_ranked_ids): rrf_scores[pid] = rrf_scores.get(pid, 0) + 1.0 / (k + rank + 1) for rank, (pid, _score, chunk_text) in enumerate(sem_results): rrf_scores[pid] = rrf_scores.get(pid, 0) + 1.0 / (k + rank + 1) fused_ids = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True) fused = fused_ids all_ids = [pid for pid, _ in fused] if not all_ids: return [] if not use_reranker: return [(pid, best_chunks.get(pid, "")) for pid in all_ids[:limit]] # --- Rerank top 20, append next 10 from RRF order --- rerank_ids = all_ids[:20] tail_ids = all_ids[20:30] from db import get_db, return_db own_db = db is None if own_db: db = get_db() try: placeholders = ",".join("?" * len(rerank_ids)) rows = db.execute( f"SELECT id, title, body FROM pages WHERE id IN ({placeholders})", rerank_ids, ).fetchall() finally: if own_db: return_db(db) page_map = {r["id"]: r for r in rows} doc_texts = [] ordered_ids = [] for pid in rerank_ids: page = page_map.get(pid) if not page: continue chunk = best_chunks.get(pid, "") body_preview = chunk[:200] if chunk else page["body"][:200] doc = f"{page['title']}. {body_preview}" doc_texts.append(doc) ordered_ids.append(pid) if not doc_texts: return [] try: reranked = rerank(query_text, doc_texts, limit=20) results = [(ordered_ids[idx], best_chunks.get(ordered_ids[idx], "")) for idx, _score in reranked] except Exception: results = [(pid, best_chunks.get(pid, "")) for pid in ordered_ids[:20]] # Append next 10 from RRF order (no reranking) reranked_set = {pid for pid, _ in results} for pid in tail_ids: if pid not in reranked_set: results.append((pid, best_chunks.get(pid, ""))) return results[:30] # --------------------------------------------------------------------------- # Reindex # --------------------------------------------------------------------------- def reindex_all(db=None, progress_callback=None): """Re-embed all pages and regenerate all summaries. Rebuilds HNSW index.""" from db import get_db, return_db own_db = db is None if own_db: db = get_db() try: # Clear existing chunks so everything is regenerated db.execute("DELETE FROM chunks") db.commit() rows = db.execute( "SELECT p.id, p.title, p.body, p.summary FROM pages p" ).fetchall() total = len(rows) for i, row in enumerate(rows): store_embeddings(row["id"], row["title"], row["body"], db) if progress_callback: progress_callback(i + 1, total) # Also handle remote pages remote_rows = db.execute( "SELECT rp.id, rp.title, rp.note FROM remote_pages rp" ).fetchall() for rp in remote_rows: store_remote_embeddings(rp["id"], rp["title"], rp["note"], db) finally: if own_db: return_db(db) build_index(db)