使用 VectorDB 和 OpenAI CLIP 构建大规模图像搜索#
大规模图像搜索#
随着图像数据量的增长,对高效强大的搜索方法的需求变得至关重要。传统的基于关键词或元数据的搜索通常无法捕捉图像的完整语义含义。向量数据库能够实现语义搜索:您可以找到概念上与查询(例如,“一张云的照片”)匹配的图像,而不是依赖文本标签。
特别是
可伸缩性:现代应用程序可以处理数百万甚至数十亿张图像,这使得典型的数据库解决方案变得更慢或更难管理。
灵活性:将图像嵌入存储为向量,使您能够适应不同的搜索用例,从“查找相似产品”到“查找具有特定对象或风格的图像”。
性能:向量数据库针对高维空间中的最近邻查询进行了优化,可以在大型数据集上实现实时或近实时搜索。
SkyPilot 简化了在云中运行此类大规模任务的过程。它抽象了管理基础设施的许多复杂性,并帮助您通过托管任务高效且经济地运行计算密集型任务。
请在此处找到完整的博客文章 此处
步骤 0:设置环境#
安装以下先决条件
SkyPilot:确保您已安装 SkyPilot 并且
sky check
成功。有关说明,请参阅 SkyPilot 文档。Hugging Face Token:要从 Hugging Face Hub 下载数据集,您需要您的 token。按照以下步骤配置您的 token。
在 ~/.env
中设置 Huggingface token
HF_TOKEN=hf_xxxxx
或设置环境变量 HF_TOKEN
。
步骤 1:使用 OpenAI CLIP 计算图像数据的向量#
您需要将图像转换为向量表示(嵌入),以便将它们存储在向量数据库中。像 OpenAI 的 CLIP 这样的模型学习强大的表示,将图像和文本映射到相同的嵌入空间。这使得语义相似性计算成为可能,使得像“一张云的照片”这样的查询能够匹配相关的图像。
使用以下命令启动一个任务,该任务处理您的图像数据集并计算 CLIP 嵌入
python3 batch_compute_vectors.py
这将自动查找可用机器来计算向量。预计
...
(clip-batch-compute-vectors, pid=2523) 2025-01-27 23:57:27,387 - root - INFO - Saved partition 2 to /output/embeddings_90000_100000.parquet_part_2/data.parquet
(clip-batch-compute-vectors, pid=2523) 2025-01-27 23:59:39,720 - root - INFO - Saved partition 3 to /output/embeddings_90000_100000.parquet_part_3/data.parquet
(clip-batch-compute-vectors, pid=2523) 2025-01-28 00:01:56,707 - root - INFO - Saved partition 4 to /output/embeddings_90000_100000.parquet_part_4/data.parquet
(clip-batch-compute-vectors, pid=2523) 2025-01-28 00:04:12,200 - root - INFO - Saved partition 5 to /output/embeddings_90000_100000.parquet_part_5/data.parquet
(clip-batch-compute-vectors, pid=2523) 2025-01-28 00:06:25,009 - root - INFO - Saved partition 6 to /output/embeddings_90000_100000.parquet_part_6/data.parquet
...
您还可以使用 sky jobs queue
和 sky jobs dashboard
查看任务状态。下图显示我们的任务在不同区域启动
步骤 2:根据计算出的嵌入构建向量数据库#
获得图像嵌入后,您需要一个专门的引擎来大规模执行快速相似性搜索。在此示例中,我们使用 ChromaDB 存储和查询嵌入。此步骤将步骤 1 中的嵌入导入到向量数据库中,以便对数百万个向量进行实时或近实时搜索。
从嵌入构建数据库
sky jobs launch build_vectordb.yaml
这会分批处理生成的 clip 嵌入,产生输出
(vectordb-build, pid=2457) INFO:__main__:Processing /clip_embeddings/embeddings_0_500.parquet_part_0/data.parquet
Processing batches: 100%|██████████| 1/1 [00:00<00:00, 1.19it/s]
Processing files: 92%|█████████▏| 11/12 [00:02<00:00, 5.36it/s]INFO:__main__:Processing /clip_embeddings/embeddings_500_1000.parquet_part_0/data.parquet
Processing batches: 100%|██████████| 1/1 [00:02<00:00, 2.39s/it]
Processing files: 100%|██████████| 12/12 [00:05<00:00, 2.04it/s]/1 [00:00<?, ?it/s]
步骤 3:提供构建好的向量数据库服务#
为了提供构建好的数据库服务,您需要暴露一个 API 端点,以便其他应用程序(或您的本地客户端)可以调用它来执行语义搜索。查询允许您确认数据库正在工作,并检索给定文本查询的语义匹配项。您可以将此端点集成到更大的应用程序中(例如图像搜索引擎或推荐系统)。
提供构建好的数据库服务
sky launch -c vecdb_serve serve_vectordb.yaml
它运行托管的向量数据库服务。或者,您可以运行
sky serve up serve_vectordb.yaml -n vectordb
这会将您的向量数据库作为服务部署到云实例上,并允许您通过公共端点与其交互。Sky Serve 提供自动健康检查和服务的扩缩容。
查询构建好的数据库,
如果您通过 sky launch
运行,使用
sky status --ip vecdb_serve
已部署的集群。
如果您通过 sky serve
运行,您可以运行
sky serve status vectordb --endpoint
获取服务的端点地址。
包含的文件#
batch_compute_vectors.py
"""
Use skypilot to launch managed jobs that will run the embedding calculation.
This script is responsible for splitting the input dataset up among several workers,
then using skypilot to launch managed jobs for each worker. We use compute_vectors.yaml
to define the managed job info.
"""
#!/usr/bin/env python3
import argparse
import os
import sky
def calculate_job_range(start_idx: int, end_idx: int, job_rank: int,
total_jobs: int) -> tuple[int, int]:
"""Calculate the range of indices this job should process.
Args:
start_idx: Global start index
end_idx: Global end index
job_rank: Current job's rank (0-based)
total_jobs: Total number of jobs
Returns:
Tuple of [job_start_idx, job_end_idx)
"""
total_range = end_idx - start_idx
chunk_size = total_range // total_jobs
remainder = total_range % total_jobs
# Distribute remainder across first few jobs
job_start = start_idx + (job_rank * chunk_size) + min(job_rank, remainder)
if job_rank < remainder:
chunk_size += 1
job_end = job_start + chunk_size
return job_start, job_end
def main():
parser = argparse.ArgumentParser(
description='Launch batch CLIP inference jobs')
parser.add_argument('--start-idx',
type=int,
default=0,
help='Global start index in dataset')
parser.add_argument('--end-idx',
type=int,
default=1000000,
help='Global end index in dataset, not inclusive')
parser.add_argument('--num-jobs',
type=int,
default=100,
help='Number of jobs to partition the work across')
parser.add_argument('--env-path',
type=str,
default='~/.env',
help='Path to the environment file')
args = parser.parse_args()
# Try to get HF_TOKEN from environment first, then ~/.env file
hf_token = os.environ.get('HF_TOKEN')
if not hf_token:
env_path = os.path.expanduser(args.env_path)
if os.path.exists(env_path):
with open(env_path) as f:
for line in f:
if line.startswith('HF_TOKEN='):
hf_token = line.strip().split('=')[1]
break
if not hf_token:
raise ValueError("HF_TOKEN not found in ~/.env or environment variable")
# Load the task template
task = sky.Task.from_yaml('compute_vectors.yaml')
# Launch jobs for each partition
for job_rank in range(args.num_jobs):
# Calculate index range for this job
job_start, job_end = calculate_job_range(args.start_idx, args.end_idx,
job_rank, args.num_jobs)
# Update environment variables for this job
task_copy = task.update_envs({
'START_IDX': job_start,
'END_IDX': job_end,
'HF_TOKEN': hf_token,
})
sky.jobs.launch(
task_copy,
name=f'vector-compute-{job_start}-{job_end}',
)
if __name__ == '__main__':
main()
build_vectordb.yaml
name: vectordb-build
workdir: .
file_mounts:
/clip_embeddings:
name: sky-demo-embedding
# this needs to be the same as the source in the compute_vectors.yaml
mode: MOUNT
/vectordb:
name: sky-vectordb
# this needs to be the same as the source in the serve_vectordb.yaml
mode: MOUNT
/images:
name: sky-demo-image
# this needs to be the same as the source in compute_vectors.yaml
mode: MOUNT
setup: |
pip install chromadb pandas tqdm pyarrow
run: |
python scripts/build_vectordb.py \
--collection-name clip_embeddings \
--persist-dir /vectordb/chroma \
--embeddings-dir /clip_embeddings \
--batch-size 1000
compute_vectors.yaml
name: clip-batch-compute-vectors
workdir: .
resources:
accelerators:
# ordered by pricing (cheapest to most expensive)
T4: 1
L4: 1
A10G: 1
A10: 1
V100: 1
memory: 32+
any_of:
- use_spot: true
- use_spot: false
num_nodes: 1
file_mounts:
/output:
name: sky-demo-embedding
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT
/images:
name: sky-demo-image
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT
envs:
# These env vars are required but should be passed in at launch time.
HF_TOKEN: ''
START_IDX: ''
END_IDX: ''
setup: |
pip install numpy==1.26.4
pip install torch==2.5.1 torchvision==0.20.1 ftfy regex tqdm
pip install datasets webdataset requests Pillow open_clip_torch
pip install fastapi uvicorn aiohttp pandas pyarrow tenacity
run: |
python scripts/compute_vectors.py \
--output-path "/output/embeddings_${START_IDX}_${END_IDX}.parquet" \
--start-idx ${START_IDX} \
--end-idx ${END_IDX} \
--batch-size 64 \
--checkpoint-size 1000
echo "Processing complete. Results saved in node-specific files under /output/"
scripts/build_vectordb.py
"""
This script is responsible for building the vector database from the mounted bucket and saving
it to another mounted bucket.
"""
import argparse
import base64
from concurrent.futures import as_completed
from concurrent.futures import ProcessPoolExecutor
import glob
import logging
import multiprocessing
import os
import pickle
import shutil
import tempfile
import chromadb
import numpy as np
import pandas as pd
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def list_local_parquet_files(mount_path: str, prefix: str) -> list:
"""List all parquet files in the mounted S3 directory."""
search_path = os.path.join(mount_path, prefix, '**/*.parquet')
parquet_files = glob.glob(search_path, recursive=True)
return parquet_files
def process_parquet_file(args):
"""Process a single parquet file and return the processed data."""
parquet_file, batch_size = args
try:
results = []
df = pd.read_parquet(parquet_file)
# Process in batches
for i in range(0, len(df), batch_size):
batch_df = df.iloc[i:i + batch_size]
# Extract data from DataFrame and unpack the pickled data
ids = [str(idx) for idx in batch_df['idx']]
unpacked_data = [pickle.loads(row) for row in batch_df['output']]
images_base64, embeddings = zip(*unpacked_data)
results.append((ids, embeddings, images_base64))
return results
except Exception as e:
logger.error(f'Error processing file {parquet_file}: {str(e)}')
return None
def main():
parser = argparse.ArgumentParser(
description='Build ChromaDB from mounted S3 parquet files')
parser.add_argument('--collection-name',
type=str,
default='clip_embeddings',
help='ChromaDB collection name')
parser.add_argument('--persist-dir',
type=str,
default='/vectordb/chroma',
help='Directory to persist ChromaDB')
parser.add_argument(
'--batch-size',
type=int,
default=1000,
help='Batch size for processing, this needs to fit in memory')
parser.add_argument('--embeddings-dir',
type=str,
default='/clip_embeddings',
help='Path to mounted bucket containing parquet files')
parser.add_argument(
'--prefix',
type=str,
default='',
help='Prefix path within mounted bucket to search for parquet files')
args = parser.parse_args()
# Create a temporary directory for building the database. The
# mounted bucket does not support append operation, so build in
# the tmpdir and then copy it to the final location.
with tempfile.TemporaryDirectory() as temp_dir:
logger.info(f'Using temporary directory: {temp_dir}')
# Initialize ChromaDB in temporary directory
client = chromadb.PersistentClient(path=temp_dir)
# Create or get collection for chromadb
# it attempts to create a collection with the same name
# if it already exists, it will get the collection
try:
collection = client.create_collection(
name=args.collection_name,
metadata={'description': 'CLIP embeddings from dataset'})
logger.info(f'Created new collection: {args.collection_name}')
except ValueError:
collection = client.get_collection(name=args.collection_name)
logger.info(f'Using existing collection: {args.collection_name}')
# List parquet files from mounted directory
parquet_files = list_local_parquet_files(args.embeddings_dir,
args.prefix)
logger.info(f'Found {len(parquet_files)} parquet files')
# Process files in parallel
max_workers = max(1,
multiprocessing.cpu_count() - 1) # Leave one CPU free
logger.info(f'Processing files using {max_workers} workers')
with ProcessPoolExecutor(max_workers=max_workers) as executor:
# Submit all files for processing
future_to_file = {
executor.submit(process_parquet_file, (file, args.batch_size)):
file for file in parquet_files
}
# Process results as they complete
for future in tqdm(as_completed(future_to_file),
total=len(parquet_files),
desc='Processing files'):
file = future_to_file[future]
try:
results = future.result()
if results:
for ids, embeddings, images_paths in results:
collection.add(ids=list(ids),
embeddings=list(embeddings),
documents=list(images_paths))
except Exception as e:
logger.error(f'Error processing file {file}: {str(e)}')
continue
logger.info('Vector database build complete!')
logger.info(f'Total documents in collection: {collection.count()}')
# Copy the completed database to the final location
logger.info(f'Copying database to final location: {args.persist_dir}')
if os.path.exists(args.persist_dir):
logger.info('Removing existing database directory')
shutil.rmtree(args.persist_dir)
shutil.copytree(temp_dir, args.persist_dir)
logger.info('Database copy complete!')
if __name__ == '__main__':
main()
scripts/compute_vectors.py
"""
This script is responsible for computing the embeddings for the ImageNet dataset.
"""
import abc
import asyncio
import base64
from io import BytesIO
import logging
import os
from pathlib import Path
import pickle
import shutil
from typing import (Any, AsyncIterator, Dict, Generic, List, Optional, Tuple,
TypeVar)
import numpy as np
import pandas as pd
from PIL import Image
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
class BatchProcessor():
"""Process ImageNet images with CLIP.
This script is responsible for computing the embeddings for the ImageNet dataset.
1. setup_model initializes the model
2. get_dataset_iterator will yield individual items from the dataset
3. do_data_loading will get an item from the dataset iterator and do any preprocessing
4. the loaded items will be batched and handed to do_batch_processing for the ultimate processing
"""
def __init__(self,
output_path: str,
images_path: str = '/images',
model_name: str = 'ViT-bigG-14',
dataset_name: str = 'ILSVRC/imagenet-1k',
pretrained: str = 'laion2b_s39b_b160k',
device: Optional[str] = None,
split: str = 'train',
streaming: bool = True,
batch_size: int = 32,
checkpoint_size: int = 100,
start_idx: int = 0,
end_idx: Optional[int] = None):
self.output_path = Path(output_path) # Convert to Path object
self.images_path = Path(images_path) # Path to store images
self.batch_size = batch_size
self.checkpoint_size = checkpoint_size
self.start_idx = start_idx
self.end_idx = end_idx
self._current_batch = []
# Create images directory if it doesn't exist
self.images_path.mkdir(parents=True, exist_ok=True)
# CLIP-specific attributes
self.model_name = model_name
self.pretrained = pretrained
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.dataset_name = dataset_name
self.split = split
self.streaming = streaming
self.model = None
self.preprocess = None
self.partition_counter = 0
async def setup_model(self):
"""Set up the CLIP model."""
import open_clip
model, _, preprocess = open_clip.create_model_and_transforms(
self.model_name, pretrained=self.pretrained, device=self.device)
self.model = model
self.preprocess = preprocess
async def get_dataset_iterator(self) -> AsyncIterator[Tuple[int, Any]]:
"""Load data from a HuggingFace dataset."""
from datasets import load_dataset
dataset = load_dataset(self.dataset_name,
streaming=self.streaming,
trust_remote_code=True)[self.split]
if self.start_idx > 0:
dataset = dataset.skip(self.start_idx)
for idx, item in enumerate(dataset, start=self.start_idx):
if self.end_idx and idx >= self.end_idx:
break
yield idx, item
async def do_data_loading(
self) -> AsyncIterator[Tuple[int, Tuple[torch.Tensor, Any]]]:
"""Load and preprocess ImageNet images."""
async for idx, item in self.get_dataset_iterator():
try:
# ImageNet provides PIL Images directly
tensor = self.preprocess(item['image'])
if tensor is not None:
# Pass through both the tensor and original image
yield idx, (tensor, item['image'])
except Exception as e:
logging.debug(
f'Error preprocessing image at index {idx}: {str(e)}')
def save_image(self, idx: int, image: Image.Image) -> str:
"""Save image to the mounted bucket and return its path."""
# Create a subdirectory based on the first few digits of the index to avoid too many files in one directory
subdir = str(idx // 100000).zfill(4)
save_dir = self.images_path / subdir
save_dir.mkdir(parents=True, exist_ok=True)
# Save image with index as filename
image_path = save_dir / f'{idx}.jpg'
image.save(image_path, format='JPEG', quality=95)
# Return relative path from images root
return str(Path(subdir) / f'{idx}.jpg')
async def do_batch_processing(
self, batch: List[Tuple[int, Tuple[torch.Tensor, Any]]]
) -> List[Tuple[int, bytes]]:
"""Process a batch of images through CLIP."""
if self.model is None:
await self.setup_model()
# Unpack the batch
indices, batch_data = zip(*batch)
model_inputs, original_images = zip(*batch_data)
# Stack inputs into a batch
batch_tensor = torch.stack(model_inputs).to(self.device)
# Run inference
with torch.no_grad():
features = self.model.encode_image(batch_tensor)
features /= features.norm(dim=-1, keepdim=True)
# Convert to numpy arrays
embeddings = features.cpu().numpy()
# Save images and store their paths
image_paths = {}
for idx, img in zip(indices, original_images):
image_path = self.save_image(idx, img)
image_paths[idx] = image_path
# Return both embeddings and image paths
return [(idx, pickle.dumps((image_paths[idx], arr)))
for idx, arr in zip(indices, embeddings)]
async def find_existing_progress(self) -> Tuple[int, int]:
"""
Find the highest processed index and partition counter from existing files.
Returns:
Tuple[int, int]: (highest_index, next_partition_number)
"""
if not self.output_path.parent.exists():
self.output_path.parent.mkdir(parents=True, exist_ok=True)
return self.start_idx, 0
partition_files = list(
self.output_path.parent.glob(
f'{self.output_path.stem}_part_*.parquet'))
print(f'Partition files: {partition_files}')
if not partition_files:
return self.start_idx, 0
max_idx = self.start_idx
max_partition = -1
for file in partition_files:
# Extract partition number from filename
try:
partition_num = int(file.stem.split('_part_')[1])
max_partition = max(max_partition, partition_num)
# Read the file and find highest index
df = pd.read_parquet(file)
if not df.empty:
max_idx = max(max_idx, df['idx'].max())
except Exception as e:
logging.warning(f'Error processing file {file}: {e}')
return max_idx, max_partition + 1
def save_results_to_parquet(self, results: list):
"""Save results to a parquet file with atomic write."""
if not results:
return
df = pd.DataFrame(results, columns=['idx', 'output'])
final_path = f'{self.output_path}_part_{self.partition_counter}.parquet'
temp_path = f'/tmp/{self.partition_counter}.tmp'
# Write to temporary file first
df.to_parquet(temp_path, engine='pyarrow', index=False)
# Copy from temp to final destination
shutil.copy2(temp_path, final_path)
os.remove(temp_path) # Clean up temp file
logging.info(
f'Saved partition {self.partition_counter} to {final_path} with {len(df)} rows'
)
self.partition_counter += 1
async def run(self):
"""
Run the batch processing pipeline with recovery support.
"""
# Initialize the model
if self.model is None:
await self.setup_model()
# Find existing progress
resume_idx, self.partition_counter = await self.find_existing_progress()
self.start_idx = max(self.start_idx, resume_idx + 1)
logging.info(
f'Starting processing from index {self.start_idx} (partition {self.partition_counter})'
)
results = []
async for idx, input_data in self.do_data_loading():
self._current_batch.append((idx, input_data))
if len(self._current_batch) >= self.batch_size:
batch_results = await self.do_batch_processing(
self._current_batch)
results.extend(batch_results)
self._current_batch = []
if len(results) >= self.checkpoint_size:
self.save_results_to_parquet(results)
results.clear()
# Process any remaining items in the batch
if self._current_batch:
batch_results = await self.do_batch_processing(self._current_batch)
results.extend(batch_results)
# Write the final partition if there are any leftover results
if results:
self.save_results_to_parquet(results)
async def main():
"""Example usage of the batch processing framework."""
import argparse
# Parse command line arguments
parser = argparse.ArgumentParser(
description='Run CLIP batch processing on ImageNet')
parser.add_argument('--output-path',
type=str,
default='embeddings.parquet',
help='Path to output parquet file')
parser.add_argument('--start-idx',
type=int,
default=0,
help='Starting index in dataset')
parser.add_argument('--end-idx',
type=int,
default=10000,
help='Ending index in dataset')
parser.add_argument('--batch-size',
type=int,
default=50,
help='Batch size for processing')
parser.add_argument('--checkpoint-size',
type=int,
default=100,
help='Number of results before checkpointing')
parser.add_argument('--model-name',
type=str,
default='ViT-bigG-14',
help='CLIP model name')
parser.add_argument('--images-path',
type=str,
default='/images',
help='Path to store images')
args = parser.parse_args()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Initialize processor
processor = BatchProcessor(output_path=args.output_path,
start_idx=args.start_idx,
end_idx=args.end_idx,
batch_size=args.batch_size,
checkpoint_size=args.checkpoint_size,
model_name=args.model_name,
images_path=args.images_path)
# Run processing
await processor.run()
if __name__ == '__main__':
asyncio.run(main())
scripts/serve_vectordb.py
"""
This script is responsible for serving the vector database.
"""
import argparse
import base64
import logging
import os
from pathlib import Path
from typing import List, Optional
import chromadb
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi.responses import FileResponse
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import numpy as np
import open_clip
from pydantic import BaseModel
import torch
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title='Vector Database Search API')
# Global variables for model and database
model = None
tokenizer = None
collection = None
device = None
images_dir = None
class SearchQuery(BaseModel):
text: str
n_results: Optional[int] = 5
class SearchResult(BaseModel):
image_path: str
similarity: float
def encode_text(text: str, model_name: str = 'ViT-bigG-14') -> np.ndarray:
"""Encode text using CLIP model."""
global model, tokenizer, device
# Tokenize and encode
text_tokens = tokenizer([text]).to(device)
with torch.no_grad():
text_features = model.encode_text(text_tokens)
# Normalize the features
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy()
def query_collection(query_embedding: np.ndarray,
n_results: int = 5) -> List[SearchResult]:
"""Query the collection and return top matches with scores."""
global collection
results = collection.query(query_embeddings=query_embedding.tolist(),
n_results=n_results,
include=['metadatas', 'distances', 'documents'])
# Get image paths and distances
image_paths = results['documents'][0]
distances = results['distances'][0]
# Convert distances to similarities (cosine similarity = 1 - distance/2)
similarities = [1 - (d / 2) for d in distances]
return [
SearchResult(image_path=img_path, similarity=similarity)
for img_path, similarity in zip(image_paths, similarities)
]
@app.post('/search', response_model=List[SearchResult])
async def search(query: SearchQuery):
"""Search endpoint that takes a text query and returns similar images."""
try:
# Encode the query text
query_embedding = encode_text(query.text)
# Query the collection
results = query_collection(query_embedding, query.n_results)
return results
except Exception as e:
logger.error(f'Error processing query: {str(e)}')
raise HTTPException(status_code=500, detail=str(e))
@app.get('/image/{subpath:path}')
async def get_image(subpath: str):
"""Serve an image from the mounted bucket."""
image_path = os.path.join(images_dir, subpath)
if not os.path.exists(image_path):
raise HTTPException(status_code=404, detail='Image not found')
return FileResponse(image_path, media_type='image/jpeg')
@app.get('/health')
async def health_check():
"""Health check endpoint."""
return {
'status': 'healthy',
'collection_size': collection.count() if collection else 0
}
@app.get('/', response_class=HTMLResponse)
async def get_search_page():
"""Serve a simple search interface."""
return """
<html>
<head>
<title>Image Search</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
line-height: 1.6;
background-color: #f5f5f5;
color: #333;
min-height: 100vh;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
}
.search-container {
background: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
margin-bottom: 2rem;
text-align: center;
}
h1 {
color: #2c3e50;
margin-bottom: 1.5rem;
font-size: 2.5rem;
}
.search-box {
display: flex;
gap: 10px;
max-width: 600px;
margin: 0 auto;
}
input {
flex: 1;
padding: 12px 20px;
border: 2px solid #e0e0e0;
border-radius: 25px;
font-size: 16px;
transition: all 0.3s ease;
}
input:focus {
outline: none;
border-color: #3498db;
box-shadow: 0 0 5px rgba(52, 152, 219, 0.3);
}
button {
padding: 12px 30px;
background: #3498db;
color: white;
border: none;
border-radius: 25px;
cursor: pointer;
font-size: 16px;
transition: background 0.3s ease;
}
button:hover {
background: #2980b9;
}
.results {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
gap: 1.5rem;
padding: 1rem;
}
.result {
background: white;
border-radius: 10px;
overflow: hidden;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
transition: transform 0.3s ease;
}
.result:hover {
transform: translateY(-5px);
}
.result img {
width: 100%;
height: 200px;
object-fit: cover;
}
.result-info {
padding: 1rem;
}
.similarity-score {
color: #2c3e50;
font-weight: 600;
}
#loading {
display: none;
text-align: center;
margin: 2rem 0;
font-size: 1.2rem;
color: #666;
}
</style>
</head>
<body>
<div class="container">
<div class="search-container">
<h1>SkyPilot Image Search</h1>
<div class="search-box">
<input type="text" id="searchInput" placeholder="Enter your search query..."
onkeypress="if(event.key === 'Enter') search()">
<button onclick="search()">Search</button>
</div>
</div>
<div id="loading">Searching...</div>
<div id="results" class="results"></div>
</div>
<script>
async function search() {
const searchInput = document.getElementById('searchInput');
const loading = document.getElementById('loading');
const resultsDiv = document.getElementById('results');
if (!searchInput.value.trim()) return;
loading.style.display = 'block';
resultsDiv.innerHTML = '';
try {
const response = await fetch('/search', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json'
},
body: JSON.stringify({
text: searchInput.value.trim(),
n_results: 12
})
});
if (!response.ok) {
const errorData = await response.json();
throw new Error(errorData.detail || 'Search failed');
}
const results = await response.json();
resultsDiv.innerHTML = results.map(result => `
<div class="result">
<img src="/image/${result.image_path}"
alt="Search result">
<div class="result-info">
<p class="similarity-score">
Similarity: ${(result.similarity * 100).toFixed(1)}%
</p>
</div>
</div>
`).join('');
} catch (error) {
resultsDiv.innerHTML = `
<p style="color: #e74c3c; text-align: center; width: 100%;">
Error: ${error.message}
</p>
`;
} finally {
loading.style.display = 'none';
}
}
</script>
</body>
</html>
"""
def main():
parser = argparse.ArgumentParser(
description='Serve Vector Database with FastAPI')
parser.add_argument('--host',
type=str,
default='0.0.0.0',
help='Host to serve on')
parser.add_argument('--port',
type=int,
default=8000,
help='Port to serve on')
parser.add_argument('--collection-name',
type=str,
default='clip_embeddings',
help='ChromaDB collection name')
parser.add_argument('--persist-dir',
type=str,
default='/vectordb/chroma',
help='Directory where ChromaDB is persisted')
parser.add_argument('--images-dir',
type=str,
default='/images',
help='Directory where images are stored')
parser.add_argument('--model-name',
type=str,
default='ViT-bigG-14',
help='CLIP model name')
args = parser.parse_args()
# Initialize global variables
global model, tokenizer, collection, device, images_dir
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f'Using device: {device}')
# Set images directory
images_dir = args.images_dir
# Load the model
import open_clip
model, _, _ = open_clip.create_model_and_transforms(
args.model_name, pretrained='laion2b_s39b_b160k', device=device)
tokenizer = open_clip.get_tokenizer(args.model_name)
# Initialize ChromaDB client
client = chromadb.PersistentClient(path=args.persist_dir)
try:
# Get the collection
collection = client.get_collection(name=args.collection_name)
logger.info(f'Connected to collection: {args.collection_name}')
logger.info(f'Total documents in collection: {collection.count()}')
except ValueError as e:
logger.error(f'Error: {str(e)}')
logger.error(
'Make sure the collection exists and the persist_dir is correct.')
raise
# Start the server
import uvicorn
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == '__main__':
main()
serve_vectordb.yaml
name: vectordb-serve
workdir: .
resources:
accelerators:
# ordered by pricing (cheapest to most expensive)
# skypilot will try to use the cheapest available accelerator
# serve requires a GPU to compute the embeddings
T4: 1
L4: 1
A10G: 1
A10: 1
V100: 1
memory: 32+
ports: 8000
use_spot: true
file_mounts:
/vectordb:
name: sky-vectordb
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT
/images:
name: sky-demo-image
# this needs to be the same as the source in the build_vectordb.yaml
mode: MOUNT
setup: |
pip install numpy==1.26.4
pip install torch==2.5.1 torchvision==0.20.1 ftfy regex tqdm
pip install open_clip_torch chromadb pandas
pip install fastapi uvicorn pydantic
run: |
python scripts/serve_vectordb.py \
--collection-name clip_embeddings \
--persist-dir /vectordb/chroma \
--images-dir /images \
--host 0.0.0.0 \
--port 8000
service:
replicas: 1
readiness_probe:
path: /health