跳到主要内容

seekdb Vector 与 Hugging Face 集成

seekdb 提供了向量类型存储、向量索引、embedding 向量搜索的能力。可以将向量化后的数据存储在 seekdb,供下一步的搜索使用。

Hugging Face 是一个开源机器学习平台,提供预训练模型、数据集和工具,让开发者能够轻松使用和部署 AI 模型。

前提条件

  • 您已经部署了 seekdb。

  • 您的环境中已存在可以使用的数据库和账号,并已对数据库账号授读写权限。

  • 安装 Python 3.11 及以上版本。

  • 安装依赖。

    python3 -m pip install cffi pyseekdb requests datasets

步骤一:获取数据库连接信息

联系 seekdb 部署人员或者管理员获取相应的数据库连接串,例如:

mysql -h$host -P$port -u$user_name -p$password -D$database_name

参数说明:

  • $host:提供 seekdb 连接 IP 地址。

  • $port:提供 seekdb 连接端口,默认是 2881。

  • $database_name:需要访问的数据库名称。

    提示

    连接的用户需要拥有该数据库的 CREATEINSERTDROPSELECT 权限。

  • $user_name:提供数据库连接账户。

  • $password:提供账户密码。

步骤二:构建您的 AI 助手

设置环境变量

获取 Hugging Face API 密钥,并同 seekdb 连接信息配置到环境变量中。

export SEEKDB_DATABASE_URL=YOUR_SEEKDB_DATABASE_URL
export SEEKDB_DATABASE_USER=YOUR_SEEKDB_DATABASE_USER
export SEEKDB_DATABASE_DB_NAME=YOUR_SEEKDB_DATABASE_DB_NAME
export SEEKDB_DATABASE_PASSWORD=YOUR_SEEKDB_DATABASE_PASSWORD
export HUGGING_FACE_API_KEY=YOUR_HUGGING_FACE_API_KEY

示例代码片段

准备数据

Hugging Face 提供了多种 embedding 模型,用户可以根据自己的需求选择对应的模型使用。 这里以 sentence-transformers/all-MiniLM-L6-v2 为例,用于调用 Hugging Face API :

import os,shutil,requests,pyseekdb
from pyseekdb import HNSWConfiguration
from sentence_transformers import SentenceTransformer
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from datasets import load_dataset

# delete cache directory
if os.path.exists("./cache"):
shutil.rmtree("./cache")


HUGGING_FACE_API_KEY = os.getenv('HUGGING_FACE_API_KEY')
DATASET = "squad" # Name of dataset from HuggingFace Datasets
INSERT_RATIO = 0.001 # Ratio of example dataset to be inserted
data = load_dataset(DATASET, split="validation", cache_dir="./cache")

# Generates a fixed subset. To generate a random subset, remove the seed.
data = data.train_test_split(test_size=INSERT_RATIO, seed=42)["test"]
# Clean up the data structure in the dataset.
data = data.map(
lambda val: {"answer": val["answers"]["text"][0]},
remove_columns=["id", "answers", "context"],
)

# HuggingFace API config
print("正在下载模型...")
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
print("模型下载完成!")

def encode_text(batch):
questions = batch["question"]

# 使用本地模型进行推理
embeddings = model.encode(questions)

# 格式化embeddings
formatted_embeddings = []
for embedding in embeddings:
formatted_embedding = [round(float(val), 6) for val in embedding]
formatted_embeddings.append(formatted_embedding)

batch["embedding"] = formatted_embeddings
return batch

INFERENCE_BATCH_SIZE = 64 # Batch size of model inference
data = data.map(encode_text, batched=True, batch_size=INFERENCE_BATCH_SIZE)
data_list = data.to_list()
ids = []
embeddings = []
documents = []
metadatas = []

for i, item in enumerate(data_list):
ids.append(f"item{i+1}")
embeddings.append(item["embedding"])
documents.append(item["question"])
metadatas.append({"answer": item["answer"]})

定义表并将数据存入 seekdb

创建一个名为 huggingface_seekdb_demo_documents 的表,并将数据存入 seekdb:

SEEKDB_DATABASE_HOST = os.getenv('SEEKDB_DATABASE_HOST')
SEEKDB_DATABASE_PORT = int(os.getenv('SEEKDB_DATABASE_PORT', 2881))
SEEKDB_DATABASE_USER = os.getenv('SEEKDB_DATABASE_USER')
SEEKDB_DATABASE_DB_NAME = os.getenv('SEEKDB_DATABASE_DB_NAME')
SEEKDB_DATABASE_PASSWORD = os.getenv('SEEKDB_DATABASE_PASSWORD')

client = pyseekdb.Client(host=SEEKDB_DATABASE_HOST, port=SEEKDB_DATABASE_PORT, database=SEEKDB_DATABASE_DB_NAME, user=SEEKDB_DATABASE_USER, password=SEEKDB_DATABASE_PASSWORD)
table_name = "huggingface_seekdb_demo_documents"
config = HNSWConfiguration(dimension=384, distance='l2')

collection = client.create_collection(
name=table_name,
configuration=config,
embedding_function=None
)

print('- Inserting Data to seekdb...')
collection.add(
ids=ids,
embeddings=embeddings,
documents=documents
)
print('- Inserting Data to seekdb completed!')

语义搜索

通过 Hugging Face API 生成查询文本向量,然后根据文本向量查询与向量表中的每个向量的 l2 距离,搜索最相关的文档:

questions = {
"question": [
"What is LGM?",
"When did Massachusetts first mandate that children be educated in schools?",
]
}

query_embeddings = encode_text(questions)["embedding"]

res = collection.query(
query_embeddings=query_embeddings,
n_results=1
)

for i in range(len(questions["question"])):
print(f"Question: {questions['question'][i]}")
if i < len(res['ids']) and res['ids'][i]:
for j in range(len(res['ids'][i])):
result = {
"id": res['ids'][i][j],
"original question": res['documents'][i][j],
"distance": res['distances'][i][j]
}
print(result)
else:
print("No results found")

预期结果

Question: What is LGM?
{'id': 'item10', 'original question': 'What does LGM stands for?', 'distance': 0.29572633579122415}

Question: When did Massachusetts first mandate that children be educated in schools?
{'id': 'item1', 'original question': 'In what year did Massachusetts first require children to be educated in schools?', 'distance': 0.24083293996160604}