optimized storage, updated readme

This commit is contained in:
lichenblankie 2026-04-11 21:59:55 +00:00
parent 7946225030
commit 30bc61212f
4 changed files with 177 additions and 34 deletions

View file

@ -233,24 +233,42 @@ def embed(texts, is_query=False):
"token_type_ids": token_type_ids,
},
)
# CLS token pooling — take the first token's hidden state
emb = outputs[0][:, 0, :]
all_embeddings.append(emb)
embeddings = np.concatenate(all_embeddings, axis=0)
# L2 normalize
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms = np.maximum(norms, 1e-12)
embeddings = embeddings / norms
return embeddings.astype(np.float32)
return _maybe_compress(embeddings.astype(np.float32))
def _maybe_compress(embeddings):
"""Compress embeddings to float16 if compression is enabled."""
try:
from db import get_setting
if get_setting("compress_embeddings", "0") == "1":
return embeddings.astype(np.float16)
except Exception:
pass
return embeddings
def _decompress(embeddings):
"""Decompress float16 embeddings to float32 if needed."""
if embeddings.dtype == np.float16:
return embeddings.astype(np.float32)
return embeddings
# ---------------------------------------------------------------------------
# HNSW index management
# ---------------------------------------------------------------------------
BATCH_SIZE = 50000
def build_index(db=None):
"""Load all embeddings from chunks table and build HNSW index."""
"""Load all embeddings from chunks table and build HNSW index in batches."""
import hnswlib
global _hnsw_index, _hnsw_ids
@ -258,29 +276,49 @@ def build_index(db=None):
own_db = db is None
if own_db:
db = get_db()
try:
rows = db.execute("SELECT id, embedding FROM chunks ORDER BY id").fetchall()
total = db.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
if total == 0:
with _hnsw_lock:
_hnsw_index = None
_hnsw_ids = []
return
all_ids = []
all_embeddings = []
for offset in range(0, total, BATCH_SIZE):
rows = db.execute(
"SELECT id, embedding FROM chunks ORDER BY id LIMIT ? OFFSET ?",
(BATCH_SIZE, offset),
).fetchall()
for r in rows:
emb = np.frombuffer(r["embedding"], dtype=np.float32)
if emb.dtype == np.float16:
emb = emb.astype(np.float32)
all_ids.append(r["id"])
all_embeddings.append(emb)
finally:
if own_db:
return_db(db)
with _hnsw_lock:
if not rows:
if not all_ids:
with _hnsw_lock:
_hnsw_index = None
_hnsw_ids = []
return
return
n = len(rows)
ids = [r["id"] for r in rows]
matrix = np.frombuffer(b"".join(r["embedding"] for r in rows), dtype=np.float32).reshape(n, DIMS)
matrix = np.stack(all_embeddings)
n = len(all_ids)
ids = all_ids
index = hnswlib.Index(space="cosine", dim=DIMS)
# ef_construction and M balance build speed vs recall;
# these defaults give >99% recall at reasonable build time
index.init_index(max_elements=max(n, 1024), ef_construction=200, M=16)
index.add_items(matrix, list(range(n)))
index.set_ef(50) # query-time accuracy parameter
index = hnswlib.Index(space="cosine", dim=DIMS)
index.init_index(max_elements=max(n, 1024), ef_construction=200, M=16)
index.add_items(matrix, list(range(n)))
index.set_ef(50)
with _hnsw_lock:
_hnsw_index = index
_hnsw_ids = ids
@ -319,8 +357,8 @@ def store_embeddings(page_id, title, body, db):
return
embeddings_matrix = embed(chunks)
embeddings_matrix = _decompress(embeddings_matrix)
# Delete old chunks for this page
db.execute("DELETE FROM chunks WHERE page_id = ?", (page_id,))
new_ids = []
@ -343,6 +381,7 @@ def store_remote_embeddings(remote_page_id, title, note, db):
return
embeddings_matrix = embed([text])
embeddings_matrix = _decompress(embeddings_matrix)
db.execute("DELETE FROM chunks WHERE remote_page_id = ?", (remote_page_id,))
cursor = db.execute(