替换连接池

This commit is contained in:
2024-11-10 22:02:45 +08:00
parent d2777882c5
commit bdd08285e7
5 changed files with 351 additions and 345 deletions

View File

@@ -1,27 +1,47 @@
import pymysql import pymysql
from configs.config import MYSQL_HOST, MYSQL_PORT, MYSQL_NAME, MYSQL_USER, MYSQL_PASS from configs.config import MYSQL_HOST, MYSQL_PORT, MYSQL_NAME, MYSQL_USER, MYSQL_PASS
from dbutils.pooled_db import PooledDB
# 创建 MySQL 连接
def create_connection():
return pymysql.connect(
host=MYSQL_HOST,
user=MYSQL_USER,
port=MYSQL_PORT, # 应该使用 MYSQL_PORT 而不是 MYSQL_HOST
password=MYSQL_PASS,
database=MYSQL_NAME,
local_infile=True,
cursorclass=pymysql.cursors.DictCursor
)
# 连接 MySQL (开启 MySQL 服务) # 创建数据库连接池
conn = create_connection() pool = PooledDB(
creator=pymysql, # 使用 pymysql 作为数据库驱动
maxconnections=20, # 最大连接数
mincached=2, # 初始化时,连接池中至少创建的空闲连接
maxcached=5, # 连接池中最多空闲连接数
blocking=True, # 超过最大连接数时,是否阻塞
maxusage=None, # 单个连接的最大复用次数
ping=0, # 设置连接是否检查
host=MYSQL_HOST,
port=MYSQL_PORT,
user=MYSQL_USER,
password=MYSQL_PASS,
database=MYSQL_NAME,
charset='utf8mb4'
)
# 获取 MySQL 连接 ## 创建 MySQL 连接
def get_cursor(): #def create_connection():
global conn # return pymysql.connect(
try: # host=MYSQL_HOST,
conn.ping() # user=MYSQL_USER,
return conn.cursor() # port=MYSQL_PORT, # 应该使用 MYSQL_PORT 而不是 MYSQL_HOST
except Exception: # password=MYSQL_PASS,
conn = create_connection() # database=MYSQL_NAME,
return conn.cursor() # local_infile=True,
# cursorclass=pymysql.cursors.DictCursor
# )
#
## 连接 MySQL (开启 MySQL 服务)
#conn = create_connection()
#
## 获取 MySQL 连接
#def get_cursor():
# global conn
# try:
# conn.ping()
# return conn.cursor()
# except Exception:
# conn = create_connection()
# return conn.cursor()
#

View File

@@ -8,7 +8,7 @@ import _thread as thread
from fastapi import APIRouter, HTTPException, Response from fastapi import APIRouter, HTTPException, Response
from urllib.parse import unquote from urllib.parse import unquote
from configs.config import IMAGES_PATH from configs.config import IMAGES_PATH
from models.mysql import get_cursor, conn from models.mysql import pool
from utilities.download import download_image, generate_thumbnail from utilities.download import download_image, generate_thumbnail
@@ -18,116 +18,109 @@ router = APIRouter()
# 预热图片(获取一次图片, 遍历图片表, 检查OSS中所有被预定的尺寸是否存在, 不存在则生成) # 预热图片(获取一次图片, 遍历图片表, 检查OSS中所有被预定的尺寸是否存在, 不存在则生成)
@router.get("/warm", summary="预热图片", description="预热图片") @router.get("/warm", summary="预热图片", description="预热图片")
def warm_image(op:int=0, end:int=10, version:str='0'): def warm_image(op:int=0, end:int=10, version:str='0'):
cursor = get_cursor() with pool.connection() as conn:
cursor.execute(f"SELECT * FROM `web_images` limit {op}, {end}") with conn.cursor() as cursor:
for img in cursor.fetchall(): cursor.execute(f"SELECT * FROM `web_images` limit {op}, {end}")
# 如果CPU使用率大于50%, 则等待, 直到CPU使用率小于50% for img in cursor.fetchall():
while statistics.mean(psutil.cpu_percent(interval=1, percpu=True)) > 50: # 如果CPU使用率大于50%, 则等待, 直到CPU使用率小于50%
print(statistics.mean(psutil.cpu_percent(interval=1, percpu=True)), '等待CPU释放...') while statistics.mean(psutil.cpu_percent(interval=1, percpu=True)) > 50:
time.sleep(2) print(statistics.mean(psutil.cpu_percent(interval=1, percpu=True)), '等待CPU释放...')
time.sleep(2)
# 如果内存剩余小于1G, 则等待, 直到内存剩余大于1G # 如果内存剩余小于1G, 则等待, 直到内存剩余大于1G
while psutil.virtual_memory().available < 1024 * 1024 * 1024: while psutil.virtual_memory().available < 1024 * 1024 * 1024:
print(psutil.virtual_memory().available, '等待内存释放...') print(psutil.virtual_memory().available, '等待内存释放...')
time.sleep(2) time.sleep(2)
# CPU使用率已降低, 开始处理图片
# CPU使用率已降低, 开始处理图片 image = download_image(img['content']) # 从OSS下载原图
image = download_image(img['content']) # 从OSS下载原图 if not image:
if not image: print('跳过不存在的图片:', img['content'])
print('跳过不存在的图片:', img['content']) continue
continue # 创建新线程处理图片
try:
# 创建新线程处理图片 print('开始处理图片:', img['content'])
try: thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 3, 328, 'webp'))
print('开始处理图片:', img['content']) thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 2, 328, 'webp'))
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 3, 328, 'webp')) thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 1, 328, 'webp'))
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 2, 328, 'webp')) except:
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 1, 328, 'webp')) print('无法启动线程')
except: return Response('预热成功', status_code=200, media_type='text/plain', headers={'Content-Type': 'text/plain; charset=utf-8'})
print('无法启动线程')
cursor.close()
return Response('预热成功', status_code=200, media_type='text/plain', headers={'Content-Type': 'text/plain; charset=utf-8'})
# 获取非标准类缩略图 # 获取非标准类缩略图
@router.get("/{type}-{id}-{version}@{n}x{w}.{ext}", summary="获取非标准类缩略图", description="/img/article-233-version@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") @router.get("/{type}-{id}-{version}@{n}x{w}.{ext}", summary="获取非标准类缩略图", description="/img/article-233-version@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:str): def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:str):
img_path = f"{IMAGES_PATH}/{type}-{id}-{version}@{n}x{w}.{ext}" with pool.connection() as conn:
if os.path.exists(img_path): with conn.cursor() as cursor:
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") img_path = f"{IMAGES_PATH}/{type}-{id}-{version}@{n}x{w}.{ext}"
if type == 'ad' or type == 'article' or type == 'article_attribute': if os.path.exists(img_path):
cursor = get_cursor() return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}") if type == 'ad' or type == 'article' or type == 'article_attribute':
img = cursor.fetchone() count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
cursor.close() img = cursor.fetchone()
if img is None: if img is None:
print('图片不存在:', count) print('图片不存在:', count)
return Response('图片不存在', status_code=404) return Response('图片不存在', status_code=404)
url = img['image'] url = img['image']
elif type == 'url': elif type == 'url':
id = unquote(id, 'utf-8') id = unquote(id, 'utf-8')
id = id.replace(' ','+') id = id.replace(' ','+')
url = unquote(base64.b64decode(id)) url = unquote(base64.b64decode(id))
print(url) print(url)
elif type == 'avatar': elif type == 'avatar':
cursor = get_cursor() count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}") user = cursor.fetchone()
user = cursor.fetchone() if user is None:
cursor.close() print('用户不存在:', count)
if user is None: return Response('用户不存在', status_code=404)
print('用户不存在:', count) url = user['avatar']
return Response('用户不存在', status_code=404) else:
url = user['avatar'] print('图片类型不存在:', type)
else: return Response('图片类型不存在', status_code=404)
print('图片类型不存在:', type) image = download_image(url)
return Response('图片类型不存在', status_code=404) if not image:
image = download_image(url) return Response('图片不存在', status_code=404)
if not image: # 如果是 avatar, 则裁剪为正方形
return Response('图片不存在', status_code=404) if type == 'avatar':
# 如果是 avatar, 则裁剪为正方形 px = image.size[0] if image.size[0] < image.size[1] else image.size[1]
if type == 'avatar': image = image.crop((0, 0, px, px))
px = image.size[0] if image.size[0] < image.size[1] else image.size[1] image.thumbnail((n*w, image.size[1]))
image = image.crop((0, 0, px, px)) image.save(img_path, ext, save_all=True)
image.thumbnail((n*w, image.size[1])) return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取非标准类原尺寸图 # 获取非标准类原尺寸图
@router.get("/{type}-{id}-{version}.{ext}", summary="获取文章缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") @router.get("/{type}-{id}-{version}.{ext}", summary="获取文章缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_type(type:str, id:str, version:str, ext:str): def get_image_type(type:str, id:str, version:str, ext:str):
img_path = f"{IMAGES_PATH}/{type}-{id}-{version}.{ext}" with pool.connection() as conn:
if os.path.exists(img_path): with conn.cursor() as cursor:
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") img_path = f"{IMAGES_PATH}/{type}-{id}-{version}.{ext}"
if type == 'ad' or type == 'article' or type == 'article_attribute': if os.path.exists(img_path):
cursor = get_cursor() return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}") if type == 'ad' or type == 'article' or type == 'article_attribute':
img = cursor.fetchone() count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
cursor.close() img = cursor.fetchone()
if img is None: if img is None:
print('图片不存在:', count) print('图片不存在:', count)
return Response('图片不存在', status_code=404) return Response('图片不存在', status_code=404)
url = img['image'] url = img['image']
elif type == 'url': elif type == 'url':
id = unquote(id, 'utf-8') id = unquote(id, 'utf-8')
id = id.replace(' ','+') id = id.replace(' ','+')
url = unquote(base64.b64decode(id)) url = unquote(base64.b64decode(id))
print("url:", url) print("url:", url)
elif type == 'avatar': elif type == 'avatar':
cursor = get_cursor() count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}") user = cursor.fetchone()
user = cursor.fetchone() if user is None:
cursor.close() print('用户不存在:', count)
if user is None: return Response('用户不存在', status_code=404)
print('用户不存在:', count) url = user['avatar']
return Response('用户不存在', status_code=404) else:
url = user['avatar'] print('图片类型不存在:', type)
else: return Response('图片类型不存在', status_code=404)
print('图片类型不存在:', type) image = download_image(url)
return Response('图片类型不存在', status_code=404) image.save(img_path, ext, save_all=True)
image = download_image(url) return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 通过url获取图片 # 通过url获取图片
@@ -149,66 +142,66 @@ def get_image_url(url:str, n:int, w:int, ext:str):
# 获取标准缩略图(带版本号) # 获取标准缩略图(带版本号)
@router.get("/{id}-{version}@{n}x{w}.{ext}", summary="获取缩略图(带版本号)", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") @router.get("/{id}-{version}@{n}x{w}.{ext}", summary="获取缩略图(带版本号)", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_thumbnail(id:int, version:str, n:int, w:int, ext:str): def get_image_thumbnail(id:int, version:str, n:int, w:int, ext:str):
# 判断图片是否已经生成 with pool.connection() as conn:
img_path = f"{IMAGES_PATH}/{id}-{version}@{n}x{w}.{ext}" with conn.cursor() as cursor:
if os.path.exists(img_path): # 判断图片是否已经生成
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") img_path = f"{IMAGES_PATH}/{id}-{version}@{n}x{w}.{ext}"
# 从数据库获取原图地址 if os.path.exists(img_path):
cursor = get_cursor() return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") # 从数据库获取原图地址
img = cursor.fetchone() cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
cursor.close() img = cursor.fetchone()
if img is None: if img is None:
print('图片不存在:', id) print('图片不存在:', id)
return Response('图片不存在', status_code=404) return Response('图片不存在', status_code=404)
image = download_image(img['content']) image = download_image(img['content'])
if not image: if not image:
return Response('图片不存在', status_code=404) return Response('图片不存在', status_code=404)
image.thumbnail((n*w, image.size[1])) image.thumbnail((n*w, image.size[1]))
image.save(img_path, ext, save_all=True) image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取标准缩略图 # 获取标准缩略图
@router.get("/{id}@{n}x{w}.{ext}", summary="获取缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") @router.get("/{id}@{n}x{w}.{ext}", summary="获取缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_thumbnail(id:int, n:int, w:int, ext:str): def get_image_thumbnail(id:int, n:int, w:int, ext:str):
# 判断图片是否已经生成 with pool.connection() as conn:
img_path = f"{IMAGES_PATH}/{id}@{n}x{w}.{ext}" with conn.cursor() as cursor:
if os.path.exists(img_path): # 判断图片是否已经生成
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") img_path = f"{IMAGES_PATH}/{id}@{n}x{w}.{ext}"
# 从数据库获取原图地址 if os.path.exists(img_path):
cursor = get_cursor() return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") # 从数据库获取原图地址
img = cursor.fetchone() cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
cursor.close() img = cursor.fetchone()
if img is None: if img is None:
print('图片不存在:', id) print('图片不存在:', id)
return Response('图片不存在', status_code=404) return Response('图片不存在', status_code=404)
image = download_image(img['content']) image = download_image(img['content'])
if not image: if not image:
return Response('图片不存在', status_code=404) return Response('图片不存在', status_code=404)
image.thumbnail((n*w, image.size[1])) image.thumbnail((n*w, image.size[1]))
image.save(img_path, ext, save_all=True) image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取标准原尺寸图 # 获取标准原尺寸图
@router.get("/{id}.{ext}", summary="获取标准原尺寸图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 无后缀获取原图") @router.get("/{id}.{ext}", summary="获取标准原尺寸图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 无后缀获取原图")
def get_image(id: int = 824, ext: str = 'webp'): def get_image(id: int = 824, ext: str = 'webp'):
# 判断图片是否已经生成 with pool.connection() as conn:
img_path = f"{IMAGES_PATH}/{id}.{ext}" with conn.cursor() as cursor:
if os.path.exists(img_path): # 判断图片是否已经生成
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") img_path = f"{IMAGES_PATH}/{id}.{ext}"
# 从数据库获取原图地址 if os.path.exists(img_path):
cursor = get_cursor() return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") # 从数据库获取原图地址
img = cursor.fetchone() cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
cursor.close() img = cursor.fetchone()
if img is None: if img is None:
print('图片不存在:', id) print('图片不存在:', id)
return Response('图片不存在', status_code=404) return Response('图片不存在', status_code=404)
image = download_image(img['content']) image = download_image(img['content'])
if not image: if not image:
return Response('图片不存在', status_code=404) return Response('图片不存在', status_code=404)
image.save(img_path, ext, save_all=True) image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")

View File

@@ -10,7 +10,7 @@ from towhee import pipe, ops
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
from models.milvus import get_collection, collection_name from models.milvus import get_collection, collection_name
from models.mysql import get_cursor, conn from models.mysql import pool
from configs.config import UPLOAD_PATH from configs.config import UPLOAD_PATH
from utilities.download import download_image from utilities.download import download_image
@@ -39,111 +39,108 @@ async def create_index():
# 重建指定图像的向量 # 重建指定图像的向量
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量') @router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
async def rewrite_image(image_id: int): async def rewrite_image(image_id: int):
print('START', image_id, '重建向量') with pool.connection() as conn:
cursor = get_cursor() with conn.cursor() as cursor:
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}") print('START', image_id, '重建向量')
img = cursor.fetchone() cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
cursor.close() img = cursor.fetchone()
if img is None: if img is None:
print('mysql中原始图片不存在:', image_id) print('mysql中原始图片不存在:', image_id)
return Response('图片不存在', status_code=404) return Response('图片不存在', status_code=404)
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content'])) img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
print('img_path', img_path) print('img_path', img_path)
image = download_image(img['content']) image = download_image(img['content'])
if image is None: if image is None:
print('图片下载失败:', img['content']) print('图片下载失败:', img['content'])
return Response('图片下载失败', status_code=404) return Response('图片下载失败', status_code=404)
image.save(img_path, 'png', save_all=True) image.save(img_path, 'png', save_all=True)
with Image.open(img_path) as imgx:
with Image.open(img_path) as imgx: feat = RESNET50(imgx).get()[0]
feat = RESNET50(imgx).get()[0] collection = get_collection(collection_name)
collection.delete(expr=f'id in [{image_id}]')
collection = get_collection(collection_name) rest = collection.insert([[image_id], [feat], [img['article_id']]])
collection.delete(expr=f'id in [{image_id}]') os.remove(img_path)
rest = collection.insert([[image_id], [feat], [img['article_id']]]) print('END', image_id, '重建向量', rest.primary_keys)
os.remove(img_path) return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()}
print('END', image_id, '重建向量', rest.primary_keys)
return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()}
# 获取相似(废弃) # 获取相似(废弃)
@router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片') @router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片')
async def similar_images(image_id: int, page: int = 1, pageSize: int = 20): async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
collection = get_collection(collection_name) with pool.connection() as conn:
result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1) collection = get_collection(collection_name)
# 如果没有结果, 则重新生成记录 result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1)
if len(result) == 0: # 如果没有结果, 则重新生成记录
cursor = get_cursor() if len(result) == 0:
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}") with conn.cursor() as cursor:
img = cursor.fetchone() cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
cursor.close() img = cursor.fetchone()
if img is None: if img is None:
print('mysql 中图片不存在:', image_id) print('mysql 中图片不存在:', image_id)
return Response('图片不存在', status_code=404) return Response('图片不存在', status_code=404)
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content'])) img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
image = download_image(img['content']) image = download_image(img['content'])
if image is None: if image is None:
print('图片下载失败:', img['content']) print('图片下载失败:', img['content'])
return Response('图片下载失败', status_code=404) return Response('图片下载失败', status_code=404)
image.save(img_path, 'png', save_all=True) image.save(img_path, 'png', save_all=True)
with Image.open(img_path) as imgx: with Image.open(img_path) as imgx:
feat = RESNET50(imgx).get()[0] feat = RESNET50(imgx).get()[0]
# 移除可能存在的旧记录, 换上新的 # 移除可能存在的旧记录, 换上新的
collection.delete(expr=f'id in [{image_id}]') collection.delete(expr=f'id in [{image_id}]')
collection.insert([[image_id], [feat], [img['article_id']]]) collection.insert([[image_id], [feat], [img['article_id']]])
os.remove(img_path) os.remove(img_path)
print('生成') print('生成')
else: else:
print('通过') print('通过')
feat = result[0]['embedding'] feat = result[0]['embedding']
res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0] res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0]
# 翻页(截取有效范围, page * pageize) # 翻页(截取有效范围, page * pageize)
ope = page*pageSize-pageSize ope = page*pageSize-pageSize
end = page*pageSize end = page*pageSize
next = False
# 为数据附加信息
ids = [i.id for i in res] # 获取所有ID
if len(res) <= end:
ids = ids[ope:]
next = False next = False
else: # 为数据附加信息
ids = ids[ope:end] ids = [i.id for i in res] # 获取所有ID
next = True if len(res) <= end:
str_ids = str(ids).replace('[', '').replace(']', '') ids = ids[ope:]
if str_ids == '': next = False
print('没有更多数据了') else:
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []} ids = ids[ope:end]
cursor = get_cursor() next = True
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})") str_ids = str(ids).replace('[', '').replace(']', '')
imgs = cursor.fetchall() if str_ids == '':
if len(imgs) == 0: print('没有更多数据了')
return imgs return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
# 获取用户ID和文章ID with conn.cursor() as cursor:
uids = list(set([x['user_id'] for x in imgs])) cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})")
tids = list(set([x['article_id'] for x in imgs])) imgs = cursor.fetchall()
# 获取用户信息 if len(imgs) == 0:
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})") return imgs
users = cursor.fetchall() # 获取用户ID和文章ID
# 获取文章信息 uids = list(set([x['user_id'] for x in imgs]))
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})") tids = list(set([x['article_id'] for x in imgs]))
articles = cursor.fetchall() # 获取用户信息
cursor.close() cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
# 合并信息 users = cursor.fetchall()
user, article = {}, {} # 获取文章信息
for x in users: user[x['id']] = x cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
for x in articles: article[x['id']] = x articles = cursor.fetchall()
for x in imgs: # 合并信息
x['article'] = article[x['article_id']] user, article = {}, {}
x['user'] = user[x['user_id']] for x in users: user[x['id']] = x
x['distance'] = [i.distance for i in res if i.id == x['id']][0] for x in articles: article[x['id']] = x
if x['praise_count'] == None: x['praise_count'] = 0 for x in imgs:
if x['collect_count'] == None: x['collect_count'] = 0 x['article'] = article[x['article_id']]
# 将字段名转换为驼峰 x['user'] = user[x['user_id']]
x['createTime'] = x.pop('create_time') x['distance'] = [i.distance for i in res if i.id == x['id']][0]
x['updateTime'] = x.pop('update_time') if x['praise_count'] == None: x['praise_count'] = 0
# 对 imgs 重新排序(按照 distance 字段) if x['collect_count'] == None: x['collect_count'] = 0
imgs = sorted(imgs, key=lambda x: x['distance']) # 将字段名转换为驼峰
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs} x['createTime'] = x.pop('create_time')
x['updateTime'] = x.pop('update_time')
# 对 imgs 重新排序(按照 distance 字段)
imgs = sorted(imgs, key=lambda x: x['distance'])
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs}
@router.post(path='', summary='以图搜图', description='上传图片进行搜索') @router.post(path='', summary='以图搜图', description='上传图片进行搜索')
@@ -159,33 +156,28 @@ async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize:
if not ids: if not ids:
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []} return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
cursor = get_cursor() with pool.connection() as conn:
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})") with conn.cursor() as cursor:
imgs = cursor.fetchall() cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})")
imgs = cursor.fetchall()
if not imgs: if not imgs:
return imgs return imgs
uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs))
uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs)) cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})") users = cursor.fetchall()
users = cursor.fetchall() cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})") articles = cursor.fetchall()
articles = cursor.fetchall() user, article = {x['id']: x for x in users}, {x['id']: x for x in articles}
cursor.close() for x in imgs:
x.update({
user, article = {x['id']: x for x in users}, {x['id']: x for x in articles} 'article': article[x['article_id']],
for x in imgs: 'user': user[x['user_id']],
x.update({ 'distance': [i.distance for i in res if i.id == x['id']][0],
'article': article[x['article_id']], 'praise_count': x.get('praise_count', 0),
'user': user[x['user_id']], 'collect_count': x.get('collect_count', 0)
#'distance': next(i.distance for i in res if i.id == x['id'], 0), })
'distance': [i.distance for i in res if i.id == x['id']][0], imgs.sort(key=lambda x: x['distance'])
'praise_count': x.get('praise_count', 0), return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs}
'collect_count': x.get('collect_count', 0)
})
imgs.sort(key=lambda x: x['distance'])
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs}
@router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像") @router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像")
async def delete_images(thread_id: str): async def delete_images(thread_id: str):

View File

@@ -1,19 +1,19 @@
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from models.mysql import conn, get_cursor from models.mysql import pool
router = APIRouter() router = APIRouter()
@router.get('/{user_id}/collect', summary="用户收藏记录", description="获取指定用户收藏记录") @router.get('/{user_id}/collect', summary="用户收藏记录", description="获取指定用户收藏记录")
def get_user_collect(user_id:int): def get_user_collect(user_id:int):
# TODO: 需要验证权限 with pool.connection() as conn:
cursor = get_cursor() with conn.cursor() as cursor:
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'") cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'")
data = cursor.fetchall() data = cursor.fetchall()
if not data: if not data:
return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []} return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []}
data = [str(item['content']) for item in data] data = [str(item['content']) for item in data]
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
data = cursor.fetchall() data = cursor.fetchall()
data = [str(item['id']) for item in data] data = [str(item['id']) for item in data]
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data } return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data }

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from fastapi import APIRouter, HTTPException, Header from fastapi import APIRouter, HTTPException, Header
from models.mysql import conn, get_cursor from models.mysql import pool
router = APIRouter() router = APIRouter()
@@ -9,37 +9,38 @@ router = APIRouter()
# 获取当前用户收藏记录 # 获取当前用户收藏记录
@router.get('', summary='自己的收藏记录', description='获取自己的收藏记录, 用于判断是否收藏(headers中必须附带token)') @router.get('', summary='自己的收藏记录', description='获取自己的收藏记录, 用于判断是否收藏(headers中必须附带token)')
def get_self_collect(token: Optional[str] = Header()): def get_self_collect(token: Optional[str] = Header()):
print('token: ', token) with pool.connection() as conn:
cursor = get_cursor() with conn.cursor() as cursor:
# 查询用户ID # 查询用户ID
cursor.execute(f"SELECT user_id FROM web_auth WHERE token={token} limit 1") cursor.execute(f"SELECT user_id FROM web_auth WHERE token={token} limit 1")
data = cursor.fetchone() data = cursor.fetchone()
print('auth: ', data) print('auth: ', data)
if not data: if not data:
raise HTTPException(status_code=401, detail="用户未登录") raise HTTPException(status_code=401, detail="用户未登录")
user_id = data['user_id'] user_id = data['user_id']
# 查询收藏记录 # 查询收藏记录
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'") cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'")
data = cursor.fetchall() # 获取所有记录列表 data = cursor.fetchall() # 获取所有记录列表
data = [str(item['content']) for item in data] # 转换为数组 data = [str(item['content']) for item in data] # 转换为数组
# 查询图片ID(对特殊字符安全转义) # 查询图片ID(对特殊字符安全转义)
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
data = cursor.fetchall() data = cursor.fetchall()
data = [str(item['id']) for item in data] data = [str(item['id']) for item in data]
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data } return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data }
# 获取指定用户收藏记录 # 获取指定用户收藏记录
@router.get('/{user_id}', summary='指定用的户收藏记录', description='获取指定用户收藏记录(仅测试用)') @router.get('/{user_id}', summary='指定用的户收藏记录', description='获取指定用户收藏记录(仅测试用)')
def get_user_collect(user_id: int): def get_user_collect(user_id: int):
cursor = get_cursor() with pool.connection() as conn:
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type=1") with conn.cursor() as cursor:
data = cursor.fetchall() # 获取所有记录列表 cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type=1")
data = [str(item['content']) for item in data] # 转换为数组 data = cursor.fetchall() # 获取所有记录列表
if not data: data = [str(item['content']) for item in data] # 转换为数组
return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []} if not data:
# 查询图片ID(对特殊字符安全转义) return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []}
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) # 查询图片ID(对特殊字符安全转义)
data = cursor.fetchall() cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
data = [str(item['id']) for item in data] data = cursor.fetchall()
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data} data = [str(item['id']) for item in data]
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data}