This commit is contained in:
parent
552311b730
commit
8ecb963be4
4 changed files with 172 additions and 29 deletions
|
|
@ -233,24 +233,42 @@ def embed(texts, is_query=False):
|
|||
"token_type_ids": token_type_ids,
|
||||
},
|
||||
)
|
||||
# CLS token pooling — take the first token's hidden state
|
||||
emb = outputs[0][:, 0, :]
|
||||
all_embeddings.append(emb)
|
||||
|
||||
embeddings = np.concatenate(all_embeddings, axis=0)
|
||||
# L2 normalize
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
norms = np.maximum(norms, 1e-12)
|
||||
embeddings = embeddings / norms
|
||||
return embeddings.astype(np.float32)
|
||||
return _maybe_compress(embeddings.astype(np.float32))
|
||||
|
||||
|
||||
def _maybe_compress(embeddings):
|
||||
"""Compress embeddings to float16 if compression is enabled."""
|
||||
try:
|
||||
from db import get_setting
|
||||
if get_setting("compress_embeddings", "0") == "1":
|
||||
return embeddings.astype(np.float16)
|
||||
except Exception:
|
||||
pass
|
||||
return embeddings
|
||||
|
||||
|
||||
def _decompress(embeddings):
|
||||
"""Decompress float16 embeddings to float32 if needed."""
|
||||
if embeddings.dtype == np.float16:
|
||||
return embeddings.astype(np.float32)
|
||||
return embeddings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HNSW index management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BATCH_SIZE = 50000
|
||||
|
||||
def build_index(db=None):
|
||||
"""Load all embeddings from chunks table and build HNSW index."""
|
||||
"""Load all embeddings from chunks table and build HNSW index in batches."""
|
||||
import hnswlib
|
||||
global _hnsw_index, _hnsw_ids
|
||||
|
||||
|
|
@ -258,29 +276,49 @@ def build_index(db=None):
|
|||
own_db = db is None
|
||||
if own_db:
|
||||
db = get_db()
|
||||
|
||||
try:
|
||||
rows = db.execute("SELECT id, embedding FROM chunks ORDER BY id").fetchall()
|
||||
total = db.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
|
||||
if total == 0:
|
||||
with _hnsw_lock:
|
||||
_hnsw_index = None
|
||||
_hnsw_ids = []
|
||||
return
|
||||
|
||||
all_ids = []
|
||||
all_embeddings = []
|
||||
|
||||
for offset in range(0, total, BATCH_SIZE):
|
||||
rows = db.execute(
|
||||
"SELECT id, embedding FROM chunks ORDER BY id LIMIT ? OFFSET ?",
|
||||
(BATCH_SIZE, offset),
|
||||
).fetchall()
|
||||
for r in rows:
|
||||
emb = np.frombuffer(r["embedding"], dtype=np.float32)
|
||||
if emb.dtype == np.float16:
|
||||
emb = emb.astype(np.float32)
|
||||
all_ids.append(r["id"])
|
||||
all_embeddings.append(emb)
|
||||
finally:
|
||||
if own_db:
|
||||
return_db(db)
|
||||
|
||||
with _hnsw_lock:
|
||||
if not rows:
|
||||
if not all_ids:
|
||||
with _hnsw_lock:
|
||||
_hnsw_index = None
|
||||
_hnsw_ids = []
|
||||
return
|
||||
return
|
||||
|
||||
n = len(rows)
|
||||
ids = [r["id"] for r in rows]
|
||||
matrix = np.frombuffer(b"".join(r["embedding"] for r in rows), dtype=np.float32).reshape(n, DIMS)
|
||||
matrix = np.stack(all_embeddings)
|
||||
n = len(all_ids)
|
||||
ids = all_ids
|
||||
|
||||
index = hnswlib.Index(space="cosine", dim=DIMS)
|
||||
# ef_construction and M balance build speed vs recall;
|
||||
# these defaults give >99% recall at reasonable build time
|
||||
index.init_index(max_elements=max(n, 1024), ef_construction=200, M=16)
|
||||
index.add_items(matrix, list(range(n)))
|
||||
index.set_ef(50) # query-time accuracy parameter
|
||||
index = hnswlib.Index(space="cosine", dim=DIMS)
|
||||
index.init_index(max_elements=max(n, 1024), ef_construction=200, M=16)
|
||||
index.add_items(matrix, list(range(n)))
|
||||
index.set_ef(50)
|
||||
|
||||
with _hnsw_lock:
|
||||
_hnsw_index = index
|
||||
_hnsw_ids = ids
|
||||
|
||||
|
|
@ -319,8 +357,8 @@ def store_embeddings(page_id, title, body, db):
|
|||
return
|
||||
|
||||
embeddings_matrix = embed(chunks)
|
||||
embeddings_matrix = _decompress(embeddings_matrix)
|
||||
|
||||
# Delete old chunks for this page
|
||||
db.execute("DELETE FROM chunks WHERE page_id = ?", (page_id,))
|
||||
|
||||
new_ids = []
|
||||
|
|
@ -343,6 +381,7 @@ def store_remote_embeddings(remote_page_id, title, note, db):
|
|||
return
|
||||
|
||||
embeddings_matrix = embed([text])
|
||||
embeddings_matrix = _decompress(embeddings_matrix)
|
||||
|
||||
db.execute("DELETE FROM chunks WHERE remote_page_id = ?", (remote_page_id,))
|
||||
cursor = db.execute(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue