前面我们已完成在Qdrant创建了startups集合,导入了startups_demo.json数据,让我们开始构建神经搜索类。
为了处理传入请求,神经搜索需要两件事:1)将查询转换为向量的模型,2)Qdrant 客户端来执行搜索查询。
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
class NeuralSearcher:
def __init__(self, collection_name):
self.collection_name = collection_name
# Initialize encoder model
self.model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
# initialize Qdrant client
self.qdrant_client = QdrantClient("http://localhost:6333")
def search(self, text: str):
vector = self.model.encode(text).tolist()
search_result = self.qdrant_client.search(
collection_name=self.collection_name,
query_vector=vector,
query_filter=None,
limit=5
)
payloads = [hit.payload for hit in search_result]
return payloads
现在已经创建了一个用于神经搜索查询的类。现在将其包装到服务中
要构建该服务,您将使用 FastAPI 框架。
要安装它,请使用命令
pip install fastapi uvicorn
创建一个名为的文件service.py并指定以下内容。
该服务只有一个 API 端点,如下所示:
from fastapi import FastAPI
# The file where NeuralSearcher is stored
from neural_searcher import NeuralSearcher
app = FastAPI()
# Create a neural searcher instance
neural_searcher = NeuralSearcher(collection_name='startups')
@app.get("/api/search")
def search_startup(q: str):
return {
"result": neural_searcher.search(text=q)
# "result": neural_searcher.async_search(text=q) # 异步非阻塞
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
python service.py
打开浏览器http://localhost:8000/docs
就可以看到服务的调试界面
请随意使用它,查询语料库中的公司,并查看结果
http://127.0.0.1:8000/api/search?q=Artificial%20intelligence%20machine%20learning
调用QdrantClient替换为AsyncQdrantClient,AsyncQdrantClient提供与同步对应项相同的方法QdrantClient
注:AsyncQdrantClient提供与同步对应项相同的方法QdrantClient,异步是在qdrant-client1.6.1版本中引入
from qdrant_client import QdrantClient, AsyncQdrantClient
class NeuralSearcher:
# .../ init()
# 异步查询
async def async_search(self, text: str):
# AsyncQdrantClient提供与同步对应项相同的方法QdrantClient,异步客户端是在qdrant-client1.6.1版本中引入
client = AsyncQdrantClient("http://localhost:6333")
vector = self.model.encode(text).tolist()
search_result = await client.search(
collection_name=self.collection_name,
query_vector=vector,
query_filter=None,
limit=5
)
payloads = [hit.payload for hit in search_result]
return payloads