This commit is contained in:
2024-11-04 05:20:42 +08:00
parent e990473dcd
commit 07de4d5fd5
24 changed files with 1385 additions and 2 deletions

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
__pycache__

119
README.md
View File

@@ -1,3 +1,118 @@
# reverse_image_search_gpu # reverse_image_search
图片反向搜索, 利用GPU加速 图片反向检索
```urls
https://www.gameui.net/api/default/create_vector?count=1024 # 预先生成指定数量的向量
https://www.gameui.net/api/default/create_index # 手动重建索引
```
## Intall
```bash
# 先安装 milvus v2.1.4 矢量数据库
wget https://github.com/milvus-io/milvus/releases/download/v2.1.4/milvus_2.1.4-1_amd64.deb
sudo apt-get update
sudo dpkg -i milvus_2.1.4-1_amd64.deb
sudo apt-get -f install
# 查看 Milvus 及其依賴的狀態(etcd 和 MinIO)
sudo systemctl status milvus
sudo systemctl status milvus-etcd
sudo systemctl status milvus-minio
# 安装 etcdctl (管理工具)
wget https://github.com/etcd-io/etcd/releases/download/v3.4.22/etcd-v3.4.22-linux-amd64.tar.gz
# 解压 etcdctl 到 /bin/
# 向 ~/.bashrc 添加 export ETCDCTL_API=3
# 安装 python 包管理工具 pip
sudo apt install python3-pip
# 使用 venv 创建虚拟环境(注意安装路径是当前用户的 .local/bin)
python3 -m venv venv
source venv/bin/activate
python -m pip install --upgrade pip
# 为 .bashrc 设置快捷命令(使用echo或cat追加写)
vim ~/.bashrc
alias venv='source venv/bin/activate'
source ~/.bashrc
# 不使用虚拟环境的防止依赖冲突方法
pip freeze > requirements.txt
pip uninstall -r requirements.txt
# 安装依赖
pip3 install -r requirements.txt
# 注意 python 要求 V3.9.7 版本以上
python3 src/main.py
```
挂载对象存储服务作为存储盘
## 批量处理工具
- [x] 使项目不依赖平台提供的服务, 易于迁移
- [x] 操作读写数据库, 不使用mysql
- [x] /img/no_content.8dd8acce.png 冲突
- [x] 在上海区新建OSS实例作为磁盘挂载到服务器(用于存储缓存图) ...登录
- [x] 在服务器提供接口(设置CDN层分发) ...登录
- [x] 获取前端项目(修改列表图片三类分辨率)
- [x] 通过 path 获取缩略图
- [x] 通过 articleDetails_id 获取缩略图
- [x] 图片向量无缓存时重新生成
- [ ] 提前生成原图webp以提高首次打开速度
## RTT
1. OSS 中 gameui-webp Bucket 作为图片缓存层
2. 通过 /img/xxx.webp 接口提供图片, 由 OSS 作镜像回源
NGINX config
```config
server {
location /docs {
proxy_pass http://gameui_ai_server;
proxy_redirect off;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
location /api/ {
proxy_pass http://gameui_ai_server;
proxy_redirect off;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
location ~* /img/([0-9]+)\.(webp|jpeg) {
proxy_pass http://gameui_ai_server;
proxy_redirect off;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
location ~* /img/([0-9]+).*\.(webp|jpeg) {
proxy_pass http://gameui_ai_server;
proxy_redirect off;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
location ~* /img/article-([0-9]+).*\.(webp|jpeg) {
proxy_pass http://gameui_ai_server;
proxy_redirect off;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
}
```

61
configs/config.py Normal file
View File

@@ -0,0 +1,61 @@
import os
import json
import time
from apscheduler.schedulers.background import BackgroundScheduler
# 以 .env 文件中的环境变量为准, 检查是否存在 .env 文件, 没有则创建
if not os.path.exists('.env'):
print('请输入环境变量参数, 将会写入 .env 文件中')
OSS_HOST = input('OSS_HOST: ')
MYSQL_HOST = input('MYSQL_HOST: ')
MILVUS_HOST = input('MILVUS_HOST: ')
with open('.env', 'w') as f:
f.write(f'OSS_HOST={OSS_HOST}\n')
f.write(f'MYSQL_HOST={MYSQL_HOST}\n')
f.write(f'MILVUS_HOST={MILVUS_HOST}\n')
# 读取 .env 文件中的环境变量
with open('.env', 'r') as f:
env = f.readlines()
env = list(filter(lambda x: not x.startswith('#') and not x.startswith('\n') and len(x.split('=')) == 2, env))
env = list(map(lambda x: x.replace('\n', '').split('='), env))
env = {k: v for k, v in env}
print(json.dumps(env, indent=4, ensure_ascii=False))
OSS_HOST = env.get('OSS_HOST')
MYSQL_HOST = env.get('MYSQL_HOST')
MILVUS_HOST = env.get('MILVUS_HOST')
# 创建上传图片的临时目录
UPLOAD_PATH = '/tmp/search-images'
if not os.path.exists(UPLOAD_PATH):
os.makedirs(UPLOAD_PATH)
# 创建压缩图片的临时目录
IMAGES_PATH = '/tmp/images'
if not os.path.exists(IMAGES_PATH):
os.makedirs(IMAGES_PATH)
# 创建sqlite3数据库目录
SQLITE3_PATH = '/tmp/sqlite3'
if not os.path.exists(SQLITE3_PATH):
os.makedirs(SQLITE3_PATH)
def clear_images():
print('开始清理创建时间大于30分钟的图片缓存')
for file in os.listdir(IMAGES_PATH):
if os.path.getmtime(os.path.join(IMAGES_PATH, file)) < time.time() - 30 * 60:
print('清理图片缓存:', file)
os.remove(os.path.join(IMAGES_PATH, file))
# 构建一个定时任务每30分钟清理一次图片缓存
scheduler = BackgroundScheduler()
scheduler.add_job(clear_images, 'interval', minutes=30)
scheduler.start()

BIN
demo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 675 KiB

53
demo.py Executable file
View File

@@ -0,0 +1,53 @@
import timm
import torch
import torch.nn.functional as functional
from PIL import Image
from torchvision import transforms
# 加载预训练模型
model = timm.create_model('resnet50', pretrained=True)
model = model.eval()
# 定义图片处理流程
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 读取图片
print("loading image")
img = Image.open("demo.png").convert("RGB")
tensor = preprocess(img).unsqueeze(0)
# 检查是否有可用的GPU
if torch.cuda.is_available():
input_batch = tensor.to('cuda')
model.to('cuda')
print("start run model")
with torch.no_grad():
output = model(input_batch)
for x in output:
print(x.shape)
# 输出2048维向量
# print(output[0])
'''
from towhee import pipe, ops, DataCollection
p = (
pipe.input('path')
.map('path', 'img', ops.image_decode())
.map('img', 'vec', ops.image_embedding.timm(model_name='resnet50'))
.output('img', 'vec')
)
ea = DataCollection(p('demo.png')).to_list()
print(ea)
'''

41
download.py Normal file
View File

@@ -0,0 +1,41 @@
import io
import os
import requests
from PIL import Image, ImageFile
from models.oss import bucket_image2, bucket_webp
from configs.config import IMAGES_PATH
ImageFile.LOAD_TRUNCATED_IMAGES = True
# 下载图片(使用OSS下载)
def download_image(url:str) -> Image:
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
try:
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
obj = bucket_image2.get_object(url).read()
return Image.open(io.BytesIO(obj))
except Exception:
return None
else:
try:
response = requests.get(url)
return Image.open(io.BytesIO(response.content))
except Exception:
print('图片下载失败:', url)
return None
# 生成缩略图, 写入OSS
def generate_thumbnail(image:Image, id:int, version:str, n:int, w:int, ext:str):
path = f"{id}-{version}@{n}x{w}.{ext}"
if bucket_webp.object_exists(path):
print('缩略图已经存在:', path)
return
# 将 image 对象复制一份, 防止影响原图
image = image.copy()
image.thumbnail((n*w, image.size[1]))
image.save(f"{IMAGES_PATH}/{path}", ext, save_all=True)
bucket_webp.put_object_from_file(path, f"{IMAGES_PATH}/{path}")
os.remove(f"{IMAGES_PATH}/{path}")
print('缩略图写入 OSS:', path)

19
log.yaml Normal file
View File

@@ -0,0 +1,19 @@
version: 1
formatters:
simple:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
handlers:
console:
class: logging.StreamHandler
level: DEBUG
formatter: simple
stream: ext://sys.stdout
loggers:
simpleExample:
level: DEBUG
handlers: [console]
propagate: no
root:
level: DEBUG
handlers: [console]

28
main.py Executable file
View File

@@ -0,0 +1,28 @@
# -*- coding:utf-8 -*-
import sys
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from functools import lru_cache
from routers import reverse, user, task, img, user_collect
# 初始化 FastAPI
app = FastAPI(title="GameUI", description="GameUI", version="1.5.0", openapi_url="/docs/openapi.json", docs_url="/docs", redoc_url="/redoc")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
# 导入路由
app.include_router(user.router, prefix='/api/user', tags=['用户'])
app.include_router(task.router, prefix='/api/task', tags=['任务'])
app.include_router(reverse.router, prefix='/api/default', tags=['搜图'])
app.include_router(user_collect.router, prefix='/api/user_collect', tags=['收藏'])
app.include_router(img.router, prefix='/imgs', tags=['图片'])
# 启动服务
if __name__ == '__main__':
port = 5002 if len(sys.argv) < 2 else int(sys.argv[1])
uvicorn.run(app='main:app', host='0.0.0.0', port=port, reload=True, workers=1)

26
models/milvus.py Normal file
View File

@@ -0,0 +1,26 @@
import pymilvus
from configs.config import MILVUS_HOST
# 连接 Milvus (开启 Milvus 服务)
collection_name = 'default'
# 获取 Milvus 连接
def get_collection(collection_name):
pymilvus.connections.connect(host=MILVUS_HOST, port='19530')
if not pymilvus.utility.has_collection(collection_name):
field1 = pymilvus.FieldSchema(name="id", dtype=pymilvus.DataType.INT64, is_primary=True)
field2 = pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, dim=2048)
field3 = pymilvus.FieldSchema(name="article_id", dtype=pymilvus.DataType.INT64)
schema = pymilvus.CollectionSchema(fields=[field1, field2, field3])
return pymilvus.Collection(name=collection_name, schema=schema)
return pymilvus.Collection(name=collection_name)
# 检查索引是否存在, 不存在则创建, 并加载
#collection = get_collection(collection_name)
#if not collection.has_index():
# default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
# collection.create_index(field_name="embedding", index_params=default_index)
#collection.load()

16
models/mysql.py Normal file
View File

@@ -0,0 +1,16 @@
import pymysql
from configs.config import MYSQL_HOST
# 连接 MySQL (开启 MySQL 服务)
conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor)
# 获取 MySQL 连接
def get_cursor():
try:
dx = conn.ping()
return conn.cursor()
except Exception:
conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor)
return conn.cursor()

9
models/oss.py Normal file
View File

@@ -0,0 +1,9 @@
import oss2
from configs.config import OSS_HOST
# 连接 OSS
oss2.defaults.connection_pool_size = 100
auth = oss2.Auth('LTAI4GH3qP6VA3QpmTYCgXEW', 'r2wz4bJty8iYfGIcFmEqlY1yon2Ruy')
bucket_image2 = oss2.Bucket(auth, f'http://{OSS_HOST}', 'gameui-image2')
bucket_webp = oss2.Bucket(auth, f'http://{OSS_HOST}', 'gameui-webp')

11
models/resnet.py Normal file
View File

@@ -0,0 +1,11 @@
import towhee
class Resnet50:
def resnet50_extract_feat(self, img_path):
feat = towhee.glob(img_path).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list()
print(feat[0])
return feat[0]
if __name__ == '__main__':
print('This script is running as the main program.')
#resnet =

184
models/task.py Normal file
View File

@@ -0,0 +1,184 @@
import os
import sqlite3
import datetime
from pydantic import BaseModel
from configs.config import SQLITE3_PATH
import uuid
import time
import requests
__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False)
__taskdb.row_factory = sqlite3.Row
# 创建任务表
__taskdb.execute("""
CREATE TABLE IF NOT EXISTS tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
user_id INTEGER NOT NULL,
description TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
""")
# 任务表单模型
class TaskForm(BaseModel):
name: str
description: str
class Task(BaseModel):
id: str
name: str
user_id: int
description: str
created_at: str
updated_at: str
status: str = "pending"
progress: int = 0
websockets: list = []
class TaskManager:
__tasks = {}
# 使用pydantic模型的写法
def add_task(self, name: str, user_id: int, description: str):
id = str(uuid.uuid4())
date = str(datetime.datetime.now())
task = Task(id=id, name=name, user_id=user_id, description=description, created_at=date, updated_at=date)
self.__tasks[id] = task
return task
def get_task(self, task_id: str):
task = self.__tasks.get(task_id)
if task:
return task.dict()
else:
return None
# 直接返回对象的写法
def get_tasks(self, user_id: int=None):
if user_id:
return [task for task in self.__tasks.values() if task.user_id == user_id]
else:
return [task for task in self.__tasks.values()]
def remove_task(self, task_id: str):
if self.__tasks.get(task_id):
self.__tasks.pop(task_id)
return True
else:
return False
# 正确的写法(不等待async函数直接返回结果)
async def add_websocket(self, task_id: str, websocket):
print(self.__tasks)
# 每次调用连接数都是1, 这是错误的
self.__tasks[task_id].websockets.append(websocket)
print(f"当前连接数: {len(self.__tasks[task_id].websockets)}")
#if len(task.websockets) > 0 and task.progress == 0 and task.status == "pending":
# print("任务未被监听, 开启监听")
# await self.start_listen_task(task)
# 移除 websocket
def remove_websocket(self, task_id: str, websocket):
task = self.__tasks.get(task_id)
if task:
print("连接已断开, 移除 websocket")
task.websockets.remove(websocket)
print(f"当前剩余连接数: {len(task.websockets)}")
# 正确的写法
async def start_listen_task(self, task: Task):
while True:
print(f"start listen task: {len(task.websockets)}")
time.sleep(2.5)
if len(task.websockets) == 0:
break
print(f"get task status:{task.id}")
break
#data = requests.get(f"http://localhost:5001/api/task/{task.id}").json()
#print(data)
## 通过 http 请求获取任务的状态
## 如果任务状态发生变化, 则将任务状态发送给所有监听该任务的 websocket
#if task.status != data.status:
# task.status = data.status
# print("send task status to websocket")
# for websocket in task.websockets:
# await websocket.send_json(task)
## 任务实体模型
#class TaskModel(BaseModel):
# id: str
# name: str
# user_id: int
# description: str
# created_at: str
# updated_at: str
# websockets: list=[]
#
## 向数据库中添加任务
#def add_task(name, user_id, description):
# print(name, user_id, description)
# date = datetime.datetime.now()
# cursor = __taskdb.cursor()
# cursor.execute("INSERT INTO tasks(name,user_id,description,created_at,updated_at) VALUES (?, ?, ?, ?, ?)", (name,user_id,description,date,date))
# id = cursor.lastrowid
# cursor.close()
# __taskdb.commit()
# return {"id": id,"name": name,"user_id": user_id,"description": description,"created_at": date,"updated_at": date}
#
#
## 从数据库中获取任务列表
#def get_tasks(user_id: int=None):
# cursor = __taskdb.cursor()
# if user_id:
# cursor.execute("SELECT * FROM tasks WHERE user_id = ?", (user_id,))
# else:
# cursor.execute("SELECT * FROM tasks")
# tasks = cursor.fetchall()
# cursor.close()
# return [dict(task) for task in tasks]
#
#
## 从数据库中获取指定任务
#def get_task(task_id: int):
# cursor = __taskdb.cursor()
# cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
# task = cursor.fetchone()
# cursor.close()
# return task
#
#
## 更新数据库中的任务
#def update_task(task: TaskModel):
# cursor = __taskdb.cursor()
# cursor.execute("UPDATE tasks SET name = ?, description = ?, updated_at = ? WHERE id = ?", (
# task.name,
# task.description,
# datetime.datetime.now(),
# task.id
# ))
# cursor.close()
# __taskdb.commit()
# return task
#
#
## 从数据库中删除任务
#def delete_task(task_id: int):
# cursor = __taskdb.cursor()
# cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
# cursor.close()
# __taskdb.commit()
# return task_id

63
requirements.txt Normal file
View File

@@ -0,0 +1,63 @@
aliyun-python-sdk-core==2.13.36
aliyun-python-sdk-kms==2.16.0
anyio==3.6.2
bleach==6.0.0
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==3.0.1
click==8.1.3
crcmod==1.7
cryptography==39.0.1
docutils==0.19
fastapi==0.92.0
grpcio==1.37.1
grpcio-tools==1.37.1
h11==0.14.0
idna==3.4
importlib-metadata==6.0.0
jaraco.classes==3.2.3
jeepney==0.8.0
jmespath==0.10.0
keyring==23.13.1
markdown-it-py==2.1.0
mdurl==0.1.2
mmh3==3.0.0
more-itertools==9.0.0
numpy==1.24.2
opencv-python==4.6.0.66
oss2==2.16.0
pandas==1.5.3
pgzip==0.3.4
Pillow==9.4.0
pkginfo==1.9.6
protobuf==3.20.3
pyarrow==11.0.0
pycparser==2.21
pycryptodome==3.17
pydantic==1.10.5
pygit2==1.10.1
Pygments==2.14.0
pymilvus==2.0.2
PyMySQL==1.0.2
python-dateutil==2.8.2
python-multipart==0.0.5
pytz==2022.7.1
readme-renderer==37.3
requests==2.28.2
requests-toolbelt==0.10.1
rfc3986==2.0.0
rich==13.3.1
SecretStorage==3.3.3
six==1.16.0
sniffio==1.3.0
starlette==0.25.0
tabulate==0.9.0
towhee==0.9.0
tqdm==4.64.1
twine==4.0.2
typing_extensions==4.5.0
ujson==5.1.0
urllib3==1.26.14
uvicorn==0.20.0
webencodings==0.5.1
zipp==3.13.0

39
resize.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/bin/sh
# 定时整理 etcd 空间防止 etcd 服务挂掉
# 将此脚本加入定时任务(每日1次, 一次一片)
# crontab -e
# 15 * * * * /bin/sh /home/milvus/resize.sh
# 查看 Milvus 及其依賴的狀態(etcd 和 MinIO)
# sudo systemctl status milvus
# sudo systemctl status milvus-etcd
# sudo systemctl status milvus-minio
echo "使用API3:"
export ETCDCTL_API=3
echo "查看空间占用:"
etcdctl endpoint status --write-out table
echo "查看告警状态:"
etcdctl alarm list
echo "获取当前版本:"
rev=$(etcdctl --endpoints=http://127.0.0.1:2379 endpoint status --write-out="json" | egrep -o '"revision":[0-9]*' | egrep -o '[0-9].*')
echo "压缩掉所有旧版本:"
etcdctl --endpoints=http://127.0.0.1:2379 compact $rev
echo "整理多余的空间:"
etcdctl --endpoints=http://127.0.0.1:2379 defrag
echo "取消告警信息:"
etcdctl --endpoints=http://127.0.0.1:2379 alarm disarm
echo "再次查看空间占用:"
etcdctl endpoint status --write-out table
echo "再次查看告警状态:"
etcdctl alarm list

214
routers/img.py Normal file
View File

@@ -0,0 +1,214 @@
import os
import time
import base64
import psutil
import statistics
import _thread as thread
from fastapi import APIRouter, HTTPException, Response
from urllib.parse import unquote
from configs.config import IMAGES_PATH
from models.mysql import get_cursor, conn
from utilities.download import download_image, generate_thumbnail
router = APIRouter()
# 预热图片(获取一次图片, 遍历图片表, 检查OSS中所有被预定的尺寸是否存在, 不存在则生成)
@router.get("/warm", summary="预热图片", description="预热图片")
def warm_image(op:int=0, end:int=10, version:str='0'):
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` limit {op}, {end}")
for img in cursor.fetchall():
# 如果CPU使用率大于50%, 则等待, 直到CPU使用率小于50%
while statistics.mean(psutil.cpu_percent(interval=1, percpu=True)) > 50:
print(statistics.mean(psutil.cpu_percent(interval=1, percpu=True)), '等待CPU释放...')
time.sleep(2)
# 如果内存剩余小于1G, 则等待, 直到内存剩余大于1G
while psutil.virtual_memory().available < 1024 * 1024 * 1024:
print(psutil.virtual_memory().available, '等待内存释放...')
time.sleep(2)
# CPU使用率已降低, 开始处理图片
image = download_image(img['content']) # 从OSS下载原图
if not image:
print('跳过不存在的图片:', img['content'])
continue
# 创建新线程处理图片
try:
print('开始处理图片:', img['content'])
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 3, 328, 'webp'))
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 2, 328, 'webp'))
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 1, 328, 'webp'))
except:
print('无法启动线程')
cursor.close()
return Response('预热成功', status_code=200, media_type='text/plain', headers={'Content-Type': 'text/plain; charset=utf-8'})
# 获取非标准类缩略图
@router.get("/{type}-{id}-{version}@{n}x{w}.{ext}", summary="获取非标准类缩略图", description="/img/article-233-version@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:str):
img_path = f"{IMAGES_PATH}/{type}-{id}-{version}@{n}x{w}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
if type == 'ad' or type == 'article' or type == 'article_attribute':
cursor = get_cursor()
count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('图片不存在:', count)
return Response('图片不存在', status_code=404)
url = img['image']
elif type == 'url':
id = unquote(id, 'utf-8')
id = id.replace(' ','+')
url = unquote(base64.b64decode(id))
print(url)
elif type == 'avatar':
cursor = get_cursor()
count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
user = cursor.fetchone()
cursor.close()
if user is None:
print('用户不存在:', count)
return Response('用户不存在', status_code=404)
url = user['avatar']
else:
print('图片类型不存在:', type)
return Response('图片类型不存在', status_code=404)
image = download_image(url)
if not image:
return Response('图片不存在', status_code=404)
# 如果是 avatar, 则裁剪为正方形
if type == 'avatar':
px = image.size[0] if image.size[0] < image.size[1] else image.size[1]
image = image.crop((0, 0, px, px))
image.thumbnail((n*w, image.size[1]))
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取非标准类原尺寸图
@router.get("/{type}-{id}-{version}.{ext}", summary="获取文章缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_type(type:str, id:str, version:str, ext:str):
img_path = f"{IMAGES_PATH}/{type}-{id}-{version}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
if type == 'ad' or type == 'article' or type == 'article_attribute':
cursor = get_cursor()
count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('图片不存在:', count)
return Response('图片不存在', status_code=404)
url = img['image']
elif type == 'url':
id = unquote(id, 'utf-8')
id = id.replace(' ','+')
url = unquote(base64.b64decode(id))
print("url:", url)
elif type == 'avatar':
cursor = get_cursor()
count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
user = cursor.fetchone()
cursor.close()
if user is None:
print('用户不存在:', count)
return Response('用户不存在', status_code=404)
url = user['avatar']
else:
print('图片类型不存在:', type)
return Response('图片类型不存在', status_code=404)
image = download_image(url)
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 通过url获取图片
@router.get("/url-{url}@{n}x{w}.{ext}", summary="通过url获取图片", description="/img/article-233.webp")
def get_image_url(url:str, n:int, w:int, ext:str):
img_path = f"{IMAGES_PATH}/{type}-{url}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
url = unquote(url, 'utf-8').replace(' ','+')
url = unquote(base64.b64decode(url))
image = download_image(url)
if not image:
return Response('图片不存在', status_code=404)
image.thumbnail((n*w, image.size[1]))
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取标准缩略图(带版本号)
@router.get("/{id}-{version}@{n}x{w}.{ext}", summary="获取缩略图(带版本号)", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_thumbnail(id:int, version:str, n:int, w:int, ext:str):
# 判断图片是否已经生成
img_path = f"{IMAGES_PATH}/{id}-{version}@{n}x{w}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 从数据库获取原图地址
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('图片不存在:', id)
return Response('图片不存在', status_code=404)
image = download_image(img['content'])
if not image:
return Response('图片不存在', status_code=404)
image.thumbnail((n*w, image.size[1]))
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取标准缩略图
@router.get("/{id}@{n}x{w}.{ext}", summary="获取缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_thumbnail(id:int, n:int, w:int, ext:str):
# 判断图片是否已经生成
img_path = f"{IMAGES_PATH}/{id}@{n}x{w}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 从数据库获取原图地址
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('图片不存在:', id)
return Response('图片不存在', status_code=404)
image = download_image(img['content'])
if not image:
return Response('图片不存在', status_code=404)
image.thumbnail((n*w, image.size[1]))
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取标准原尺寸图
@router.get("/{id}.{ext}", summary="获取标准原尺寸图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 无后缀获取原图")
def get_image(id: int = 824, ext: str = 'webp'):
# 判断图片是否已经生成
img_path = f"{IMAGES_PATH}/{id}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 从数据库获取原图地址
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('图片不存在:', id)
return Response('图片不存在', status_code=404)
image = download_image(img['content'])
if not image:
return Response('图片不存在', status_code=404)
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")

231
routers/reverse.py Normal file
View File

@@ -0,0 +1,231 @@
import os
import random
import string
import sqlite3
import numpy as np
import time
import io
from PIL import Image
from towhee import pipe, ops
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
from models.milvus import get_collection, collection_name
from models.mysql import get_cursor, conn
#from models.resnet import Resnet50
from configs.config import UPLOAD_PATH
from utilities.download import download_image
router = APIRouter()
#MODEL = Resnet50()
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
# 获取状态统计
@router.get('', summary='状态统计', description='通过表名获取状态统计')
def count_images():
collection = get_collection(collection_name)
return {'status': True, 'count': collection.num_entities}
# 手动重建索引
@router.get('/create_index', summary='重建索引', description='手动重建索引', include_in_schema=False)
async def create_index():
collection = get_collection(collection_name)
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
collection.create_index(field_name="embedding", index_params=default_index)
collection.load()
return {'status': True, 'count': collection.num_entities}
'''
# 批量生成向量
@router.get('/create_vector', summary='生成向量', description='手动生成向量', include_in_schema=False)
async def create_vector(count: int = 10):
cursor = get_cursor()
cursor.execute(f"SELECT id,thumbnail_image,article_id,milvus_id FROM `web_images` WHERE thumbnail_image IS NOT NULL AND article_id IS NOT NULL AND milvus_id != 2048 AND width IS NOT NULL LIMIT 0,{count}")
images = cursor.fetchall()
cursor.close()
for item in images:
print(item)
# 先查询 milvus 中是否存在
collection = get_collection(collection_name)
data = collection.query(expr=f'id in [{item["id"]}]', output_fields=None, partition_names=None, timeout=None) # offset, limit
if len(data) > 0:
cursor = get_cursor()
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
conn.commit()
cursor.close()
continue
try:
img_path = os.path.join(UPLOAD_PATH, os.path.basename(item['thumbnail_image']))
download_image(item['thumbnail_image']).save(img_path, 'png', save_all=True)
feat = MODEL.resnet50_extract_feat(img_path)
collection.insert([[item['id']], [feat], [item['article_id']]])
cursor = get_cursor()
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
conn.commit()
cursor.close()
except Exception as e:
print(e)
print('END')
return images
'''
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
async def rewrite_image(image_id: int):
print('START', image_id, '重建向量')
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('mysql中原始图片不存在:', image_id)
return Response('图片不存在', status_code=404)
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
print('img_path', img_path)
image = download_image(img['content'])
if image is None:
print('图片下载失败:', img['content'])
return Response('图片下载失败', status_code=404)
image.save(img_path, 'png', save_all=True)
with Image.open(img_path) as imgx:
feat = RESNET50(imgx).get()[0]
collection = get_collection(collection_name)
collection.delete(expr=f'id in [{image_id}]')
rest = collection.insert([[image_id], [feat], [img['article_id']]])
os.remove(img_path)
print('END', image_id, '重建向量', rest.primary_keys)
return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()}
# 获取相似(废弃)
@router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片')
async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
collection = get_collection(collection_name)
result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1)
# 如果没有结果, 则重新生成记录
if len(result) == 0:
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('mysql 中图片不存在:', image_id)
return Response('图片不存在', status_code=404)
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
image = download_image(img['content'])
if image is None:
print('图片下载失败:', img['content'])
return Response('图片下载失败', status_code=404)
image.save(img_path, 'png', save_all=True)
with Image.open(img_path) as imgx:
feat = RESNET50(imgx).get()[0]
# 移除可能存在的旧记录, 换上新的
collection.delete(expr=f'id in [{image_id}]')
collection.insert([[image_id], [feat], [img['article_id']]])
os.remove(img_path)
print('生成')
else:
print('通过')
feat = result[0]['embedding']
res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0]
# 翻页(截取有效范围, page * pageize)
ope = page*pageSize-pageSize
end = page*pageSize
next = False
# 为数据附加信息
ids = [i.id for i in res] # 获取所有ID
if len(res) <= end:
ids = ids[ope:]
next = False
else:
ids = ids[ope:end]
next = True
str_ids = str(ids).replace('[', '').replace(']', '')
if str_ids == '':
print('没有更多数据了')
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
cursor = get_cursor()
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})")
imgs = cursor.fetchall()
if len(imgs) == 0:
return imgs
# 获取用户ID和文章ID
uids = list(set([x['user_id'] for x in imgs]))
tids = list(set([x['article_id'] for x in imgs]))
# 获取用户信息
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
users = cursor.fetchall()
# 获取文章信息
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
articles = cursor.fetchall()
cursor.close()
# 合并信息
user, article = {}, {}
for x in users: user[x['id']] = x
for x in articles: article[x['id']] = x
for x in imgs:
x['article'] = article[x['article_id']]
x['user'] = user[x['user_id']]
x['distance'] = [i.distance for i in res if i.id == x['id']][0]
if x['praise_count'] == None: x['praise_count'] = 0
if x['collect_count'] == None: x['collect_count'] = 0
# 将字段名转换为驼峰
x['createTime'] = x.pop('create_time')
x['updateTime'] = x.pop('update_time')
# 对 imgs 重新排序(按照 distance 字段)
imgs = sorted(imgs, key=lambda x: x['distance'])
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs}
@router.post(path='', summary='以图搜图', description='上传图片进行搜索')
async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize: int = 20):
content = await image.read()
img = Image.open(image.file)
embeddig = RESNET50(img).get()[0]
collection = get_collection('default')
res = collection.search([embeddig],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=500)[0]
ope, end = (page - 1) * pageSize, page * pageSize
ids, nextx = [x.id for x in res][ope:end], len(res) > end
if not ids:
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
cursor = get_cursor()
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})")
imgs = cursor.fetchall()
if not imgs:
return imgs
uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs))
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
users = cursor.fetchall()
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
articles = cursor.fetchall()
cursor.close()
user, article = {x['id']: x for x in users}, {x['id']: x for x in articles}
for x in imgs:
x.update({
'article': article[x['article_id']],
'user': user[x['user_id']],
#'distance': next(i.distance for i in res if i.id == x['id'], 0),
'distance': [i.distance for i in res if i.id == x['id']][0],
'praise_count': x.get('praise_count', 0),
'collect_count': x.get('collect_count', 0)
})
imgs.sort(key=lambda x: x['distance'])
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs}
@router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像")
async def delete_images(thread_id: str):
collection = get_collection(collection_name)
collection.delete(expr="article_id in ["+thread_id+"]")
collection.load()
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
collection.create_index(field_name="embedding", index_params=default_index)
return {"status": True, 'msg': '删除完毕'}

94
routers/task.py Normal file
View File

@@ -0,0 +1,94 @@
from fastapi import APIRouter, HTTPException, Request, WebSocket
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from models.task import TaskForm, TaskManager
router = APIRouter()
task_manager = TaskManager()
# 创建新任务
@router.post("", summary="创建新任务")
def create_task(form: TaskForm):
task = task_manager.add_task(name=form.name, user_id=123456, description=form.description)
return task
# 获取任务列表
@router.get("", summary="获取任务列表", description="可使用user_id参数筛选指定用户的任务列表")
def get_task_list(user_id: int=None):
return task_manager.get_tasks(user_id)
# websocket demo
@router.get("/demo", response_class=HTMLResponse)
async def websocket_demo(request: Request):
task_list = task_manager.get_tasks()
templates = Jinja2Templates(directory="templates")
return templates.TemplateResponse("websocket.html", {"request": request, "task_list": task_list})
# 监听指定任务的变化事件, 通知前端(不使用pydantic模型, 正确的写法)
@router.websocket("/{task_id}", name="监听任务变化")
async def websocket_endpoint(task_id: str, websocket: WebSocket):
await websocket.accept()
await task_manager.add_websocket(task_id, websocket)
async for data in websocket.iter_text():
await websocket.send_text(f"Message text was: {data}")
task_manager.remove_websocket(task_id, websocket)
print("websocket 连接已自动关闭")
#await websocket.send_json({"message": "Hello WebSocket!"})
#task = task_manager.get_task(task_id)
#print(task)
#if not task:
# print("task 不存在, 结束连接")
# return await websocket.close() # 任务不存在, 结束连接
#await websocket.send_json(task) # 将任务的状态发送给客户端
#await task_manager.add_websocket(task_id, websocket)
# 正确的写法, 使用 async for, 并且处理意外断开的情况
#try:
# async for data in websocket.iter_text():
# if data == "close":
# print("客户端主动关闭连接")
# task_manager.remove_websocket(task_id, websocket)
# await websocket.close()
# break
# else:
# print(f"接收到客户端消息: {data}")
# await websocket.send_text(f"Message text was: {data}")
#except Exception as e:
# print(f"客户端意外断开连接: {e}")
# task_manager.remove_websocket(task_id, websocket)
# #await websocket.close()
# 监听客户端的状态, 如果客户端主动关闭连接或意外断开连接, 都从任务的websocket列表中移除
#while True:
# try:
# data = await websocket.receive_text()
# if data == "close":
# print("客户端主动关闭连接")
# task_manager.remove_websocket(task_id, websocket)
# await websocket.close()
# break
# else:
# print(f"接收到客户端消息: {data}")
# await websocket.send_text(f"Message text was: {data}")
# except Exception as e:
# print(f"客户端意外断开连接: {e}")
# task_manager.remove_websocket(task_id, websocket)
# #await websocket.close()
# break
# 获取任务详情
@router.get("/{task_id}", summary="获取任务详情")
def get_task(task_id: int):
return get_task(task_id)
# 删除任务
@router.delete("/{task_id}", summary="删除指定任务")
def delete_task(task_id: int):
return delete_task(task_id)

19
routers/user.py Normal file
View File

@@ -0,0 +1,19 @@
from fastapi import APIRouter, HTTPException
from models.mysql import conn, get_cursor
router = APIRouter()
@router.get('/{user_id}/collect', summary="用户收藏记录", description="获取指定用户收藏记录")
def get_user_collect(user_id:int):
# TODO: 需要验证权限
cursor = get_cursor()
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'")
data = cursor.fetchall()
if not data:
return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []}
data = [str(item['content']) for item in data]
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
data = cursor.fetchall()
data = [str(item['id']) for item in data]
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data }

45
routers/user_collect.py Normal file
View File

@@ -0,0 +1,45 @@
from typing import Optional
from fastapi import APIRouter, HTTPException, Header
from models.mysql import conn, get_cursor
router = APIRouter()
# 获取当前用户收藏记录
@router.get('', summary='自己的收藏记录', description='获取自己的收藏记录, 用于判断是否收藏(headers中必须附带token)')
def get_self_collect(token: Optional[str] = Header()):
print('token: ', token)
cursor = get_cursor()
# 查询用户ID
cursor.execute(f"SELECT user_id FROM web_auth WHERE token={token} limit 1")
data = cursor.fetchone()
print('auth: ', data)
if not data:
raise HTTPException(status_code=401, detail="用户未登录")
user_id = data['user_id']
# 查询收藏记录
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'")
data = cursor.fetchall() # 获取所有记录列表
data = [str(item['content']) for item in data] # 转换为数组
# 查询图片ID(对特殊字符安全转义)
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
data = cursor.fetchall()
data = [str(item['id']) for item in data]
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data }
# 获取指定用户收藏记录
@router.get('/{user_id}', summary='指定用的户收藏记录', description='获取指定用户收藏记录(仅测试用)')
def get_user_collect(user_id: int):
cursor = get_cursor()
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type=1")
data = cursor.fetchall() # 获取所有记录列表
data = [str(item['content']) for item in data] # 转换为数组
if not data:
return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []}
# 查询图片ID(对特殊字符安全转义)
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
data = cursor.fetchall()
data = [str(item['id']) for item in data]
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data}

25
start.sh Executable file
View File

@@ -0,0 +1,25 @@
#!/bin/sh
# pip freeze > requirements.txt
#pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
#python3 -m pip install -U pyOpenSSL cryptography
# 使用 uvicorn 启动两个工作进程, 并支持自动重载
#uvicorn main:app --host 0.0.0.0 --port 5001 --workers 2 --reload >/dev/null 2>&1 &
#uvicorn main:app --host 0.0.0.0 --port 5002 >/dev/null 2>&1 &
python main.py >/dev/null 2>&1 &
# 为容器设定自动重启
#docker update --restart=always 51bb94aa2726
#cd ~/reverse_image_search
#pm2 start python3 --name reverse-1 -- main.py 5001
#pm2 start python3 --name reverse-2 -- main.py 5002
#pm2 start python3 --name reverse-3 -- main.py 5003
# 某些系统下,需要手动安装指定版本 opencv-python (4.7引用zlib错误)
# pip install opencv-python==4.6.0.66 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 如果网络不通, 从自定义服务器下载模型参数
#scp root@172.21.216.33:~/root/.cache/torch/hub/checkpoints/resnet50_a1_0-14fe96d1.pth resnet50_a1_0-14fe96d1.pth

44
templates/websocket.html Normal file
View File

@@ -0,0 +1,44 @@
<!DOCTYPE html>
<html>
<head>
<title>WebSocket Demo</title>
</head>
<body>
<div>
<p>{{ task_list }}</p>
<button id="create_task">Create Task</button>
</div>
<h1>WebSocket Demo</h1>
<div id="message"></div>
<button id="create">Create WebSocket</button>
<button id="close">Close WebSocket</button>
<script>
function createWebSocket(id) {
var ws = new WebSocket("ws://localhost:5001/api/task/"+id);
ws.onmessage = function(event) {
var message = event.data;
document.getElementById("message").innerHTML += message + "<br>";
};
ws.onclose = function(event) {
document.getElementById("message").innerHTML += "WebSocket closed";
};
return ws;
}
document.getElementById("create_task").onclick = function() {
fetch("/api/task", {
method: "POST",
headers: {
"Content-Type": "application/json"
},
body: JSON.stringify({
"name": "Task 1",
"description": "Task 1 description"
})
}).then(res => res.json()).then(data => {
console.log(data)
var ws = createWebSocket(data.id);
});
};
</script>
</body>
</html>

4
update.sh Executable file
View File

@@ -0,0 +1,4 @@
#!/bin/sh
# 服务器跟随更新
ssh ai "cd ~/reverse_image_search_gpu; git pull;"

41
utilities/download.py Normal file
View File

@@ -0,0 +1,41 @@
import io
import os
import requests
from PIL import Image, ImageFile
from models.oss import bucket_image2, bucket_webp
from configs.config import IMAGES_PATH
ImageFile.LOAD_TRUNCATED_IMAGES = True
# 下载图片(使用OSS下载)
def download_image(url:str) -> Image:
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
try:
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
obj = bucket_image2.get_object(url).read()
return Image.open(io.BytesIO(obj))
except Exception:
return None
else:
try:
response = requests.get(url)
return Image.open(io.BytesIO(response.content))
except Exception:
print('图片下载失败:', url)
return None
# 生成缩略图, 写入OSS
def generate_thumbnail(image:Image, id:int, version:str, n:int, w:int, ext:str):
path = f"{id}-{version}@{n}x{w}.{ext}"
if bucket_webp.object_exists(path):
print('缩略图已经存在:', path)
return
# 将 image 对象复制一份, 防止影响原图
image = image.copy()
image.thumbnail((n*w, image.size[1]))
image.save(f"{IMAGES_PATH}/{path}", ext, save_all=True)
bucket_webp.put_object_from_file(path, f"{IMAGES_PATH}/{path}")
os.remove(f"{IMAGES_PATH}/{path}")
print('缩略图写入 OSS:', path)