跳到主要内容

创建自定义嵌入函数

您可以通过实现 EmbeddedFunction 协议来创建自定义嵌入函数。该功能包括以下内容:

  • 执行 call 方法,该方法能够接受 Documents (str or List[str]) 和返回 Embeddings (List[List[float]])

  • 选择性的实现一个维度属性以返回向量维度。

前提条件

创建自定义嵌入函数时,请确保以下内容:

  • 实现 __call__ 方法:

    • 每个向量必须具有相同的维度。
    • 传入:单个或者多个 documents 的类型为:str 或者 List[str]。
    • 返回:嵌入向量的字段类型为:List[List[float]]
  • (推荐)实现 dimension 属性:

    • 返回:此函数生成的向量类型为 int
    • 创建 collections 有助于验证唯一性。
  • 处理特殊情况

    • 单个字符串输入应转为列表。
    • 空输入应返回空列表。
    • 输出中的所有向量必须具有相同的维度。

示例1:句子转换器自定义嵌入功能

from typing import List, Union
from pyseekdb import EmbeddingFunction, Client, HNSWConfiguration

Documents = Union[str, List[str]]
Embeddings = List[List[float]]

class SentenceTransformerCustomEmbeddingFunction(EmbeddingFunction[Documents]):
"""
A custom embedding function using sentence-transformers with a specific model.
"""

def __init__(self, model_name: str = "all-mpnet-base-v2", device: str = "cpu"): # TODO: your own model name and device
"""
Initialize the sentence-transformer embedding function.

Args:
model_name: Name of the sentence-transformers model to use
device: Device to run the model on ('cpu' or 'cuda')
"""
self.model_name = model_name
self.device = device
self._model = None
self._dimension = None

def _ensure_model_loaded(self):
"""Lazy load the embedding model"""
if self._model is None:
try:
from sentence_transformers import SentenceTransformer
self._model = SentenceTransformer(self.model_name, device=self.device)
# Get dimension from model
test_embedding = self._model.encode(["test"], convert_to_numpy=True)
self._dimension = len(test_embedding[0])
except ImportError:
raise ImportError(
"sentence-transformers is not installed. "
"Please install it with: pip install sentence-transformers"
)

@property
def dimension(self) -> int:
"""Get the dimension of embeddings produced by this function"""
self._ensure_model_loaded()
return self._dimension

def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings for the given documents.

Args:
input: Single document (str) or list of documents (List[str])

Returns:
List of embedding vectors
"""
self._ensure_model_loaded()

# Handle single string input
if isinstance(input, str):
input = [input]

# Handle empty input
if not input:
return []

# Generate embeddings
embeddings = self._model.encode(
input,
convert_to_numpy=True,
show_progress_bar=False
)

# Convert numpy arrays to lists
return [embedding.tolist() for embedding in embeddings]

# Use the custom embedding function
client = Client()

# Initialize embedding function with all-mpnet-base-v2 model (768 dimensions)
ef = SentenceTransformerCustomEmbeddingFunction(
model_name='all-mpnet-base-v2', # TODO: your own model name
device='cpu' # TODO: your own device
)

# Get the dimension from the embedding function
dimension = ef.dimension
print(f"Embedding dimension: {dimension}")

# Create collection with matching dimension
collection_name = "my_collection"
if client.has_collection(collection_name):
client.delete_collection(collection_name)

collection = client.create_collection(
name=collection_name,
configuration=HNSWConfiguration(dimension=dimension, distance='cosine'),
embedding_function=ef
)

# Test the embedding function
print("\nTesting embedding function...")
test_documents = ["Hello world", "This is a test", "Sentence transformers are great"]
embeddings = ef(test_documents)
print(f"Generated {len(embeddings)} embeddings")
print(f"Each embedding has {len(embeddings[0])} dimensions")

# Add some documents to the collection
print("\nAdding documents to collection...")
collection.add(
ids=["1", "2", "3"],
documents=test_documents,
metadatas=[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}]
)

# Query the collection
print("\nQuerying collection...")
results = collection.query(
query_texts="Hello",
n_results=2
)

print("\nQuery results:")
for i in range(len(results['ids'][0])):
print(f"ID: {results['ids'][0][i]}")
print(f"Document: {results['documents'][0][i]}")
print(f"Distance: {results['distances'][0][i]}")
print()

# Clean up
client.delete_collection(name=collection_name)
print("Test completed successfully!")

示例2:OpenAI 嵌入函数

from typing import List, Union
import os
from openai import OpenAI
from pyseekdb import EmbeddingFunction
import pyseekdb

Documents = Union[str, List[str]]
Embeddings = List[List[float]]

class QWenEmbeddingFunction(EmbeddingFunction[Documents]):
"""
A custom embedding function using OpenAI's embedding API.
"""

def __init__(self, model_name: str = "", api_key: str = ""): # TODO: your own model name and api key
"""
Initialize the OpenAI embedding function.

Args:
model_name: Name of the OpenAI embedding model
api_key: OpenAI API key (if not provided, uses OPENAI_API_KEY env var)
"""
self.model_name = model_name
self.api_key = api_key or os.environ.get('OPENAI_API_KEY') # TODO: your own api key
if not self.api_key:
raise ValueError("OpenAI API key is required")

self._dimension = 1024 # TODO: your own dimension

@property
def dimension(self) -> int:
"""Get the dimension of embeddings produced by this function"""
if self._dimension is None:
# Call API to get dimension (or use known values)
raise ValueError("Dimension not set for this model")
return self._dimension

def __call__(self, input: Documents) -> Embeddings:
"""
Generate embeddings using OpenAI API.

Args:
input: Single document (str) or list of documents (List[str])

Returns:
List of embedding vectors
"""
# Handle single string input
if isinstance(input, str):
input = [input]

# Handle empty input
if not input:
return []

# Call OpenAI API
client = OpenAI(
api_key=self.api_key,
base_url="" # TODO: your own base url
)
response = client.embeddings.create(
model=self.model_name,
input=input
)

# Extract embeddings
embeddings = [item.embedding for item in response.data]
return embeddings

# Use the custom embedding function
collection_name = "my_collection"
ef = QWenEmbeddingFunction()
client = pyseekdb.Client()

if client.has_collection(collection_name):
client.delete_collection(collection_name)

collection = client.create_collection(
name=collection_name,
embedding_function=ef
)

collection.add(
ids=["1", "2", "3"],
documents=["Hello", "World", "Hello World"],
metadatas=[{"tag": "A"}, {"tag": "B"}, {"tag": "C"}]
)

results = collection.query(
query_texts="Hello",
n_results=2
)
for i in range(len(results['ids'][0])):
print(results['ids'][0][i])
print(results['documents'][0][i])
print(results['metadatas'][0][i])
print(results['distances'][0][i])
print()

client.delete_collection(name=collection_name)