from configs import CACHED_VS_NUM from server.knowledge_base.kb_cache.base import * from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter from server.utils import load_local_embeddings from server.knowledge_base.utils import get_vs_path from langchain.vectorstores.faiss import FAISS from langchain.schema import Document import os from langchain.schema import Document class ThreadSafeFaiss(ThreadSafeObject): def __repr__(self) -> str: cls = type(self).__name__ return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>" def docs_count(self) -> int: return len(self._obj.docstore._dict) def save(self, path: str, create_path: bool = True): with self.acquire(): if not os.path.isdir(path) and create_path: os.makedirs(path) ret = self._obj.save_local(path) logger.info(f"已将向量库 {self.key} 保存到磁盘") return ret def clear(self): ret = [] with self.acquire(): ids = list(self._obj.docstore._dict.keys()) if ids: ret = self._obj.delete(ids) assert len(self._obj.docstore._dict) == 0 logger.info(f"已将向量库 {self.key} 清空") return ret class _FaissPool(CachePool): def new_vector_store( self, embed_model: str = EMBEDDING_MODEL, embed_device: str = embedding_device(), ) -> FAISS: # TODO: 整个Embeddings加载逻辑有些混乱,待清理 # create an empty vector store embeddings = EmbeddingsFunAdapter(embed_model) doc = Document(page_content="init", metadata={}) vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) ids = list(vector_store.docstore._dict.keys()) vector_store.delete(ids) return vector_store def save_vector_store(self, kb_name: str, path: str=None): if cache := self.get(kb_name): return cache.save(path) def unload_vector_store(self, kb_name: str): if cache := self.get(kb_name): self.pop(kb_name) logger.info(f"成功释放向量库:{kb_name}") class KBFaissPool(_FaissPool): def load_vector_store( self, kb_name: str, vector_name: str = None, create: bool = True, embed_model: str = EMBEDDING_MODEL, embed_device: str = embedding_device(), ) -> ThreadSafeFaiss: self.atomic.acquire() vector_name = vector_name or embed_model cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 if cache is None: item = ThreadSafeFaiss((kb_name, vector_name), pool=self) self.set((kb_name, vector_name), item) with item.acquire(msg="初始化"): self.atomic.release() logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.") vs_path = get_vs_path(kb_name, vector_name) if os.path.isfile(os.path.join(vs_path, "index.faiss")): embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model) vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) elif create: # create an empty vector store if not os.path.exists(vs_path): os.makedirs(vs_path) vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device) vector_store.save_local(vs_path) else: raise RuntimeError(f"knowledge base {kb_name} not exist.") item.obj = vector_store item.finish_loading() else: self.atomic.release() return self.get((kb_name, vector_name)) class MemoFaissPool(_FaissPool): def load_vector_store( self, kb_name: str, embed_model: str = EMBEDDING_MODEL, embed_device: str = embedding_device(), ) -> ThreadSafeFaiss: self.atomic.acquire() cache = self.get(kb_name) if cache is None: item = ThreadSafeFaiss(kb_name, pool=self) self.set(kb_name, item) with item.acquire(msg="初始化"): self.atomic.release() logger.info(f"loading vector store in '{kb_name}' to memory.") # create an empty vector store vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device) item.obj = vector_store item.finish_loading() else: self.atomic.release() return self.get(kb_name) kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM) memo_faiss_pool = MemoFaissPool() if __name__ == "__main__": import time, random from pprint import pprint kb_names = ["vs1", "vs2", "vs3"] # for name in kb_names: # memo_faiss_pool.load_vector_store(name) def worker(vs_name: str, name: str): vs_name = "samples" time.sleep(random.randint(1, 5)) embeddings = load_local_embeddings() r = random.randint(1, 3) with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs: if r == 1: # add docs ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings) pprint(ids) elif r == 2: # search docs docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0) pprint(docs) if r == 3: # delete docs logger.warning(f"清除 {vs_name} by {name}") kb_faiss_pool.get(vs_name).clear() threads = [] for n in range(1, 30): t = threading.Thread(target=worker, kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"}, daemon=True) t.start() threads.append(t) for t in threads: t.join()