diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ed8ebf5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index 025d6b5..13b12d1 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,118 @@ -# reverse_image_search_gpu +# reverse_image_search -图片反向搜索, 利用GPU加速 \ No newline at end of file +图片反向检索 + +```urls +https://www.gameui.net/api/default/create_vector?count=1024 # 预先生成指定数量的向量 +https://www.gameui.net/api/default/create_index # 手动重建索引 +``` + +## Intall + +```bash +# 先安装 milvus v2.1.4 矢量数据库 +wget https://github.com/milvus-io/milvus/releases/download/v2.1.4/milvus_2.1.4-1_amd64.deb +sudo apt-get update +sudo dpkg -i milvus_2.1.4-1_amd64.deb +sudo apt-get -f install + +# 查看 Milvus 及其依賴的狀態(etcd 和 MinIO) +sudo systemctl status milvus +sudo systemctl status milvus-etcd +sudo systemctl status milvus-minio + +# 安装 etcdctl (管理工具) +wget https://github.com/etcd-io/etcd/releases/download/v3.4.22/etcd-v3.4.22-linux-amd64.tar.gz +# 解压 etcdctl 到 /bin/ +# 向 ~/.bashrc 添加 export ETCDCTL_API=3 + +# 安装 python 包管理工具 pip +sudo apt install python3-pip + +# 使用 venv 创建虚拟环境(注意安装路径是当前用户的 .local/bin) +python3 -m venv venv +source venv/bin/activate +python -m pip install --upgrade pip + +# 为 .bashrc 设置快捷命令(使用echo或cat追加写) +vim ~/.bashrc +alias venv='source venv/bin/activate' +source ~/.bashrc + +# 不使用虚拟环境的防止依赖冲突方法 +pip freeze > requirements.txt +pip uninstall -r requirements.txt + +# 安装依赖 +pip3 install -r requirements.txt + +# 注意 python 要求 V3.9.7 版本以上 +python3 src/main.py + +``` + +挂载对象存储服务作为存储盘 + +## 批量处理工具 + +- [x] 使项目不依赖平台提供的服务, 易于迁移 +- [x] 操作读写数据库, 不使用mysql +- [x] /img/no_content.8dd8acce.png 冲突 +- [x] 在上海区新建OSS实例作为磁盘挂载到服务器(用于存储缓存图) ...登录 +- [x] 在服务器提供接口(设置CDN层分发) ...登录 +- [x] 获取前端项目(修改列表图片三类分辨率) +- [x] 通过 path 获取缩略图 +- [x] 通过 articleDetails_id 获取缩略图 +- [x] 图片向量无缓存时重新生成 +- [ ] 提前生成原图webp以提高首次打开速度 + +## RTT + +1. OSS 中 gameui-webp Bucket 作为图片缓存层 +2. 通过 /img/xxx.webp 接口提供图片, 由 OSS 作镜像回源 + +NGINX config + +```config +server { + location /docs { + proxy_pass http://gameui_ai_server; + proxy_redirect off; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + } + + location /api/ { + proxy_pass http://gameui_ai_server; + proxy_redirect off; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + } + + location ~* /img/([0-9]+)\.(webp|jpeg) { + proxy_pass http://gameui_ai_server; + proxy_redirect off; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + } + + location ~* /img/([0-9]+).*\.(webp|jpeg) { + proxy_pass http://gameui_ai_server; + proxy_redirect off; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + } + + location ~* /img/article-([0-9]+).*\.(webp|jpeg) { + proxy_pass http://gameui_ai_server; + proxy_redirect off; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + } +} +``` diff --git a/configs/config.py b/configs/config.py new file mode 100644 index 0000000..28de64e --- /dev/null +++ b/configs/config.py @@ -0,0 +1,61 @@ +import os +import json +import time + +from apscheduler.schedulers.background import BackgroundScheduler + + +# 以 .env 文件中的环境变量为准, 检查是否存在 .env 文件, 没有则创建 +if not os.path.exists('.env'): + print('请输入环境变量参数, 将会写入 .env 文件中') + OSS_HOST = input('OSS_HOST: ') + MYSQL_HOST = input('MYSQL_HOST: ') + MILVUS_HOST = input('MILVUS_HOST: ') + with open('.env', 'w') as f: + f.write(f'OSS_HOST={OSS_HOST}\n') + f.write(f'MYSQL_HOST={MYSQL_HOST}\n') + f.write(f'MILVUS_HOST={MILVUS_HOST}\n') + + +# 读取 .env 文件中的环境变量 +with open('.env', 'r') as f: + env = f.readlines() + env = list(filter(lambda x: not x.startswith('#') and not x.startswith('\n') and len(x.split('=')) == 2, env)) + env = list(map(lambda x: x.replace('\n', '').split('='), env)) + env = {k: v for k, v in env} + print(json.dumps(env, indent=4, ensure_ascii=False)) + OSS_HOST = env.get('OSS_HOST') + MYSQL_HOST = env.get('MYSQL_HOST') + MILVUS_HOST = env.get('MILVUS_HOST') + + +# 创建上传图片的临时目录 +UPLOAD_PATH = '/tmp/search-images' +if not os.path.exists(UPLOAD_PATH): + os.makedirs(UPLOAD_PATH) + + +# 创建压缩图片的临时目录 +IMAGES_PATH = '/tmp/images' +if not os.path.exists(IMAGES_PATH): + os.makedirs(IMAGES_PATH) + + +# 创建sqlite3数据库目录 +SQLITE3_PATH = '/tmp/sqlite3' +if not os.path.exists(SQLITE3_PATH): + os.makedirs(SQLITE3_PATH) + + +def clear_images(): + print('开始清理创建时间大于30分钟的图片缓存') + for file in os.listdir(IMAGES_PATH): + if os.path.getmtime(os.path.join(IMAGES_PATH, file)) < time.time() - 30 * 60: + print('清理图片缓存:', file) + os.remove(os.path.join(IMAGES_PATH, file)) + + +# 构建一个定时任务每30分钟清理一次图片缓存 +scheduler = BackgroundScheduler() +scheduler.add_job(clear_images, 'interval', minutes=30) +scheduler.start() diff --git a/demo.png b/demo.png new file mode 100644 index 0000000..adc6aa9 Binary files /dev/null and b/demo.png differ diff --git a/demo.py b/demo.py new file mode 100755 index 0000000..0dd91bf --- /dev/null +++ b/demo.py @@ -0,0 +1,53 @@ +import timm +import torch +import torch.nn.functional as functional +from PIL import Image +from torchvision import transforms + +# 加载预训练模型 +model = timm.create_model('resnet50', pretrained=True) +model = model.eval() + +# 定义图片处理流程 +preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +# 读取图片 +print("loading image") +img = Image.open("demo.png").convert("RGB") +tensor = preprocess(img).unsqueeze(0) + +# 检查是否有可用的GPU +if torch.cuda.is_available(): + input_batch = tensor.to('cuda') + model.to('cuda') + +print("start run model") +with torch.no_grad(): + output = model(input_batch) + +for x in output: + print(x.shape) + +# 输出2048维向量 +# print(output[0]) + +''' +from towhee import pipe, ops, DataCollection + +p = ( + pipe.input('path') + .map('path', 'img', ops.image_decode()) + .map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')) + .output('img', 'vec') +) + +ea = DataCollection(p('demo.png')).to_list() +print(ea) + +''' + diff --git a/download.py b/download.py new file mode 100644 index 0000000..13bfb82 --- /dev/null +++ b/download.py @@ -0,0 +1,41 @@ +import io +import os +import requests + +from PIL import Image, ImageFile +from models.oss import bucket_image2, bucket_webp +from configs.config import IMAGES_PATH + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +# 下载图片(使用OSS下载) +def download_image(url:str) -> Image: + if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'): + try: + url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '') + obj = bucket_image2.get_object(url).read() + return Image.open(io.BytesIO(obj)) + except Exception: + return None + else: + try: + response = requests.get(url) + return Image.open(io.BytesIO(response.content)) + except Exception: + print('图片下载失败:', url) + return None + +# 生成缩略图, 写入OSS +def generate_thumbnail(image:Image, id:int, version:str, n:int, w:int, ext:str): + path = f"{id}-{version}@{n}x{w}.{ext}" + if bucket_webp.object_exists(path): + print('缩略图已经存在:', path) + return + # 将 image 对象复制一份, 防止影响原图 + image = image.copy() + image.thumbnail((n*w, image.size[1])) + image.save(f"{IMAGES_PATH}/{path}", ext, save_all=True) + bucket_webp.put_object_from_file(path, f"{IMAGES_PATH}/{path}") + os.remove(f"{IMAGES_PATH}/{path}") + print('缩略图写入 OSS:', path) + diff --git a/log.yaml b/log.yaml new file mode 100644 index 0000000..09207fe --- /dev/null +++ b/log.yaml @@ -0,0 +1,19 @@ +version: 1 +formatters: + simple: + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +handlers: + console: + class: logging.StreamHandler + level: DEBUG + formatter: simple + stream: ext://sys.stdout +loggers: + simpleExample: + level: DEBUG + handlers: [console] + propagate: no +root: + level: DEBUG + handlers: [console] + diff --git a/main.py b/main.py new file mode 100755 index 0000000..800885b --- /dev/null +++ b/main.py @@ -0,0 +1,28 @@ +# -*- coding:utf-8 -*- + +import sys +import uvicorn + +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware +from functools import lru_cache +from routers import reverse, user, task, img, user_collect + + +# 初始化 FastAPI +app = FastAPI(title="GameUI", description="GameUI", version="1.5.0", openapi_url="/docs/openapi.json", docs_url="/docs", redoc_url="/redoc") +app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) + + +# 导入路由 +app.include_router(user.router, prefix='/api/user', tags=['用户']) +app.include_router(task.router, prefix='/api/task', tags=['任务']) +app.include_router(reverse.router, prefix='/api/default', tags=['搜图']) +app.include_router(user_collect.router, prefix='/api/user_collect', tags=['收藏']) +app.include_router(img.router, prefix='/imgs', tags=['图片']) + + +# 启动服务 +if __name__ == '__main__': + port = 5002 if len(sys.argv) < 2 else int(sys.argv[1]) + uvicorn.run(app='main:app', host='0.0.0.0', port=port, reload=True, workers=1) diff --git a/models/milvus.py b/models/milvus.py new file mode 100644 index 0000000..7234664 --- /dev/null +++ b/models/milvus.py @@ -0,0 +1,26 @@ +import pymilvus + +from configs.config import MILVUS_HOST + +# 连接 Milvus (开启 Milvus 服务) +collection_name = 'default' + + +# 获取 Milvus 连接 +def get_collection(collection_name): + pymilvus.connections.connect(host=MILVUS_HOST, port='19530') + if not pymilvus.utility.has_collection(collection_name): + field1 = pymilvus.FieldSchema(name="id", dtype=pymilvus.DataType.INT64, is_primary=True) + field2 = pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, dim=2048) + field3 = pymilvus.FieldSchema(name="article_id", dtype=pymilvus.DataType.INT64) + schema = pymilvus.CollectionSchema(fields=[field1, field2, field3]) + return pymilvus.Collection(name=collection_name, schema=schema) + return pymilvus.Collection(name=collection_name) + + +# 检查索引是否存在, 不存在则创建, 并加载 +#collection = get_collection(collection_name) +#if not collection.has_index(): +# default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}} +# collection.create_index(field_name="embedding", index_params=default_index) +#collection.load() diff --git a/models/mysql.py b/models/mysql.py new file mode 100644 index 0000000..b95bedc --- /dev/null +++ b/models/mysql.py @@ -0,0 +1,16 @@ +import pymysql +from configs.config import MYSQL_HOST + + +# 连接 MySQL (开启 MySQL 服务) +conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor) + + +# 获取 MySQL 连接 +def get_cursor(): + try: + dx = conn.ping() + return conn.cursor() + except Exception: + conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor) + return conn.cursor() diff --git a/models/oss.py b/models/oss.py new file mode 100644 index 0000000..e94a300 --- /dev/null +++ b/models/oss.py @@ -0,0 +1,9 @@ +import oss2 +from configs.config import OSS_HOST + + +# 连接 OSS +oss2.defaults.connection_pool_size = 100 +auth = oss2.Auth('LTAI4GH3qP6VA3QpmTYCgXEW', 'r2wz4bJty8iYfGIcFmEqlY1yon2Ruy') +bucket_image2 = oss2.Bucket(auth, f'http://{OSS_HOST}', 'gameui-image2') +bucket_webp = oss2.Bucket(auth, f'http://{OSS_HOST}', 'gameui-webp') diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000..c3c249a --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,11 @@ +import towhee + +class Resnet50: + def resnet50_extract_feat(self, img_path): + feat = towhee.glob(img_path).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list() + print(feat[0]) + return feat[0] + +if __name__ == '__main__': + print('This script is running as the main program.') + #resnet = diff --git a/models/task.py b/models/task.py new file mode 100644 index 0000000..4baa885 --- /dev/null +++ b/models/task.py @@ -0,0 +1,184 @@ +import os +import sqlite3 +import datetime +from pydantic import BaseModel +from configs.config import SQLITE3_PATH +import uuid +import time +import requests + + +__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False) +__taskdb.row_factory = sqlite3.Row + +# 创建任务表 +__taskdb.execute(""" +CREATE TABLE IF NOT EXISTS tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + user_id INTEGER NOT NULL, + description TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +) +""") + + +# 任务表单模型 +class TaskForm(BaseModel): + name: str + description: str + + +class Task(BaseModel): + id: str + name: str + user_id: int + description: str + created_at: str + updated_at: str + status: str = "pending" + progress: int = 0 + websockets: list = [] + + +class TaskManager: + __tasks = {} + + # 使用pydantic模型的写法 + def add_task(self, name: str, user_id: int, description: str): + id = str(uuid.uuid4()) + date = str(datetime.datetime.now()) + task = Task(id=id, name=name, user_id=user_id, description=description, created_at=date, updated_at=date) + self.__tasks[id] = task + return task + + def get_task(self, task_id: str): + task = self.__tasks.get(task_id) + if task: + return task.dict() + else: + return None + + # 直接返回对象的写法 + def get_tasks(self, user_id: int=None): + if user_id: + return [task for task in self.__tasks.values() if task.user_id == user_id] + else: + return [task for task in self.__tasks.values()] + + + def remove_task(self, task_id: str): + if self.__tasks.get(task_id): + self.__tasks.pop(task_id) + return True + else: + return False + + # 正确的写法(不等待async函数直接返回结果) + async def add_websocket(self, task_id: str, websocket): + print(self.__tasks) + # 每次调用连接数都是1, 这是错误的 + self.__tasks[task_id].websockets.append(websocket) + print(f"当前连接数: {len(self.__tasks[task_id].websockets)}") + + + #if len(task.websockets) > 0 and task.progress == 0 and task.status == "pending": + # print("任务未被监听, 开启监听") + # await self.start_listen_task(task) + + # 移除 websocket + def remove_websocket(self, task_id: str, websocket): + task = self.__tasks.get(task_id) + if task: + print("连接已断开, 移除 websocket") + task.websockets.remove(websocket) + print(f"当前剩余连接数: {len(task.websockets)}") + + # 正确的写法 + async def start_listen_task(self, task: Task): + while True: + print(f"start listen task: {len(task.websockets)}") + time.sleep(2.5) + if len(task.websockets) == 0: + break + print(f"get task status:{task.id}") + break + #data = requests.get(f"http://localhost:5001/api/task/{task.id}").json() + #print(data) + ## 通过 http 请求获取任务的状态 + ## 如果任务状态发生变化, 则将任务状态发送给所有监听该任务的 websocket + #if task.status != data.status: + # task.status = data.status + # print("send task status to websocket") + # for websocket in task.websockets: + # await websocket.send_json(task) + + + +## 任务实体模型 +#class TaskModel(BaseModel): +# id: str +# name: str +# user_id: int +# description: str +# created_at: str +# updated_at: str +# websockets: list=[] +# +## 向数据库中添加任务 +#def add_task(name, user_id, description): +# print(name, user_id, description) +# date = datetime.datetime.now() +# cursor = __taskdb.cursor() +# cursor.execute("INSERT INTO tasks(name,user_id,description,created_at,updated_at) VALUES (?, ?, ?, ?, ?)", (name,user_id,description,date,date)) +# id = cursor.lastrowid +# cursor.close() +# __taskdb.commit() +# return {"id": id,"name": name,"user_id": user_id,"description": description,"created_at": date,"updated_at": date} +# +# +## 从数据库中获取任务列表 +#def get_tasks(user_id: int=None): +# cursor = __taskdb.cursor() +# if user_id: +# cursor.execute("SELECT * FROM tasks WHERE user_id = ?", (user_id,)) +# else: +# cursor.execute("SELECT * FROM tasks") +# tasks = cursor.fetchall() +# cursor.close() +# return [dict(task) for task in tasks] +# +# +## 从数据库中获取指定任务 +#def get_task(task_id: int): +# cursor = __taskdb.cursor() +# cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)) +# task = cursor.fetchone() +# cursor.close() +# return task +# +# +## 更新数据库中的任务 +#def update_task(task: TaskModel): +# cursor = __taskdb.cursor() +# cursor.execute("UPDATE tasks SET name = ?, description = ?, updated_at = ? WHERE id = ?", ( +# task.name, +# task.description, +# datetime.datetime.now(), +# task.id +# )) +# cursor.close() +# __taskdb.commit() +# return task +# +# +## 从数据库中删除任务 +#def delete_task(task_id: int): +# cursor = __taskdb.cursor() +# cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,)) +# cursor.close() +# __taskdb.commit() +# return task_id + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..34d2b78 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,63 @@ +aliyun-python-sdk-core==2.13.36 +aliyun-python-sdk-kms==2.16.0 +anyio==3.6.2 +bleach==6.0.0 +certifi==2022.12.7 +cffi==1.15.1 +charset-normalizer==3.0.1 +click==8.1.3 +crcmod==1.7 +cryptography==39.0.1 +docutils==0.19 +fastapi==0.92.0 +grpcio==1.37.1 +grpcio-tools==1.37.1 +h11==0.14.0 +idna==3.4 +importlib-metadata==6.0.0 +jaraco.classes==3.2.3 +jeepney==0.8.0 +jmespath==0.10.0 +keyring==23.13.1 +markdown-it-py==2.1.0 +mdurl==0.1.2 +mmh3==3.0.0 +more-itertools==9.0.0 +numpy==1.24.2 +opencv-python==4.6.0.66 +oss2==2.16.0 +pandas==1.5.3 +pgzip==0.3.4 +Pillow==9.4.0 +pkginfo==1.9.6 +protobuf==3.20.3 +pyarrow==11.0.0 +pycparser==2.21 +pycryptodome==3.17 +pydantic==1.10.5 +pygit2==1.10.1 +Pygments==2.14.0 +pymilvus==2.0.2 +PyMySQL==1.0.2 +python-dateutil==2.8.2 +python-multipart==0.0.5 +pytz==2022.7.1 +readme-renderer==37.3 +requests==2.28.2 +requests-toolbelt==0.10.1 +rfc3986==2.0.0 +rich==13.3.1 +SecretStorage==3.3.3 +six==1.16.0 +sniffio==1.3.0 +starlette==0.25.0 +tabulate==0.9.0 +towhee==0.9.0 +tqdm==4.64.1 +twine==4.0.2 +typing_extensions==4.5.0 +ujson==5.1.0 +urllib3==1.26.14 +uvicorn==0.20.0 +webencodings==0.5.1 +zipp==3.13.0 diff --git a/resize.sh b/resize.sh new file mode 100755 index 0000000..13fcf3f --- /dev/null +++ b/resize.sh @@ -0,0 +1,39 @@ +#!/bin/sh + +# 定时整理 etcd 空间防止 etcd 服务挂掉 +# 将此脚本加入定时任务(每日1次, 一次一片) + +# crontab -e +# 15 * * * * /bin/sh /home/milvus/resize.sh + +# 查看 Milvus 及其依賴的狀態(etcd 和 MinIO) +# sudo systemctl status milvus +# sudo systemctl status milvus-etcd +# sudo systemctl status milvus-minio + +echo "使用API3:" +export ETCDCTL_API=3 + +echo "查看空间占用:" +etcdctl endpoint status --write-out table + +echo "查看告警状态:" +etcdctl alarm list + +echo "获取当前版本:" +rev=$(etcdctl --endpoints=http://127.0.0.1:2379 endpoint status --write-out="json" | egrep -o '"revision":[0-9]*' | egrep -o '[0-9].*') + +echo "压缩掉所有旧版本:" +etcdctl --endpoints=http://127.0.0.1:2379 compact $rev + +echo "整理多余的空间:" +etcdctl --endpoints=http://127.0.0.1:2379 defrag + +echo "取消告警信息:" +etcdctl --endpoints=http://127.0.0.1:2379 alarm disarm + +echo "再次查看空间占用:" +etcdctl endpoint status --write-out table + +echo "再次查看告警状态:" +etcdctl alarm list diff --git a/routers/img.py b/routers/img.py new file mode 100644 index 0000000..1af387a --- /dev/null +++ b/routers/img.py @@ -0,0 +1,214 @@ +import os +import time +import base64 +import psutil +import statistics +import _thread as thread + +from fastapi import APIRouter, HTTPException, Response +from urllib.parse import unquote +from configs.config import IMAGES_PATH +from models.mysql import get_cursor, conn +from utilities.download import download_image, generate_thumbnail + + +router = APIRouter() + + +# 预热图片(获取一次图片, 遍历图片表, 检查OSS中所有被预定的尺寸是否存在, 不存在则生成) +@router.get("/warm", summary="预热图片", description="预热图片") +def warm_image(op:int=0, end:int=10, version:str='0'): + cursor = get_cursor() + cursor.execute(f"SELECT * FROM `web_images` limit {op}, {end}") + for img in cursor.fetchall(): + # 如果CPU使用率大于50%, 则等待, 直到CPU使用率小于50% + while statistics.mean(psutil.cpu_percent(interval=1, percpu=True)) > 50: + print(statistics.mean(psutil.cpu_percent(interval=1, percpu=True)), '等待CPU释放...') + time.sleep(2) + + # 如果内存剩余小于1G, 则等待, 直到内存剩余大于1G + while psutil.virtual_memory().available < 1024 * 1024 * 1024: + print(psutil.virtual_memory().available, '等待内存释放...') + time.sleep(2) + + # CPU使用率已降低, 开始处理图片 + image = download_image(img['content']) # 从OSS下载原图 + if not image: + print('跳过不存在的图片:', img['content']) + continue + + # 创建新线程处理图片 + try: + print('开始处理图片:', img['content']) + thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 3, 328, 'webp')) + thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 2, 328, 'webp')) + thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 1, 328, 'webp')) + except: + print('无法启动线程') + cursor.close() + return Response('预热成功', status_code=200, media_type='text/plain', headers={'Content-Type': 'text/plain; charset=utf-8'}) + + +# 获取非标准类缩略图 +@router.get("/{type}-{id}-{version}@{n}x{w}.{ext}", summary="获取非标准类缩略图", description="/img/article-233-version@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") +def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:str): + img_path = f"{IMAGES_PATH}/{type}-{id}-{version}@{n}x{w}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + if type == 'ad' or type == 'article' or type == 'article_attribute': + cursor = get_cursor() + count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}") + img = cursor.fetchone() + cursor.close() + if img is None: + print('图片不存在:', count) + return Response('图片不存在', status_code=404) + url = img['image'] + elif type == 'url': + id = unquote(id, 'utf-8') + id = id.replace(' ','+') + url = unquote(base64.b64decode(id)) + print(url) + elif type == 'avatar': + cursor = get_cursor() + count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}") + user = cursor.fetchone() + cursor.close() + if user is None: + print('用户不存在:', count) + return Response('用户不存在', status_code=404) + url = user['avatar'] + else: + print('图片类型不存在:', type) + return Response('图片类型不存在', status_code=404) + image = download_image(url) + if not image: + return Response('图片不存在', status_code=404) + # 如果是 avatar, 则裁剪为正方形 + if type == 'avatar': + px = image.size[0] if image.size[0] < image.size[1] else image.size[1] + image = image.crop((0, 0, px, px)) + image.thumbnail((n*w, image.size[1])) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + + +# 获取非标准类原尺寸图 +@router.get("/{type}-{id}-{version}.{ext}", summary="获取文章缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") +def get_image_type(type:str, id:str, version:str, ext:str): + img_path = f"{IMAGES_PATH}/{type}-{id}-{version}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + if type == 'ad' or type == 'article' or type == 'article_attribute': + cursor = get_cursor() + count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}") + img = cursor.fetchone() + cursor.close() + if img is None: + print('图片不存在:', count) + return Response('图片不存在', status_code=404) + url = img['image'] + elif type == 'url': + id = unquote(id, 'utf-8') + id = id.replace(' ','+') + url = unquote(base64.b64decode(id)) + print("url:", url) + elif type == 'avatar': + cursor = get_cursor() + count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}") + user = cursor.fetchone() + cursor.close() + if user is None: + print('用户不存在:', count) + return Response('用户不存在', status_code=404) + url = user['avatar'] + else: + print('图片类型不存在:', type) + return Response('图片类型不存在', status_code=404) + image = download_image(url) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + + +# 通过url获取图片 +@router.get("/url-{url}@{n}x{w}.{ext}", summary="通过url获取图片", description="/img/article-233.webp") +def get_image_url(url:str, n:int, w:int, ext:str): + img_path = f"{IMAGES_PATH}/{type}-{url}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + url = unquote(url, 'utf-8').replace(' ','+') + url = unquote(base64.b64decode(url)) + image = download_image(url) + if not image: + return Response('图片不存在', status_code=404) + image.thumbnail((n*w, image.size[1])) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + + +# 获取标准缩略图(带版本号) +@router.get("/{id}-{version}@{n}x{w}.{ext}", summary="获取缩略图(带版本号)", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") +def get_image_thumbnail(id:int, version:str, n:int, w:int, ext:str): + # 判断图片是否已经生成 + img_path = f"{IMAGES_PATH}/{id}-{version}@{n}x{w}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + # 从数据库获取原图地址 + cursor = get_cursor() + cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") + img = cursor.fetchone() + cursor.close() + if img is None: + print('图片不存在:', id) + return Response('图片不存在', status_code=404) + image = download_image(img['content']) + if not image: + return Response('图片不存在', status_code=404) + image.thumbnail((n*w, image.size[1])) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + + +# 获取标准缩略图 +@router.get("/{id}@{n}x{w}.{ext}", summary="获取缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") +def get_image_thumbnail(id:int, n:int, w:int, ext:str): + # 判断图片是否已经生成 + img_path = f"{IMAGES_PATH}/{id}@{n}x{w}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + # 从数据库获取原图地址 + cursor = get_cursor() + cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") + img = cursor.fetchone() + cursor.close() + if img is None: + print('图片不存在:', id) + return Response('图片不存在', status_code=404) + image = download_image(img['content']) + if not image: + return Response('图片不存在', status_code=404) + image.thumbnail((n*w, image.size[1])) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + + +# 获取标准原尺寸图 +@router.get("/{id}.{ext}", summary="获取标准原尺寸图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 无后缀获取原图") +def get_image(id: int = 824, ext: str = 'webp'): + # 判断图片是否已经生成 + img_path = f"{IMAGES_PATH}/{id}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + # 从数据库获取原图地址 + cursor = get_cursor() + cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") + img = cursor.fetchone() + cursor.close() + if img is None: + print('图片不存在:', id) + return Response('图片不存在', status_code=404) + image = download_image(img['content']) + if not image: + return Response('图片不存在', status_code=404) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") diff --git a/routers/reverse.py b/routers/reverse.py new file mode 100644 index 0000000..51a43e6 --- /dev/null +++ b/routers/reverse.py @@ -0,0 +1,231 @@ +import os +import random +import string +import sqlite3 +import numpy as np +import time +import io +from PIL import Image +from towhee import pipe, ops + +from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header +from models.milvus import get_collection, collection_name +from models.mysql import get_cursor, conn +#from models.resnet import Resnet50 + +from configs.config import UPLOAD_PATH +from utilities.download import download_image + +router = APIRouter() +#MODEL = Resnet50() + +RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec')) + +# 获取状态统计 +@router.get('', summary='状态统计', description='通过表名获取状态统计') +def count_images(): + collection = get_collection(collection_name) + return {'status': True, 'count': collection.num_entities} + + +# 手动重建索引 +@router.get('/create_index', summary='重建索引', description='手动重建索引', include_in_schema=False) +async def create_index(): + collection = get_collection(collection_name) + default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}} + collection.create_index(field_name="embedding", index_params=default_index) + collection.load() + return {'status': True, 'count': collection.num_entities} + +''' +# 批量生成向量 +@router.get('/create_vector', summary='生成向量', description='手动生成向量', include_in_schema=False) +async def create_vector(count: int = 10): + cursor = get_cursor() + cursor.execute(f"SELECT id,thumbnail_image,article_id,milvus_id FROM `web_images` WHERE thumbnail_image IS NOT NULL AND article_id IS NOT NULL AND milvus_id != 2048 AND width IS NOT NULL LIMIT 0,{count}") + images = cursor.fetchall() + cursor.close() + for item in images: + print(item) + # 先查询 milvus 中是否存在 + collection = get_collection(collection_name) + data = collection.query(expr=f'id in [{item["id"]}]', output_fields=None, partition_names=None, timeout=None) # offset, limit + if len(data) > 0: + cursor = get_cursor() + cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}") + conn.commit() + cursor.close() + continue + try: + img_path = os.path.join(UPLOAD_PATH, os.path.basename(item['thumbnail_image'])) + download_image(item['thumbnail_image']).save(img_path, 'png', save_all=True) + feat = MODEL.resnet50_extract_feat(img_path) + collection.insert([[item['id']], [feat], [item['article_id']]]) + cursor = get_cursor() + cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}") + conn.commit() + cursor.close() + except Exception as e: + print(e) + print('END') + return images +''' + +@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量') +async def rewrite_image(image_id: int): + print('START', image_id, '重建向量') + cursor = get_cursor() + cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}") + img = cursor.fetchone() + cursor.close() + if img is None: + print('mysql中原始图片不存在:', image_id) + return Response('图片不存在', status_code=404) + img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content'])) + print('img_path', img_path) + image = download_image(img['content']) + if image is None: + print('图片下载失败:', img['content']) + return Response('图片下载失败', status_code=404) + image.save(img_path, 'png', save_all=True) + + with Image.open(img_path) as imgx: + feat = RESNET50(imgx).get()[0] + + collection = get_collection(collection_name) + collection.delete(expr=f'id in [{image_id}]') + rest = collection.insert([[image_id], [feat], [img['article_id']]]) + os.remove(img_path) + print('END', image_id, '重建向量', rest.primary_keys) + return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()} + + +# 获取相似(废弃) +@router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片') +async def similar_images(image_id: int, page: int = 1, pageSize: int = 20): + collection = get_collection(collection_name) + result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1) + # 如果没有结果, 则重新生成记录 + if len(result) == 0: + cursor = get_cursor() + cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}") + img = cursor.fetchone() + cursor.close() + if img is None: + print('mysql 中图片不存在:', image_id) + return Response('图片不存在', status_code=404) + img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content'])) + image = download_image(img['content']) + if image is None: + print('图片下载失败:', img['content']) + return Response('图片下载失败', status_code=404) + image.save(img_path, 'png', save_all=True) + with Image.open(img_path) as imgx: + feat = RESNET50(imgx).get()[0] + # 移除可能存在的旧记录, 换上新的 + collection.delete(expr=f'id in [{image_id}]') + collection.insert([[image_id], [feat], [img['article_id']]]) + os.remove(img_path) + print('生成') + else: + print('通过') + feat = result[0]['embedding'] + res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0] + # 翻页(截取有效范围, page * pageize) + ope = page*pageSize-pageSize + end = page*pageSize + next = False + # 为数据附加信息 + ids = [i.id for i in res] # 获取所有ID + if len(res) <= end: + ids = ids[ope:] + next = False + else: + ids = ids[ope:end] + next = True + str_ids = str(ids).replace('[', '').replace(']', '') + if str_ids == '': + print('没有更多数据了') + return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []} + cursor = get_cursor() + cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})") + imgs = cursor.fetchall() + if len(imgs) == 0: + return imgs + # 获取用户ID和文章ID + uids = list(set([x['user_id'] for x in imgs])) + tids = list(set([x['article_id'] for x in imgs])) + # 获取用户信息 + cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})") + users = cursor.fetchall() + # 获取文章信息 + cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})") + articles = cursor.fetchall() + cursor.close() + # 合并信息 + user, article = {}, {} + for x in users: user[x['id']] = x + for x in articles: article[x['id']] = x + for x in imgs: + x['article'] = article[x['article_id']] + x['user'] = user[x['user_id']] + x['distance'] = [i.distance for i in res if i.id == x['id']][0] + if x['praise_count'] == None: x['praise_count'] = 0 + if x['collect_count'] == None: x['collect_count'] = 0 + # 将字段名转换为驼峰 + x['createTime'] = x.pop('create_time') + x['updateTime'] = x.pop('update_time') + # 对 imgs 重新排序(按照 distance 字段) + imgs = sorted(imgs, key=lambda x: x['distance']) + return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs} + + +@router.post(path='', summary='以图搜图', description='上传图片进行搜索') +async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize: int = 20): + content = await image.read() + img = Image.open(image.file) + embeddig = RESNET50(img).get()[0] + collection = get_collection('default') + res = collection.search([embeddig],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=500)[0] + ope, end = (page - 1) * pageSize, page * pageSize + ids, nextx = [x.id for x in res][ope:end], len(res) > end + + if not ids: + return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []} + + cursor = get_cursor() + cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})") + imgs = cursor.fetchall() + + if not imgs: + return imgs + + uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs)) + cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})") + users = cursor.fetchall() + cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})") + articles = cursor.fetchall() + cursor.close() + + user, article = {x['id']: x for x in users}, {x['id']: x for x in articles} + for x in imgs: + x.update({ + 'article': article[x['article_id']], + 'user': user[x['user_id']], + #'distance': next(i.distance for i in res if i.id == x['id'], 0), + 'distance': [i.distance for i in res if i.id == x['id']][0], + 'praise_count': x.get('praise_count', 0), + 'collect_count': x.get('collect_count', 0) + }) + + imgs.sort(key=lambda x: x['distance']) + return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs} + +@router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像") +async def delete_images(thread_id: str): + collection = get_collection(collection_name) + collection.delete(expr="article_id in ["+thread_id+"]") + collection.load() + default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}} + collection.create_index(field_name="embedding", index_params=default_index) + return {"status": True, 'msg': '删除完毕'} diff --git a/routers/task.py b/routers/task.py new file mode 100644 index 0000000..e36c852 --- /dev/null +++ b/routers/task.py @@ -0,0 +1,94 @@ +from fastapi import APIRouter, HTTPException, Request, WebSocket +from fastapi.responses import HTMLResponse +from fastapi.templating import Jinja2Templates +from models.task import TaskForm, TaskManager + + +router = APIRouter() +task_manager = TaskManager() + + +# 创建新任务 +@router.post("", summary="创建新任务") +def create_task(form: TaskForm): + task = task_manager.add_task(name=form.name, user_id=123456, description=form.description) + return task + + +# 获取任务列表 +@router.get("", summary="获取任务列表", description="可使用user_id参数筛选指定用户的任务列表") +def get_task_list(user_id: int=None): + return task_manager.get_tasks(user_id) + + +# websocket demo +@router.get("/demo", response_class=HTMLResponse) +async def websocket_demo(request: Request): + task_list = task_manager.get_tasks() + templates = Jinja2Templates(directory="templates") + return templates.TemplateResponse("websocket.html", {"request": request, "task_list": task_list}) + + +# 监听指定任务的变化事件, 通知前端(不使用pydantic模型, 正确的写法) +@router.websocket("/{task_id}", name="监听任务变化") +async def websocket_endpoint(task_id: str, websocket: WebSocket): + await websocket.accept() + await task_manager.add_websocket(task_id, websocket) + async for data in websocket.iter_text(): + await websocket.send_text(f"Message text was: {data}") + task_manager.remove_websocket(task_id, websocket) + print("websocket 连接已自动关闭") + + #await websocket.send_json({"message": "Hello WebSocket!"}) + #task = task_manager.get_task(task_id) + #print(task) + #if not task: + # print("task 不存在, 结束连接") + # return await websocket.close() # 任务不存在, 结束连接 + #await websocket.send_json(task) # 将任务的状态发送给客户端 + #await task_manager.add_websocket(task_id, websocket) + # 正确的写法, 使用 async for, 并且处理意外断开的情况 + #try: + # async for data in websocket.iter_text(): + # if data == "close": + # print("客户端主动关闭连接") + # task_manager.remove_websocket(task_id, websocket) + # await websocket.close() + # break + # else: + # print(f"接收到客户端消息: {data}") + # await websocket.send_text(f"Message text was: {data}") + #except Exception as e: + # print(f"客户端意外断开连接: {e}") + # task_manager.remove_websocket(task_id, websocket) + # #await websocket.close() + + # 监听客户端的状态, 如果客户端主动关闭连接或意外断开连接, 都从任务的websocket列表中移除 + #while True: + # try: + # data = await websocket.receive_text() + # if data == "close": + # print("客户端主动关闭连接") + # task_manager.remove_websocket(task_id, websocket) + # await websocket.close() + # break + # else: + # print(f"接收到客户端消息: {data}") + # await websocket.send_text(f"Message text was: {data}") + # except Exception as e: + # print(f"客户端意外断开连接: {e}") + # task_manager.remove_websocket(task_id, websocket) + # #await websocket.close() + # break + + +# 获取任务详情 +@router.get("/{task_id}", summary="获取任务详情") +def get_task(task_id: int): + return get_task(task_id) + + +# 删除任务 +@router.delete("/{task_id}", summary="删除指定任务") +def delete_task(task_id: int): + return delete_task(task_id) diff --git a/routers/user.py b/routers/user.py new file mode 100644 index 0000000..eb16016 --- /dev/null +++ b/routers/user.py @@ -0,0 +1,19 @@ +from fastapi import APIRouter, HTTPException +from models.mysql import conn, get_cursor + +router = APIRouter() + +@router.get('/{user_id}/collect', summary="用户收藏记录", description="获取指定用户收藏记录") +def get_user_collect(user_id:int): + # TODO: 需要验证权限 + cursor = get_cursor() + cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'") + data = cursor.fetchall() + if not data: + return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []} + data = [str(item['content']) for item in data] + cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) + data = cursor.fetchall() + data = [str(item['id']) for item in data] + return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data } + diff --git a/routers/user_collect.py b/routers/user_collect.py new file mode 100644 index 0000000..e8d912a --- /dev/null +++ b/routers/user_collect.py @@ -0,0 +1,45 @@ +from typing import Optional +from fastapi import APIRouter, HTTPException, Header +from models.mysql import conn, get_cursor + + +router = APIRouter() + + +# 获取当前用户收藏记录 +@router.get('', summary='自己的收藏记录', description='获取自己的收藏记录, 用于判断是否收藏(headers中必须附带token)') +def get_self_collect(token: Optional[str] = Header()): + print('token: ', token) + cursor = get_cursor() + # 查询用户ID + cursor.execute(f"SELECT user_id FROM web_auth WHERE token={token} limit 1") + data = cursor.fetchone() + print('auth: ', data) + if not data: + raise HTTPException(status_code=401, detail="用户未登录") + user_id = data['user_id'] + # 查询收藏记录 + cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'") + data = cursor.fetchall() # 获取所有记录列表 + data = [str(item['content']) for item in data] # 转换为数组 + # 查询图片ID(对特殊字符安全转义) + cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) + data = cursor.fetchall() + data = [str(item['id']) for item in data] + return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data } + + +# 获取指定用户收藏记录 +@router.get('/{user_id}', summary='指定用的户收藏记录', description='获取指定用户收藏记录(仅测试用)') +def get_user_collect(user_id: int): + cursor = get_cursor() + cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type=1") + data = cursor.fetchall() # 获取所有记录列表 + data = [str(item['content']) for item in data] # 转换为数组 + if not data: + return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []} + # 查询图片ID(对特殊字符安全转义) + cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) + data = cursor.fetchall() + data = [str(item['id']) for item in data] + return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data} diff --git a/start.sh b/start.sh new file mode 100755 index 0000000..d130fc9 --- /dev/null +++ b/start.sh @@ -0,0 +1,25 @@ +#!/bin/sh + +# pip freeze > requirements.txt +#pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +#python3 -m pip install -U pyOpenSSL cryptography + +# 使用 uvicorn 启动两个工作进程, 并支持自动重载 +#uvicorn main:app --host 0.0.0.0 --port 5001 --workers 2 --reload >/dev/null 2>&1 & +#uvicorn main:app --host 0.0.0.0 --port 5002 >/dev/null 2>&1 & + +python main.py >/dev/null 2>&1 & + +# 为容器设定自动重启 +#docker update --restart=always 51bb94aa2726 + +#cd ~/reverse_image_search +#pm2 start python3 --name reverse-1 -- main.py 5001 +#pm2 start python3 --name reverse-2 -- main.py 5002 +#pm2 start python3 --name reverse-3 -- main.py 5003 + +# 某些系统下,需要手动安装指定版本 opencv-python (4.7引用zlib错误) +# pip install opencv-python==4.6.0.66 -i https://pypi.tuna.tsinghua.edu.cn/simple + +# 如果网络不通, 从自定义服务器下载模型参数 +#scp root@172.21.216.33:~/root/.cache/torch/hub/checkpoints/resnet50_a1_0-14fe96d1.pth resnet50_a1_0-14fe96d1.pth diff --git a/templates/websocket.html b/templates/websocket.html new file mode 100644 index 0000000..ae8d6fe --- /dev/null +++ b/templates/websocket.html @@ -0,0 +1,44 @@ + + + + WebSocket Demo + + +
+

{{ task_list }}

+ +
+

WebSocket Demo

+
+ + + + + diff --git a/update.sh b/update.sh new file mode 100755 index 0000000..b026c87 --- /dev/null +++ b/update.sh @@ -0,0 +1,4 @@ +#!/bin/sh + +# 服务器跟随更新 +ssh ai "cd ~/reverse_image_search_gpu; git pull;" diff --git a/utilities/download.py b/utilities/download.py new file mode 100644 index 0000000..13bfb82 --- /dev/null +++ b/utilities/download.py @@ -0,0 +1,41 @@ +import io +import os +import requests + +from PIL import Image, ImageFile +from models.oss import bucket_image2, bucket_webp +from configs.config import IMAGES_PATH + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +# 下载图片(使用OSS下载) +def download_image(url:str) -> Image: + if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'): + try: + url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '') + obj = bucket_image2.get_object(url).read() + return Image.open(io.BytesIO(obj)) + except Exception: + return None + else: + try: + response = requests.get(url) + return Image.open(io.BytesIO(response.content)) + except Exception: + print('图片下载失败:', url) + return None + +# 生成缩略图, 写入OSS +def generate_thumbnail(image:Image, id:int, version:str, n:int, w:int, ext:str): + path = f"{id}-{version}@{n}x{w}.{ext}" + if bucket_webp.object_exists(path): + print('缩略图已经存在:', path) + return + # 将 image 对象复制一份, 防止影响原图 + image = image.copy() + image.thumbnail((n*w, image.size[1])) + image.save(f"{IMAGES_PATH}/{path}", ext, save_all=True) + bucket_webp.put_object_from_file(path, f"{IMAGES_PATH}/{path}") + os.remove(f"{IMAGES_PATH}/{path}") + print('缩略图写入 OSS:', path) +