ai/server/knowledge_base/kb_service/milvus_kb_service.py

132 lines
4.8 KiB
Python

from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import Milvus
from pymilvus import connections
from configs import kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
class MilvusKBService(KBService):
milvus: Milvus
@staticmethod
def get_collection(milvus_name):
from pymilvus import Collection
return Collection(milvus_name)
# def save_vector_store(self):
# if self.milvus.col:
# self.milvus.col.flush()
def get_doc_by_id(self, id: str) -> Optional[Document]:
if self.milvus.col:
data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"])
if len(data_list) > 0:
data = data_list[0]
text = data.pop("text")
return Document(page_content=text, metadata=data)
@staticmethod
def search(milvus_name, content, limit=3):
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
print(content)
c = MilvusKBService.get_collection(milvus_name)
r = c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
return r
def search_all(self, query, limit=3):
connections.connect("default", host="localhost", port="19530")
embedding_function = EmbeddingsFunAdapter(self._load_embeddings())
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
content = [embedding_function.embed_query(query)]
c = MilvusKBService.get_collection(self.kb_name)
r = c.search(content, "vector", search_params, limit=limit, output_fields=["*"])
return r
def do_delete_one_doc(self, pk):
if self.milvus.col:
self.milvus.col.delete(expr=f'pk in [{pk}]')
print("delete success")
def do_create_kb(self):
pass
def vs_type(self) -> str:
return SupportedVSType.MILVUS
def _load_milvus(self, embeddings: Embeddings = None):
if embeddings is None:
self.embeddings = self._load_embeddings()
embeddings = self.embeddings
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings),
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
def do_init(self):
self._load_milvus()
def do_drop_kb(self):
if self.milvus.col:
self.milvus.col.release()
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
# similarity_search_with_score使用带分数的进行搜索
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def do_search_all(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
r = self.search_all(query, top_k)
return r
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
# TODO: workaround for bug #10492 in langchain
for doc in docs:
for k, v in doc.metadata.items():
doc.metadata[k] = str(v)
for field in self.milvus.fields:
doc.metadata.setdefault(field, "")
doc.metadata.pop(self.milvus._text_field, None)
doc.metadata.pop(self.milvus._vector_field, None)
ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
if self.milvus.col:
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
self.milvus.col.delete(expr=f'pk in {delete_list}')
def do_clear_vs(self):
if self.milvus.col:
self.do_drop_kb()
self.do_init()
if __name__ == '__main__':
# 测试建表使用
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("test")
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
print(milvusService.get_doc_by_id("445466355570849011"))
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
# milvusService.do_drop_kb()