转移
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__pycache__
|
119
README.md
119
README.md
@@ -1,3 +1,118 @@
|
|||||||
# reverse_image_search_gpu
|
# reverse_image_search
|
||||||
|
|
||||||
图片反向搜索, 利用GPU加速
|
图片反向检索
|
||||||
|
|
||||||
|
```urls
|
||||||
|
https://www.gameui.net/api/default/create_vector?count=1024 # 预先生成指定数量的向量
|
||||||
|
https://www.gameui.net/api/default/create_index # 手动重建索引
|
||||||
|
```
|
||||||
|
|
||||||
|
## Intall
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 先安装 milvus v2.1.4 矢量数据库
|
||||||
|
wget https://github.com/milvus-io/milvus/releases/download/v2.1.4/milvus_2.1.4-1_amd64.deb
|
||||||
|
sudo apt-get update
|
||||||
|
sudo dpkg -i milvus_2.1.4-1_amd64.deb
|
||||||
|
sudo apt-get -f install
|
||||||
|
|
||||||
|
# 查看 Milvus 及其依賴的狀態(etcd 和 MinIO)
|
||||||
|
sudo systemctl status milvus
|
||||||
|
sudo systemctl status milvus-etcd
|
||||||
|
sudo systemctl status milvus-minio
|
||||||
|
|
||||||
|
# 安装 etcdctl (管理工具)
|
||||||
|
wget https://github.com/etcd-io/etcd/releases/download/v3.4.22/etcd-v3.4.22-linux-amd64.tar.gz
|
||||||
|
# 解压 etcdctl 到 /bin/
|
||||||
|
# 向 ~/.bashrc 添加 export ETCDCTL_API=3
|
||||||
|
|
||||||
|
# 安装 python 包管理工具 pip
|
||||||
|
sudo apt install python3-pip
|
||||||
|
|
||||||
|
# 使用 venv 创建虚拟环境(注意安装路径是当前用户的 .local/bin)
|
||||||
|
python3 -m venv venv
|
||||||
|
source venv/bin/activate
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
|
||||||
|
# 为 .bashrc 设置快捷命令(使用echo或cat追加写)
|
||||||
|
vim ~/.bashrc
|
||||||
|
alias venv='source venv/bin/activate'
|
||||||
|
source ~/.bashrc
|
||||||
|
|
||||||
|
# 不使用虚拟环境的防止依赖冲突方法
|
||||||
|
pip freeze > requirements.txt
|
||||||
|
pip uninstall -r requirements.txt
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
|
||||||
|
# 注意 python 要求 V3.9.7 版本以上
|
||||||
|
python3 src/main.py
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
挂载对象存储服务作为存储盘
|
||||||
|
|
||||||
|
## 批量处理工具
|
||||||
|
|
||||||
|
- [x] 使项目不依赖平台提供的服务, 易于迁移
|
||||||
|
- [x] 操作读写数据库, 不使用mysql
|
||||||
|
- [x] /img/no_content.8dd8acce.png 冲突
|
||||||
|
- [x] 在上海区新建OSS实例作为磁盘挂载到服务器(用于存储缓存图) ...登录
|
||||||
|
- [x] 在服务器提供接口(设置CDN层分发) ...登录
|
||||||
|
- [x] 获取前端项目(修改列表图片三类分辨率)
|
||||||
|
- [x] 通过 path 获取缩略图
|
||||||
|
- [x] 通过 articleDetails_id 获取缩略图
|
||||||
|
- [x] 图片向量无缓存时重新生成
|
||||||
|
- [ ] 提前生成原图webp以提高首次打开速度
|
||||||
|
|
||||||
|
## RTT
|
||||||
|
|
||||||
|
1. OSS 中 gameui-webp Bucket 作为图片缓存层
|
||||||
|
2. 通过 /img/xxx.webp 接口提供图片, 由 OSS 作镜像回源
|
||||||
|
|
||||||
|
NGINX config
|
||||||
|
|
||||||
|
```config
|
||||||
|
server {
|
||||||
|
location /docs {
|
||||||
|
proxy_pass http://gameui_ai_server;
|
||||||
|
proxy_redirect off;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
}
|
||||||
|
|
||||||
|
location /api/ {
|
||||||
|
proxy_pass http://gameui_ai_server;
|
||||||
|
proxy_redirect off;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
}
|
||||||
|
|
||||||
|
location ~* /img/([0-9]+)\.(webp|jpeg) {
|
||||||
|
proxy_pass http://gameui_ai_server;
|
||||||
|
proxy_redirect off;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
}
|
||||||
|
|
||||||
|
location ~* /img/([0-9]+).*\.(webp|jpeg) {
|
||||||
|
proxy_pass http://gameui_ai_server;
|
||||||
|
proxy_redirect off;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
}
|
||||||
|
|
||||||
|
location ~* /img/article-([0-9]+).*\.(webp|jpeg) {
|
||||||
|
proxy_pass http://gameui_ai_server;
|
||||||
|
proxy_redirect off;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
61
configs/config.py
Normal file
61
configs/config.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
|
||||||
|
|
||||||
|
# 以 .env 文件中的环境变量为准, 检查是否存在 .env 文件, 没有则创建
|
||||||
|
if not os.path.exists('.env'):
|
||||||
|
print('请输入环境变量参数, 将会写入 .env 文件中')
|
||||||
|
OSS_HOST = input('OSS_HOST: ')
|
||||||
|
MYSQL_HOST = input('MYSQL_HOST: ')
|
||||||
|
MILVUS_HOST = input('MILVUS_HOST: ')
|
||||||
|
with open('.env', 'w') as f:
|
||||||
|
f.write(f'OSS_HOST={OSS_HOST}\n')
|
||||||
|
f.write(f'MYSQL_HOST={MYSQL_HOST}\n')
|
||||||
|
f.write(f'MILVUS_HOST={MILVUS_HOST}\n')
|
||||||
|
|
||||||
|
|
||||||
|
# 读取 .env 文件中的环境变量
|
||||||
|
with open('.env', 'r') as f:
|
||||||
|
env = f.readlines()
|
||||||
|
env = list(filter(lambda x: not x.startswith('#') and not x.startswith('\n') and len(x.split('=')) == 2, env))
|
||||||
|
env = list(map(lambda x: x.replace('\n', '').split('='), env))
|
||||||
|
env = {k: v for k, v in env}
|
||||||
|
print(json.dumps(env, indent=4, ensure_ascii=False))
|
||||||
|
OSS_HOST = env.get('OSS_HOST')
|
||||||
|
MYSQL_HOST = env.get('MYSQL_HOST')
|
||||||
|
MILVUS_HOST = env.get('MILVUS_HOST')
|
||||||
|
|
||||||
|
|
||||||
|
# 创建上传图片的临时目录
|
||||||
|
UPLOAD_PATH = '/tmp/search-images'
|
||||||
|
if not os.path.exists(UPLOAD_PATH):
|
||||||
|
os.makedirs(UPLOAD_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
# 创建压缩图片的临时目录
|
||||||
|
IMAGES_PATH = '/tmp/images'
|
||||||
|
if not os.path.exists(IMAGES_PATH):
|
||||||
|
os.makedirs(IMAGES_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
# 创建sqlite3数据库目录
|
||||||
|
SQLITE3_PATH = '/tmp/sqlite3'
|
||||||
|
if not os.path.exists(SQLITE3_PATH):
|
||||||
|
os.makedirs(SQLITE3_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_images():
|
||||||
|
print('开始清理创建时间大于30分钟的图片缓存')
|
||||||
|
for file in os.listdir(IMAGES_PATH):
|
||||||
|
if os.path.getmtime(os.path.join(IMAGES_PATH, file)) < time.time() - 30 * 60:
|
||||||
|
print('清理图片缓存:', file)
|
||||||
|
os.remove(os.path.join(IMAGES_PATH, file))
|
||||||
|
|
||||||
|
|
||||||
|
# 构建一个定时任务每30分钟清理一次图片缓存
|
||||||
|
scheduler = BackgroundScheduler()
|
||||||
|
scheduler.add_job(clear_images, 'interval', minutes=30)
|
||||||
|
scheduler.start()
|
53
demo.py
Executable file
53
demo.py
Executable file
@@ -0,0 +1,53 @@
|
|||||||
|
import timm
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as functional
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
# 加载预训练模型
|
||||||
|
model = timm.create_model('resnet50', pretrained=True)
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
# 定义图片处理流程
|
||||||
|
preprocess = transforms.Compose([
|
||||||
|
transforms.Resize(256),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 读取图片
|
||||||
|
print("loading image")
|
||||||
|
img = Image.open("demo.png").convert("RGB")
|
||||||
|
tensor = preprocess(img).unsqueeze(0)
|
||||||
|
|
||||||
|
# 检查是否有可用的GPU
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
input_batch = tensor.to('cuda')
|
||||||
|
model.to('cuda')
|
||||||
|
|
||||||
|
print("start run model")
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_batch)
|
||||||
|
|
||||||
|
for x in output:
|
||||||
|
print(x.shape)
|
||||||
|
|
||||||
|
# 输出2048维向量
|
||||||
|
# print(output[0])
|
||||||
|
|
||||||
|
'''
|
||||||
|
from towhee import pipe, ops, DataCollection
|
||||||
|
|
||||||
|
p = (
|
||||||
|
pipe.input('path')
|
||||||
|
.map('path', 'img', ops.image_decode())
|
||||||
|
.map('img', 'vec', ops.image_embedding.timm(model_name='resnet50'))
|
||||||
|
.output('img', 'vec')
|
||||||
|
)
|
||||||
|
|
||||||
|
ea = DataCollection(p('demo.png')).to_list()
|
||||||
|
print(ea)
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
41
download.py
Normal file
41
download.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import io
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from models.oss import bucket_image2, bucket_webp
|
||||||
|
from configs.config import IMAGES_PATH
|
||||||
|
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
# 下载图片(使用OSS下载)
|
||||||
|
def download_image(url:str) -> Image:
|
||||||
|
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
|
||||||
|
try:
|
||||||
|
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
|
||||||
|
obj = bucket_image2.get_object(url).read()
|
||||||
|
return Image.open(io.BytesIO(obj))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
return Image.open(io.BytesIO(response.content))
|
||||||
|
except Exception:
|
||||||
|
print('图片下载失败:', url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 生成缩略图, 写入OSS
|
||||||
|
def generate_thumbnail(image:Image, id:int, version:str, n:int, w:int, ext:str):
|
||||||
|
path = f"{id}-{version}@{n}x{w}.{ext}"
|
||||||
|
if bucket_webp.object_exists(path):
|
||||||
|
print('缩略图已经存在:', path)
|
||||||
|
return
|
||||||
|
# 将 image 对象复制一份, 防止影响原图
|
||||||
|
image = image.copy()
|
||||||
|
image.thumbnail((n*w, image.size[1]))
|
||||||
|
image.save(f"{IMAGES_PATH}/{path}", ext, save_all=True)
|
||||||
|
bucket_webp.put_object_from_file(path, f"{IMAGES_PATH}/{path}")
|
||||||
|
os.remove(f"{IMAGES_PATH}/{path}")
|
||||||
|
print('缩略图写入 OSS:', path)
|
||||||
|
|
19
log.yaml
Normal file
19
log.yaml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
version: 1
|
||||||
|
formatters:
|
||||||
|
simple:
|
||||||
|
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
handlers:
|
||||||
|
console:
|
||||||
|
class: logging.StreamHandler
|
||||||
|
level: DEBUG
|
||||||
|
formatter: simple
|
||||||
|
stream: ext://sys.stdout
|
||||||
|
loggers:
|
||||||
|
simpleExample:
|
||||||
|
level: DEBUG
|
||||||
|
handlers: [console]
|
||||||
|
propagate: no
|
||||||
|
root:
|
||||||
|
level: DEBUG
|
||||||
|
handlers: [console]
|
||||||
|
|
28
main.py
Executable file
28
main.py
Executable file
@@ -0,0 +1,28 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from functools import lru_cache
|
||||||
|
from routers import reverse, user, task, img, user_collect
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化 FastAPI
|
||||||
|
app = FastAPI(title="GameUI", description="GameUI", version="1.5.0", openapi_url="/docs/openapi.json", docs_url="/docs", redoc_url="/redoc")
|
||||||
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
||||||
|
|
||||||
|
|
||||||
|
# 导入路由
|
||||||
|
app.include_router(user.router, prefix='/api/user', tags=['用户'])
|
||||||
|
app.include_router(task.router, prefix='/api/task', tags=['任务'])
|
||||||
|
app.include_router(reverse.router, prefix='/api/default', tags=['搜图'])
|
||||||
|
app.include_router(user_collect.router, prefix='/api/user_collect', tags=['收藏'])
|
||||||
|
app.include_router(img.router, prefix='/imgs', tags=['图片'])
|
||||||
|
|
||||||
|
|
||||||
|
# 启动服务
|
||||||
|
if __name__ == '__main__':
|
||||||
|
port = 5002 if len(sys.argv) < 2 else int(sys.argv[1])
|
||||||
|
uvicorn.run(app='main:app', host='0.0.0.0', port=port, reload=True, workers=1)
|
26
models/milvus.py
Normal file
26
models/milvus.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import pymilvus
|
||||||
|
|
||||||
|
from configs.config import MILVUS_HOST
|
||||||
|
|
||||||
|
# 连接 Milvus (开启 Milvus 服务)
|
||||||
|
collection_name = 'default'
|
||||||
|
|
||||||
|
|
||||||
|
# 获取 Milvus 连接
|
||||||
|
def get_collection(collection_name):
|
||||||
|
pymilvus.connections.connect(host=MILVUS_HOST, port='19530')
|
||||||
|
if not pymilvus.utility.has_collection(collection_name):
|
||||||
|
field1 = pymilvus.FieldSchema(name="id", dtype=pymilvus.DataType.INT64, is_primary=True)
|
||||||
|
field2 = pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, dim=2048)
|
||||||
|
field3 = pymilvus.FieldSchema(name="article_id", dtype=pymilvus.DataType.INT64)
|
||||||
|
schema = pymilvus.CollectionSchema(fields=[field1, field2, field3])
|
||||||
|
return pymilvus.Collection(name=collection_name, schema=schema)
|
||||||
|
return pymilvus.Collection(name=collection_name)
|
||||||
|
|
||||||
|
|
||||||
|
# 检查索引是否存在, 不存在则创建, 并加载
|
||||||
|
#collection = get_collection(collection_name)
|
||||||
|
#if not collection.has_index():
|
||||||
|
# default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
|
||||||
|
# collection.create_index(field_name="embedding", index_params=default_index)
|
||||||
|
#collection.load()
|
16
models/mysql.py
Normal file
16
models/mysql.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import pymysql
|
||||||
|
from configs.config import MYSQL_HOST
|
||||||
|
|
||||||
|
|
||||||
|
# 连接 MySQL (开启 MySQL 服务)
|
||||||
|
conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor)
|
||||||
|
|
||||||
|
|
||||||
|
# 获取 MySQL 连接
|
||||||
|
def get_cursor():
|
||||||
|
try:
|
||||||
|
dx = conn.ping()
|
||||||
|
return conn.cursor()
|
||||||
|
except Exception:
|
||||||
|
conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor)
|
||||||
|
return conn.cursor()
|
9
models/oss.py
Normal file
9
models/oss.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import oss2
|
||||||
|
from configs.config import OSS_HOST
|
||||||
|
|
||||||
|
|
||||||
|
# 连接 OSS
|
||||||
|
oss2.defaults.connection_pool_size = 100
|
||||||
|
auth = oss2.Auth('LTAI4GH3qP6VA3QpmTYCgXEW', 'r2wz4bJty8iYfGIcFmEqlY1yon2Ruy')
|
||||||
|
bucket_image2 = oss2.Bucket(auth, f'http://{OSS_HOST}', 'gameui-image2')
|
||||||
|
bucket_webp = oss2.Bucket(auth, f'http://{OSS_HOST}', 'gameui-webp')
|
11
models/resnet.py
Normal file
11
models/resnet.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import towhee
|
||||||
|
|
||||||
|
class Resnet50:
|
||||||
|
def resnet50_extract_feat(self, img_path):
|
||||||
|
feat = towhee.glob(img_path).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list()
|
||||||
|
print(feat[0])
|
||||||
|
return feat[0]
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print('This script is running as the main program.')
|
||||||
|
#resnet =
|
184
models/task.py
Normal file
184
models/task.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import datetime
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from configs.config import SQLITE3_PATH
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False)
|
||||||
|
__taskdb.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
# 创建任务表
|
||||||
|
__taskdb.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS tasks (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
description TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
# 任务表单模型
|
||||||
|
class TaskForm(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class Task(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
user_id: int
|
||||||
|
description: str
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
status: str = "pending"
|
||||||
|
progress: int = 0
|
||||||
|
websockets: list = []
|
||||||
|
|
||||||
|
|
||||||
|
class TaskManager:
|
||||||
|
__tasks = {}
|
||||||
|
|
||||||
|
# 使用pydantic模型的写法
|
||||||
|
def add_task(self, name: str, user_id: int, description: str):
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
date = str(datetime.datetime.now())
|
||||||
|
task = Task(id=id, name=name, user_id=user_id, description=description, created_at=date, updated_at=date)
|
||||||
|
self.__tasks[id] = task
|
||||||
|
return task
|
||||||
|
|
||||||
|
def get_task(self, task_id: str):
|
||||||
|
task = self.__tasks.get(task_id)
|
||||||
|
if task:
|
||||||
|
return task.dict()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 直接返回对象的写法
|
||||||
|
def get_tasks(self, user_id: int=None):
|
||||||
|
if user_id:
|
||||||
|
return [task for task in self.__tasks.values() if task.user_id == user_id]
|
||||||
|
else:
|
||||||
|
return [task for task in self.__tasks.values()]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_task(self, task_id: str):
|
||||||
|
if self.__tasks.get(task_id):
|
||||||
|
self.__tasks.pop(task_id)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 正确的写法(不等待async函数直接返回结果)
|
||||||
|
async def add_websocket(self, task_id: str, websocket):
|
||||||
|
print(self.__tasks)
|
||||||
|
# 每次调用连接数都是1, 这是错误的
|
||||||
|
self.__tasks[task_id].websockets.append(websocket)
|
||||||
|
print(f"当前连接数: {len(self.__tasks[task_id].websockets)}")
|
||||||
|
|
||||||
|
|
||||||
|
#if len(task.websockets) > 0 and task.progress == 0 and task.status == "pending":
|
||||||
|
# print("任务未被监听, 开启监听")
|
||||||
|
# await self.start_listen_task(task)
|
||||||
|
|
||||||
|
# 移除 websocket
|
||||||
|
def remove_websocket(self, task_id: str, websocket):
|
||||||
|
task = self.__tasks.get(task_id)
|
||||||
|
if task:
|
||||||
|
print("连接已断开, 移除 websocket")
|
||||||
|
task.websockets.remove(websocket)
|
||||||
|
print(f"当前剩余连接数: {len(task.websockets)}")
|
||||||
|
|
||||||
|
# 正确的写法
|
||||||
|
async def start_listen_task(self, task: Task):
|
||||||
|
while True:
|
||||||
|
print(f"start listen task: {len(task.websockets)}")
|
||||||
|
time.sleep(2.5)
|
||||||
|
if len(task.websockets) == 0:
|
||||||
|
break
|
||||||
|
print(f"get task status:{task.id}")
|
||||||
|
break
|
||||||
|
#data = requests.get(f"http://localhost:5001/api/task/{task.id}").json()
|
||||||
|
#print(data)
|
||||||
|
## 通过 http 请求获取任务的状态
|
||||||
|
## 如果任务状态发生变化, 则将任务状态发送给所有监听该任务的 websocket
|
||||||
|
#if task.status != data.status:
|
||||||
|
# task.status = data.status
|
||||||
|
# print("send task status to websocket")
|
||||||
|
# for websocket in task.websockets:
|
||||||
|
# await websocket.send_json(task)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 任务实体模型
|
||||||
|
#class TaskModel(BaseModel):
|
||||||
|
# id: str
|
||||||
|
# name: str
|
||||||
|
# user_id: int
|
||||||
|
# description: str
|
||||||
|
# created_at: str
|
||||||
|
# updated_at: str
|
||||||
|
# websockets: list=[]
|
||||||
|
#
|
||||||
|
## 向数据库中添加任务
|
||||||
|
#def add_task(name, user_id, description):
|
||||||
|
# print(name, user_id, description)
|
||||||
|
# date = datetime.datetime.now()
|
||||||
|
# cursor = __taskdb.cursor()
|
||||||
|
# cursor.execute("INSERT INTO tasks(name,user_id,description,created_at,updated_at) VALUES (?, ?, ?, ?, ?)", (name,user_id,description,date,date))
|
||||||
|
# id = cursor.lastrowid
|
||||||
|
# cursor.close()
|
||||||
|
# __taskdb.commit()
|
||||||
|
# return {"id": id,"name": name,"user_id": user_id,"description": description,"created_at": date,"updated_at": date}
|
||||||
|
#
|
||||||
|
#
|
||||||
|
## 从数据库中获取任务列表
|
||||||
|
#def get_tasks(user_id: int=None):
|
||||||
|
# cursor = __taskdb.cursor()
|
||||||
|
# if user_id:
|
||||||
|
# cursor.execute("SELECT * FROM tasks WHERE user_id = ?", (user_id,))
|
||||||
|
# else:
|
||||||
|
# cursor.execute("SELECT * FROM tasks")
|
||||||
|
# tasks = cursor.fetchall()
|
||||||
|
# cursor.close()
|
||||||
|
# return [dict(task) for task in tasks]
|
||||||
|
#
|
||||||
|
#
|
||||||
|
## 从数据库中获取指定任务
|
||||||
|
#def get_task(task_id: int):
|
||||||
|
# cursor = __taskdb.cursor()
|
||||||
|
# cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
|
||||||
|
# task = cursor.fetchone()
|
||||||
|
# cursor.close()
|
||||||
|
# return task
|
||||||
|
#
|
||||||
|
#
|
||||||
|
## 更新数据库中的任务
|
||||||
|
#def update_task(task: TaskModel):
|
||||||
|
# cursor = __taskdb.cursor()
|
||||||
|
# cursor.execute("UPDATE tasks SET name = ?, description = ?, updated_at = ? WHERE id = ?", (
|
||||||
|
# task.name,
|
||||||
|
# task.description,
|
||||||
|
# datetime.datetime.now(),
|
||||||
|
# task.id
|
||||||
|
# ))
|
||||||
|
# cursor.close()
|
||||||
|
# __taskdb.commit()
|
||||||
|
# return task
|
||||||
|
#
|
||||||
|
#
|
||||||
|
## 从数据库中删除任务
|
||||||
|
#def delete_task(task_id: int):
|
||||||
|
# cursor = __taskdb.cursor()
|
||||||
|
# cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
|
||||||
|
# cursor.close()
|
||||||
|
# __taskdb.commit()
|
||||||
|
# return task_id
|
||||||
|
|
||||||
|
|
63
requirements.txt
Normal file
63
requirements.txt
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
aliyun-python-sdk-core==2.13.36
|
||||||
|
aliyun-python-sdk-kms==2.16.0
|
||||||
|
anyio==3.6.2
|
||||||
|
bleach==6.0.0
|
||||||
|
certifi==2022.12.7
|
||||||
|
cffi==1.15.1
|
||||||
|
charset-normalizer==3.0.1
|
||||||
|
click==8.1.3
|
||||||
|
crcmod==1.7
|
||||||
|
cryptography==39.0.1
|
||||||
|
docutils==0.19
|
||||||
|
fastapi==0.92.0
|
||||||
|
grpcio==1.37.1
|
||||||
|
grpcio-tools==1.37.1
|
||||||
|
h11==0.14.0
|
||||||
|
idna==3.4
|
||||||
|
importlib-metadata==6.0.0
|
||||||
|
jaraco.classes==3.2.3
|
||||||
|
jeepney==0.8.0
|
||||||
|
jmespath==0.10.0
|
||||||
|
keyring==23.13.1
|
||||||
|
markdown-it-py==2.1.0
|
||||||
|
mdurl==0.1.2
|
||||||
|
mmh3==3.0.0
|
||||||
|
more-itertools==9.0.0
|
||||||
|
numpy==1.24.2
|
||||||
|
opencv-python==4.6.0.66
|
||||||
|
oss2==2.16.0
|
||||||
|
pandas==1.5.3
|
||||||
|
pgzip==0.3.4
|
||||||
|
Pillow==9.4.0
|
||||||
|
pkginfo==1.9.6
|
||||||
|
protobuf==3.20.3
|
||||||
|
pyarrow==11.0.0
|
||||||
|
pycparser==2.21
|
||||||
|
pycryptodome==3.17
|
||||||
|
pydantic==1.10.5
|
||||||
|
pygit2==1.10.1
|
||||||
|
Pygments==2.14.0
|
||||||
|
pymilvus==2.0.2
|
||||||
|
PyMySQL==1.0.2
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
python-multipart==0.0.5
|
||||||
|
pytz==2022.7.1
|
||||||
|
readme-renderer==37.3
|
||||||
|
requests==2.28.2
|
||||||
|
requests-toolbelt==0.10.1
|
||||||
|
rfc3986==2.0.0
|
||||||
|
rich==13.3.1
|
||||||
|
SecretStorage==3.3.3
|
||||||
|
six==1.16.0
|
||||||
|
sniffio==1.3.0
|
||||||
|
starlette==0.25.0
|
||||||
|
tabulate==0.9.0
|
||||||
|
towhee==0.9.0
|
||||||
|
tqdm==4.64.1
|
||||||
|
twine==4.0.2
|
||||||
|
typing_extensions==4.5.0
|
||||||
|
ujson==5.1.0
|
||||||
|
urllib3==1.26.14
|
||||||
|
uvicorn==0.20.0
|
||||||
|
webencodings==0.5.1
|
||||||
|
zipp==3.13.0
|
39
resize.sh
Executable file
39
resize.sh
Executable file
@@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# 定时整理 etcd 空间防止 etcd 服务挂掉
|
||||||
|
# 将此脚本加入定时任务(每日1次, 一次一片)
|
||||||
|
|
||||||
|
# crontab -e
|
||||||
|
# 15 * * * * /bin/sh /home/milvus/resize.sh
|
||||||
|
|
||||||
|
# 查看 Milvus 及其依賴的狀態(etcd 和 MinIO)
|
||||||
|
# sudo systemctl status milvus
|
||||||
|
# sudo systemctl status milvus-etcd
|
||||||
|
# sudo systemctl status milvus-minio
|
||||||
|
|
||||||
|
echo "使用API3:"
|
||||||
|
export ETCDCTL_API=3
|
||||||
|
|
||||||
|
echo "查看空间占用:"
|
||||||
|
etcdctl endpoint status --write-out table
|
||||||
|
|
||||||
|
echo "查看告警状态:"
|
||||||
|
etcdctl alarm list
|
||||||
|
|
||||||
|
echo "获取当前版本:"
|
||||||
|
rev=$(etcdctl --endpoints=http://127.0.0.1:2379 endpoint status --write-out="json" | egrep -o '"revision":[0-9]*' | egrep -o '[0-9].*')
|
||||||
|
|
||||||
|
echo "压缩掉所有旧版本:"
|
||||||
|
etcdctl --endpoints=http://127.0.0.1:2379 compact $rev
|
||||||
|
|
||||||
|
echo "整理多余的空间:"
|
||||||
|
etcdctl --endpoints=http://127.0.0.1:2379 defrag
|
||||||
|
|
||||||
|
echo "取消告警信息:"
|
||||||
|
etcdctl --endpoints=http://127.0.0.1:2379 alarm disarm
|
||||||
|
|
||||||
|
echo "再次查看空间占用:"
|
||||||
|
etcdctl endpoint status --write-out table
|
||||||
|
|
||||||
|
echo "再次查看告警状态:"
|
||||||
|
etcdctl alarm list
|
214
routers/img.py
Normal file
214
routers/img.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import base64
|
||||||
|
import psutil
|
||||||
|
import statistics
|
||||||
|
import _thread as thread
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Response
|
||||||
|
from urllib.parse import unquote
|
||||||
|
from configs.config import IMAGES_PATH
|
||||||
|
from models.mysql import get_cursor, conn
|
||||||
|
from utilities.download import download_image, generate_thumbnail
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# 预热图片(获取一次图片, 遍历图片表, 检查OSS中所有被预定的尺寸是否存在, 不存在则生成)
|
||||||
|
@router.get("/warm", summary="预热图片", description="预热图片")
|
||||||
|
def warm_image(op:int=0, end:int=10, version:str='0'):
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT * FROM `web_images` limit {op}, {end}")
|
||||||
|
for img in cursor.fetchall():
|
||||||
|
# 如果CPU使用率大于50%, 则等待, 直到CPU使用率小于50%
|
||||||
|
while statistics.mean(psutil.cpu_percent(interval=1, percpu=True)) > 50:
|
||||||
|
print(statistics.mean(psutil.cpu_percent(interval=1, percpu=True)), '等待CPU释放...')
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# 如果内存剩余小于1G, 则等待, 直到内存剩余大于1G
|
||||||
|
while psutil.virtual_memory().available < 1024 * 1024 * 1024:
|
||||||
|
print(psutil.virtual_memory().available, '等待内存释放...')
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# CPU使用率已降低, 开始处理图片
|
||||||
|
image = download_image(img['content']) # 从OSS下载原图
|
||||||
|
if not image:
|
||||||
|
print('跳过不存在的图片:', img['content'])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建新线程处理图片
|
||||||
|
try:
|
||||||
|
print('开始处理图片:', img['content'])
|
||||||
|
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 3, 328, 'webp'))
|
||||||
|
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 2, 328, 'webp'))
|
||||||
|
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 1, 328, 'webp'))
|
||||||
|
except:
|
||||||
|
print('无法启动线程')
|
||||||
|
cursor.close()
|
||||||
|
return Response('预热成功', status_code=200, media_type='text/plain', headers={'Content-Type': 'text/plain; charset=utf-8'})
|
||||||
|
|
||||||
|
|
||||||
|
# 获取非标准类缩略图
|
||||||
|
@router.get("/{type}-{id}-{version}@{n}x{w}.{ext}", summary="获取非标准类缩略图", description="/img/article-233-version@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
|
||||||
|
def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:str):
|
||||||
|
img_path = f"{IMAGES_PATH}/{type}-{id}-{version}@{n}x{w}.{ext}"
|
||||||
|
if os.path.exists(img_path):
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
if type == 'ad' or type == 'article' or type == 'article_attribute':
|
||||||
|
cursor = get_cursor()
|
||||||
|
count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
|
||||||
|
img = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
if img is None:
|
||||||
|
print('图片不存在:', count)
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
url = img['image']
|
||||||
|
elif type == 'url':
|
||||||
|
id = unquote(id, 'utf-8')
|
||||||
|
id = id.replace(' ','+')
|
||||||
|
url = unquote(base64.b64decode(id))
|
||||||
|
print(url)
|
||||||
|
elif type == 'avatar':
|
||||||
|
cursor = get_cursor()
|
||||||
|
count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
|
||||||
|
user = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
if user is None:
|
||||||
|
print('用户不存在:', count)
|
||||||
|
return Response('用户不存在', status_code=404)
|
||||||
|
url = user['avatar']
|
||||||
|
else:
|
||||||
|
print('图片类型不存在:', type)
|
||||||
|
return Response('图片类型不存在', status_code=404)
|
||||||
|
image = download_image(url)
|
||||||
|
if not image:
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
# 如果是 avatar, 则裁剪为正方形
|
||||||
|
if type == 'avatar':
|
||||||
|
px = image.size[0] if image.size[0] < image.size[1] else image.size[1]
|
||||||
|
image = image.crop((0, 0, px, px))
|
||||||
|
image.thumbnail((n*w, image.size[1]))
|
||||||
|
image.save(img_path, ext, save_all=True)
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
|
||||||
|
|
||||||
|
# 获取非标准类原尺寸图
|
||||||
|
@router.get("/{type}-{id}-{version}.{ext}", summary="获取文章缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
|
||||||
|
def get_image_type(type:str, id:str, version:str, ext:str):
|
||||||
|
img_path = f"{IMAGES_PATH}/{type}-{id}-{version}.{ext}"
|
||||||
|
if os.path.exists(img_path):
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
if type == 'ad' or type == 'article' or type == 'article_attribute':
|
||||||
|
cursor = get_cursor()
|
||||||
|
count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
|
||||||
|
img = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
if img is None:
|
||||||
|
print('图片不存在:', count)
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
url = img['image']
|
||||||
|
elif type == 'url':
|
||||||
|
id = unquote(id, 'utf-8')
|
||||||
|
id = id.replace(' ','+')
|
||||||
|
url = unquote(base64.b64decode(id))
|
||||||
|
print("url:", url)
|
||||||
|
elif type == 'avatar':
|
||||||
|
cursor = get_cursor()
|
||||||
|
count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
|
||||||
|
user = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
if user is None:
|
||||||
|
print('用户不存在:', count)
|
||||||
|
return Response('用户不存在', status_code=404)
|
||||||
|
url = user['avatar']
|
||||||
|
else:
|
||||||
|
print('图片类型不存在:', type)
|
||||||
|
return Response('图片类型不存在', status_code=404)
|
||||||
|
image = download_image(url)
|
||||||
|
image.save(img_path, ext, save_all=True)
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
|
||||||
|
|
||||||
|
# 通过url获取图片
|
||||||
|
@router.get("/url-{url}@{n}x{w}.{ext}", summary="通过url获取图片", description="/img/article-233.webp")
|
||||||
|
def get_image_url(url:str, n:int, w:int, ext:str):
|
||||||
|
img_path = f"{IMAGES_PATH}/{type}-{url}.{ext}"
|
||||||
|
if os.path.exists(img_path):
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
url = unquote(url, 'utf-8').replace(' ','+')
|
||||||
|
url = unquote(base64.b64decode(url))
|
||||||
|
image = download_image(url)
|
||||||
|
if not image:
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
image.thumbnail((n*w, image.size[1]))
|
||||||
|
image.save(img_path, ext, save_all=True)
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
|
||||||
|
|
||||||
|
# 获取标准缩略图(带版本号)
|
||||||
|
@router.get("/{id}-{version}@{n}x{w}.{ext}", summary="获取缩略图(带版本号)", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
|
||||||
|
def get_image_thumbnail(id:int, version:str, n:int, w:int, ext:str):
|
||||||
|
# 判断图片是否已经生成
|
||||||
|
img_path = f"{IMAGES_PATH}/{id}-{version}@{n}x{w}.{ext}"
|
||||||
|
if os.path.exists(img_path):
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
# 从数据库获取原图地址
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
|
||||||
|
img = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
if img is None:
|
||||||
|
print('图片不存在:', id)
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
image = download_image(img['content'])
|
||||||
|
if not image:
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
image.thumbnail((n*w, image.size[1]))
|
||||||
|
image.save(img_path, ext, save_all=True)
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
|
||||||
|
|
||||||
|
# 获取标准缩略图
|
||||||
|
@router.get("/{id}@{n}x{w}.{ext}", summary="获取缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
|
||||||
|
def get_image_thumbnail(id:int, n:int, w:int, ext:str):
|
||||||
|
# 判断图片是否已经生成
|
||||||
|
img_path = f"{IMAGES_PATH}/{id}@{n}x{w}.{ext}"
|
||||||
|
if os.path.exists(img_path):
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
# 从数据库获取原图地址
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
|
||||||
|
img = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
if img is None:
|
||||||
|
print('图片不存在:', id)
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
image = download_image(img['content'])
|
||||||
|
if not image:
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
image.thumbnail((n*w, image.size[1]))
|
||||||
|
image.save(img_path, ext, save_all=True)
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
|
||||||
|
|
||||||
|
# 获取标准原尺寸图
|
||||||
|
@router.get("/{id}.{ext}", summary="获取标准原尺寸图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 无后缀获取原图")
|
||||||
|
def get_image(id: int = 824, ext: str = 'webp'):
|
||||||
|
# 判断图片是否已经生成
|
||||||
|
img_path = f"{IMAGES_PATH}/{id}.{ext}"
|
||||||
|
if os.path.exists(img_path):
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
||||||
|
# 从数据库获取原图地址
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
|
||||||
|
img = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
if img is None:
|
||||||
|
print('图片不存在:', id)
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
image = download_image(img['content'])
|
||||||
|
if not image:
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
image.save(img_path, ext, save_all=True)
|
||||||
|
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
|
231
routers/reverse.py
Normal file
231
routers/reverse.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import sqlite3
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
from towhee import pipe, ops
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
|
||||||
|
from models.milvus import get_collection, collection_name
|
||||||
|
from models.mysql import get_cursor, conn
|
||||||
|
#from models.resnet import Resnet50
|
||||||
|
|
||||||
|
from configs.config import UPLOAD_PATH
|
||||||
|
from utilities.download import download_image
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
#MODEL = Resnet50()
|
||||||
|
|
||||||
|
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
|
||||||
|
|
||||||
|
# 获取状态统计
|
||||||
|
@router.get('', summary='状态统计', description='通过表名获取状态统计')
|
||||||
|
def count_images():
|
||||||
|
collection = get_collection(collection_name)
|
||||||
|
return {'status': True, 'count': collection.num_entities}
|
||||||
|
|
||||||
|
|
||||||
|
# 手动重建索引
|
||||||
|
@router.get('/create_index', summary='重建索引', description='手动重建索引', include_in_schema=False)
|
||||||
|
async def create_index():
|
||||||
|
collection = get_collection(collection_name)
|
||||||
|
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
|
||||||
|
collection.create_index(field_name="embedding", index_params=default_index)
|
||||||
|
collection.load()
|
||||||
|
return {'status': True, 'count': collection.num_entities}
|
||||||
|
|
||||||
|
'''
|
||||||
|
# 批量生成向量
|
||||||
|
@router.get('/create_vector', summary='生成向量', description='手动生成向量', include_in_schema=False)
|
||||||
|
async def create_vector(count: int = 10):
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT id,thumbnail_image,article_id,milvus_id FROM `web_images` WHERE thumbnail_image IS NOT NULL AND article_id IS NOT NULL AND milvus_id != 2048 AND width IS NOT NULL LIMIT 0,{count}")
|
||||||
|
images = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
for item in images:
|
||||||
|
print(item)
|
||||||
|
# 先查询 milvus 中是否存在
|
||||||
|
collection = get_collection(collection_name)
|
||||||
|
data = collection.query(expr=f'id in [{item["id"]}]', output_fields=None, partition_names=None, timeout=None) # offset, limit
|
||||||
|
if len(data) > 0:
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
|
||||||
|
conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
img_path = os.path.join(UPLOAD_PATH, os.path.basename(item['thumbnail_image']))
|
||||||
|
download_image(item['thumbnail_image']).save(img_path, 'png', save_all=True)
|
||||||
|
feat = MODEL.resnet50_extract_feat(img_path)
|
||||||
|
collection.insert([[item['id']], [feat], [item['article_id']]])
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
|
||||||
|
conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print('END')
|
||||||
|
return images
|
||||||
|
'''
|
||||||
|
|
||||||
|
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
|
||||||
|
async def rewrite_image(image_id: int):
|
||||||
|
print('START', image_id, '重建向量')
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
|
||||||
|
img = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
if img is None:
|
||||||
|
print('mysql中原始图片不存在:', image_id)
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
|
||||||
|
print('img_path', img_path)
|
||||||
|
image = download_image(img['content'])
|
||||||
|
if image is None:
|
||||||
|
print('图片下载失败:', img['content'])
|
||||||
|
return Response('图片下载失败', status_code=404)
|
||||||
|
image.save(img_path, 'png', save_all=True)
|
||||||
|
|
||||||
|
with Image.open(img_path) as imgx:
|
||||||
|
feat = RESNET50(imgx).get()[0]
|
||||||
|
|
||||||
|
collection = get_collection(collection_name)
|
||||||
|
collection.delete(expr=f'id in [{image_id}]')
|
||||||
|
rest = collection.insert([[image_id], [feat], [img['article_id']]])
|
||||||
|
os.remove(img_path)
|
||||||
|
print('END', image_id, '重建向量', rest.primary_keys)
|
||||||
|
return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()}
|
||||||
|
|
||||||
|
|
||||||
|
# 获取相似(废弃)
|
||||||
|
@router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片')
|
||||||
|
async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
|
||||||
|
collection = get_collection(collection_name)
|
||||||
|
result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1)
|
||||||
|
# 如果没有结果, 则重新生成记录
|
||||||
|
if len(result) == 0:
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
|
||||||
|
img = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
if img is None:
|
||||||
|
print('mysql 中图片不存在:', image_id)
|
||||||
|
return Response('图片不存在', status_code=404)
|
||||||
|
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
|
||||||
|
image = download_image(img['content'])
|
||||||
|
if image is None:
|
||||||
|
print('图片下载失败:', img['content'])
|
||||||
|
return Response('图片下载失败', status_code=404)
|
||||||
|
image.save(img_path, 'png', save_all=True)
|
||||||
|
with Image.open(img_path) as imgx:
|
||||||
|
feat = RESNET50(imgx).get()[0]
|
||||||
|
# 移除可能存在的旧记录, 换上新的
|
||||||
|
collection.delete(expr=f'id in [{image_id}]')
|
||||||
|
collection.insert([[image_id], [feat], [img['article_id']]])
|
||||||
|
os.remove(img_path)
|
||||||
|
print('生成')
|
||||||
|
else:
|
||||||
|
print('通过')
|
||||||
|
feat = result[0]['embedding']
|
||||||
|
res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0]
|
||||||
|
# 翻页(截取有效范围, page * pageize)
|
||||||
|
ope = page*pageSize-pageSize
|
||||||
|
end = page*pageSize
|
||||||
|
next = False
|
||||||
|
# 为数据附加信息
|
||||||
|
ids = [i.id for i in res] # 获取所有ID
|
||||||
|
if len(res) <= end:
|
||||||
|
ids = ids[ope:]
|
||||||
|
next = False
|
||||||
|
else:
|
||||||
|
ids = ids[ope:end]
|
||||||
|
next = True
|
||||||
|
str_ids = str(ids).replace('[', '').replace(']', '')
|
||||||
|
if str_ids == '':
|
||||||
|
print('没有更多数据了')
|
||||||
|
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})")
|
||||||
|
imgs = cursor.fetchall()
|
||||||
|
if len(imgs) == 0:
|
||||||
|
return imgs
|
||||||
|
# 获取用户ID和文章ID
|
||||||
|
uids = list(set([x['user_id'] for x in imgs]))
|
||||||
|
tids = list(set([x['article_id'] for x in imgs]))
|
||||||
|
# 获取用户信息
|
||||||
|
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
|
||||||
|
users = cursor.fetchall()
|
||||||
|
# 获取文章信息
|
||||||
|
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
|
||||||
|
articles = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
# 合并信息
|
||||||
|
user, article = {}, {}
|
||||||
|
for x in users: user[x['id']] = x
|
||||||
|
for x in articles: article[x['id']] = x
|
||||||
|
for x in imgs:
|
||||||
|
x['article'] = article[x['article_id']]
|
||||||
|
x['user'] = user[x['user_id']]
|
||||||
|
x['distance'] = [i.distance for i in res if i.id == x['id']][0]
|
||||||
|
if x['praise_count'] == None: x['praise_count'] = 0
|
||||||
|
if x['collect_count'] == None: x['collect_count'] = 0
|
||||||
|
# 将字段名转换为驼峰
|
||||||
|
x['createTime'] = x.pop('create_time')
|
||||||
|
x['updateTime'] = x.pop('update_time')
|
||||||
|
# 对 imgs 重新排序(按照 distance 字段)
|
||||||
|
imgs = sorted(imgs, key=lambda x: x['distance'])
|
||||||
|
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(path='', summary='以图搜图', description='上传图片进行搜索')
|
||||||
|
async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize: int = 20):
|
||||||
|
content = await image.read()
|
||||||
|
img = Image.open(image.file)
|
||||||
|
embeddig = RESNET50(img).get()[0]
|
||||||
|
collection = get_collection('default')
|
||||||
|
res = collection.search([embeddig],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=500)[0]
|
||||||
|
ope, end = (page - 1) * pageSize, page * pageSize
|
||||||
|
ids, nextx = [x.id for x in res][ope:end], len(res) > end
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
|
||||||
|
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})")
|
||||||
|
imgs = cursor.fetchall()
|
||||||
|
|
||||||
|
if not imgs:
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs))
|
||||||
|
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
|
||||||
|
users = cursor.fetchall()
|
||||||
|
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
|
||||||
|
articles = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
user, article = {x['id']: x for x in users}, {x['id']: x for x in articles}
|
||||||
|
for x in imgs:
|
||||||
|
x.update({
|
||||||
|
'article': article[x['article_id']],
|
||||||
|
'user': user[x['user_id']],
|
||||||
|
#'distance': next(i.distance for i in res if i.id == x['id'], 0),
|
||||||
|
'distance': [i.distance for i in res if i.id == x['id']][0],
|
||||||
|
'praise_count': x.get('praise_count', 0),
|
||||||
|
'collect_count': x.get('collect_count', 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
imgs.sort(key=lambda x: x['distance'])
|
||||||
|
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs}
|
||||||
|
|
||||||
|
@router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像")
|
||||||
|
async def delete_images(thread_id: str):
|
||||||
|
collection = get_collection(collection_name)
|
||||||
|
collection.delete(expr="article_id in ["+thread_id+"]")
|
||||||
|
collection.load()
|
||||||
|
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
|
||||||
|
collection.create_index(field_name="embedding", index_params=default_index)
|
||||||
|
return {"status": True, 'msg': '删除完毕'}
|
94
routers/task.py
Normal file
94
routers/task.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Request, WebSocket
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
from models.task import TaskForm, TaskManager
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
task_manager = TaskManager()
|
||||||
|
|
||||||
|
|
||||||
|
# 创建新任务
|
||||||
|
@router.post("", summary="创建新任务")
|
||||||
|
def create_task(form: TaskForm):
|
||||||
|
task = task_manager.add_task(name=form.name, user_id=123456, description=form.description)
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
# 获取任务列表
|
||||||
|
@router.get("", summary="获取任务列表", description="可使用user_id参数筛选指定用户的任务列表")
|
||||||
|
def get_task_list(user_id: int=None):
|
||||||
|
return task_manager.get_tasks(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
# websocket demo
|
||||||
|
@router.get("/demo", response_class=HTMLResponse)
|
||||||
|
async def websocket_demo(request: Request):
|
||||||
|
task_list = task_manager.get_tasks()
|
||||||
|
templates = Jinja2Templates(directory="templates")
|
||||||
|
return templates.TemplateResponse("websocket.html", {"request": request, "task_list": task_list})
|
||||||
|
|
||||||
|
|
||||||
|
# 监听指定任务的变化事件, 通知前端(不使用pydantic模型, 正确的写法)
|
||||||
|
@router.websocket("/{task_id}", name="监听任务变化")
|
||||||
|
async def websocket_endpoint(task_id: str, websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
await task_manager.add_websocket(task_id, websocket)
|
||||||
|
async for data in websocket.iter_text():
|
||||||
|
await websocket.send_text(f"Message text was: {data}")
|
||||||
|
task_manager.remove_websocket(task_id, websocket)
|
||||||
|
print("websocket 连接已自动关闭")
|
||||||
|
|
||||||
|
#await websocket.send_json({"message": "Hello WebSocket!"})
|
||||||
|
#task = task_manager.get_task(task_id)
|
||||||
|
#print(task)
|
||||||
|
#if not task:
|
||||||
|
# print("task 不存在, 结束连接")
|
||||||
|
# return await websocket.close() # 任务不存在, 结束连接
|
||||||
|
#await websocket.send_json(task) # 将任务的状态发送给客户端
|
||||||
|
#await task_manager.add_websocket(task_id, websocket)
|
||||||
|
# 正确的写法, 使用 async for, 并且处理意外断开的情况
|
||||||
|
#try:
|
||||||
|
# async for data in websocket.iter_text():
|
||||||
|
# if data == "close":
|
||||||
|
# print("客户端主动关闭连接")
|
||||||
|
# task_manager.remove_websocket(task_id, websocket)
|
||||||
|
# await websocket.close()
|
||||||
|
# break
|
||||||
|
# else:
|
||||||
|
# print(f"接收到客户端消息: {data}")
|
||||||
|
# await websocket.send_text(f"Message text was: {data}")
|
||||||
|
#except Exception as e:
|
||||||
|
# print(f"客户端意外断开连接: {e}")
|
||||||
|
# task_manager.remove_websocket(task_id, websocket)
|
||||||
|
# #await websocket.close()
|
||||||
|
|
||||||
|
# 监听客户端的状态, 如果客户端主动关闭连接或意外断开连接, 都从任务的websocket列表中移除
|
||||||
|
#while True:
|
||||||
|
# try:
|
||||||
|
# data = await websocket.receive_text()
|
||||||
|
# if data == "close":
|
||||||
|
# print("客户端主动关闭连接")
|
||||||
|
# task_manager.remove_websocket(task_id, websocket)
|
||||||
|
# await websocket.close()
|
||||||
|
# break
|
||||||
|
# else:
|
||||||
|
# print(f"接收到客户端消息: {data}")
|
||||||
|
# await websocket.send_text(f"Message text was: {data}")
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"客户端意外断开连接: {e}")
|
||||||
|
# task_manager.remove_websocket(task_id, websocket)
|
||||||
|
# #await websocket.close()
|
||||||
|
# break
|
||||||
|
|
||||||
|
|
||||||
|
# 获取任务详情
|
||||||
|
@router.get("/{task_id}", summary="获取任务详情")
|
||||||
|
def get_task(task_id: int):
|
||||||
|
return get_task(task_id)
|
||||||
|
|
||||||
|
|
||||||
|
# 删除任务
|
||||||
|
@router.delete("/{task_id}", summary="删除指定任务")
|
||||||
|
def delete_task(task_id: int):
|
||||||
|
return delete_task(task_id)
|
19
routers/user.py
Normal file
19
routers/user.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from models.mysql import conn, get_cursor
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get('/{user_id}/collect', summary="用户收藏记录", description="获取指定用户收藏记录")
|
||||||
|
def get_user_collect(user_id:int):
|
||||||
|
# TODO: 需要验证权限
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'")
|
||||||
|
data = cursor.fetchall()
|
||||||
|
if not data:
|
||||||
|
return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []}
|
||||||
|
data = [str(item['content']) for item in data]
|
||||||
|
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
|
||||||
|
data = cursor.fetchall()
|
||||||
|
data = [str(item['id']) for item in data]
|
||||||
|
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data }
|
||||||
|
|
45
routers/user_collect.py
Normal file
45
routers/user_collect.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, HTTPException, Header
|
||||||
|
from models.mysql import conn, get_cursor
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# 获取当前用户收藏记录
|
||||||
|
@router.get('', summary='自己的收藏记录', description='获取自己的收藏记录, 用于判断是否收藏(headers中必须附带token)')
|
||||||
|
def get_self_collect(token: Optional[str] = Header()):
|
||||||
|
print('token: ', token)
|
||||||
|
cursor = get_cursor()
|
||||||
|
# 查询用户ID
|
||||||
|
cursor.execute(f"SELECT user_id FROM web_auth WHERE token={token} limit 1")
|
||||||
|
data = cursor.fetchone()
|
||||||
|
print('auth: ', data)
|
||||||
|
if not data:
|
||||||
|
raise HTTPException(status_code=401, detail="用户未登录")
|
||||||
|
user_id = data['user_id']
|
||||||
|
# 查询收藏记录
|
||||||
|
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'")
|
||||||
|
data = cursor.fetchall() # 获取所有记录列表
|
||||||
|
data = [str(item['content']) for item in data] # 转换为数组
|
||||||
|
# 查询图片ID(对特殊字符安全转义)
|
||||||
|
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
|
||||||
|
data = cursor.fetchall()
|
||||||
|
data = [str(item['id']) for item in data]
|
||||||
|
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data }
|
||||||
|
|
||||||
|
|
||||||
|
# 获取指定用户收藏记录
|
||||||
|
@router.get('/{user_id}', summary='指定用的户收藏记录', description='获取指定用户收藏记录(仅测试用)')
|
||||||
|
def get_user_collect(user_id: int):
|
||||||
|
cursor = get_cursor()
|
||||||
|
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type=1")
|
||||||
|
data = cursor.fetchall() # 获取所有记录列表
|
||||||
|
data = [str(item['content']) for item in data] # 转换为数组
|
||||||
|
if not data:
|
||||||
|
return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []}
|
||||||
|
# 查询图片ID(对特殊字符安全转义)
|
||||||
|
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
|
||||||
|
data = cursor.fetchall()
|
||||||
|
data = [str(item['id']) for item in data]
|
||||||
|
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data}
|
25
start.sh
Executable file
25
start.sh
Executable file
@@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# pip freeze > requirements.txt
|
||||||
|
#pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
#python3 -m pip install -U pyOpenSSL cryptography
|
||||||
|
|
||||||
|
# 使用 uvicorn 启动两个工作进程, 并支持自动重载
|
||||||
|
#uvicorn main:app --host 0.0.0.0 --port 5001 --workers 2 --reload >/dev/null 2>&1 &
|
||||||
|
#uvicorn main:app --host 0.0.0.0 --port 5002 >/dev/null 2>&1 &
|
||||||
|
|
||||||
|
python main.py >/dev/null 2>&1 &
|
||||||
|
|
||||||
|
# 为容器设定自动重启
|
||||||
|
#docker update --restart=always 51bb94aa2726
|
||||||
|
|
||||||
|
#cd ~/reverse_image_search
|
||||||
|
#pm2 start python3 --name reverse-1 -- main.py 5001
|
||||||
|
#pm2 start python3 --name reverse-2 -- main.py 5002
|
||||||
|
#pm2 start python3 --name reverse-3 -- main.py 5003
|
||||||
|
|
||||||
|
# 某些系统下,需要手动安装指定版本 opencv-python (4.7引用zlib错误)
|
||||||
|
# pip install opencv-python==4.6.0.66 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
# 如果网络不通, 从自定义服务器下载模型参数
|
||||||
|
#scp root@172.21.216.33:~/root/.cache/torch/hub/checkpoints/resnet50_a1_0-14fe96d1.pth resnet50_a1_0-14fe96d1.pth
|
44
templates/websocket.html
Normal file
44
templates/websocket.html
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>WebSocket Demo</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div>
|
||||||
|
<p>{{ task_list }}</p>
|
||||||
|
<button id="create_task">Create Task</button>
|
||||||
|
</div>
|
||||||
|
<h1>WebSocket Demo</h1>
|
||||||
|
<div id="message"></div>
|
||||||
|
<button id="create">Create WebSocket</button>
|
||||||
|
<button id="close">Close WebSocket</button>
|
||||||
|
<script>
|
||||||
|
function createWebSocket(id) {
|
||||||
|
var ws = new WebSocket("ws://localhost:5001/api/task/"+id);
|
||||||
|
ws.onmessage = function(event) {
|
||||||
|
var message = event.data;
|
||||||
|
document.getElementById("message").innerHTML += message + "<br>";
|
||||||
|
};
|
||||||
|
ws.onclose = function(event) {
|
||||||
|
document.getElementById("message").innerHTML += "WebSocket closed";
|
||||||
|
};
|
||||||
|
return ws;
|
||||||
|
}
|
||||||
|
document.getElementById("create_task").onclick = function() {
|
||||||
|
fetch("/api/task", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
"name": "Task 1",
|
||||||
|
"description": "Task 1 description"
|
||||||
|
})
|
||||||
|
}).then(res => res.json()).then(data => {
|
||||||
|
console.log(data)
|
||||||
|
var ws = createWebSocket(data.id);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
4
update.sh
Executable file
4
update.sh
Executable file
@@ -0,0 +1,4 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
|
||||||
|
# 服务器跟随更新
|
||||||
|
ssh ai "cd ~/reverse_image_search_gpu; git pull;"
|
41
utilities/download.py
Normal file
41
utilities/download.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import io
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from models.oss import bucket_image2, bucket_webp
|
||||||
|
from configs.config import IMAGES_PATH
|
||||||
|
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
# 下载图片(使用OSS下载)
|
||||||
|
def download_image(url:str) -> Image:
|
||||||
|
if url.startswith('http://image.gameuiux.cn/') or url.startswith('https://image.gameuiux.cn/'):
|
||||||
|
try:
|
||||||
|
url = url.replace('http://image.gameuiux.cn/', '').replace('https://image.gameuiux.cn/', '')
|
||||||
|
obj = bucket_image2.get_object(url).read()
|
||||||
|
return Image.open(io.BytesIO(obj))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
return Image.open(io.BytesIO(response.content))
|
||||||
|
except Exception:
|
||||||
|
print('图片下载失败:', url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 生成缩略图, 写入OSS
|
||||||
|
def generate_thumbnail(image:Image, id:int, version:str, n:int, w:int, ext:str):
|
||||||
|
path = f"{id}-{version}@{n}x{w}.{ext}"
|
||||||
|
if bucket_webp.object_exists(path):
|
||||||
|
print('缩略图已经存在:', path)
|
||||||
|
return
|
||||||
|
# 将 image 对象复制一份, 防止影响原图
|
||||||
|
image = image.copy()
|
||||||
|
image.thumbnail((n*w, image.size[1]))
|
||||||
|
image.save(f"{IMAGES_PATH}/{path}", ext, save_all=True)
|
||||||
|
bucket_webp.put_object_from_file(path, f"{IMAGES_PATH}/{path}")
|
||||||
|
os.remove(f"{IMAGES_PATH}/{path}")
|
||||||
|
print('缩略图写入 OSS:', path)
|
||||||
|
|
Reference in New Issue
Block a user