ai/text_splitter/ali_text_splitter.py

35 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
class AliTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
def split_text(self, text: str) -> List[str]:
# use_document_segmentation参数指定是否用语义切分文档此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base论文见https://arxiv.org/abs/2107.09278
# 如果使用模型进行文档语义切分那么需要安装modelscope[nlp]pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
# 考虑到使用了三个模型可能对于低配置gpu不太友好因此这里将模型load进cpu计算有需要的话可以替换device为自己的显卡id
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
try:
from modelscope.pipelines import pipeline
except ImportError:
raise ImportError(
"Could not import modelscope python package. "
"Please install modelscope with `pip install modelscope`. "
)
p = pipeline(
task="document-segmentation",
model='damo/nlp_bert_document-segmentation_chinese-base',
device="cpu")
result = p(documents=text)
sent_list = [i for i in result["text"].split("\n\t") if i]
return sent_list