From bdd08285e73564c86be905ac77ec5129673b756f Mon Sep 17 00:00:00 2001 From: satori Date: Sun, 10 Nov 2024 22:02:45 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=BF=E6=8D=A2=E8=BF=9E=E6=8E=A5=E6=B1=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/mysql.py | 64 ++++++--- routers/img.py | 301 ++++++++++++++++++++-------------------- routers/reverse.py | 246 ++++++++++++++++---------------- routers/user.py | 24 ++-- routers/user_collect.py | 61 ++++---- 5 files changed, 351 insertions(+), 345 deletions(-) diff --git a/models/mysql.py b/models/mysql.py index 4c81aa5..da84b3b 100644 --- a/models/mysql.py +++ b/models/mysql.py @@ -1,27 +1,47 @@ import pymysql from configs.config import MYSQL_HOST, MYSQL_PORT, MYSQL_NAME, MYSQL_USER, MYSQL_PASS +from dbutils.pooled_db import PooledDB -# 创建 MySQL 连接 -def create_connection(): - return pymysql.connect( - host=MYSQL_HOST, - user=MYSQL_USER, - port=MYSQL_PORT, # 应该使用 MYSQL_PORT 而不是 MYSQL_HOST - password=MYSQL_PASS, - database=MYSQL_NAME, - local_infile=True, - cursorclass=pymysql.cursors.DictCursor - ) -# 连接 MySQL (开启 MySQL 服务) -conn = create_connection() +# 创建数据库连接池 +pool = PooledDB( + creator=pymysql, # 使用 pymysql 作为数据库驱动 + maxconnections=20, # 最大连接数 + mincached=2, # 初始化时,连接池中至少创建的空闲连接 + maxcached=5, # 连接池中最多空闲连接数 + blocking=True, # 超过最大连接数时,是否阻塞 + maxusage=None, # 单个连接的最大复用次数 + ping=0, # 设置连接是否检查 + host=MYSQL_HOST, + port=MYSQL_PORT, + user=MYSQL_USER, + password=MYSQL_PASS, + database=MYSQL_NAME, + charset='utf8mb4' +) -# 获取 MySQL 连接 -def get_cursor(): - global conn - try: - conn.ping() - return conn.cursor() - except Exception: - conn = create_connection() - return conn.cursor() +## 创建 MySQL 连接 +#def create_connection(): +# return pymysql.connect( +# host=MYSQL_HOST, +# user=MYSQL_USER, +# port=MYSQL_PORT, # 应该使用 MYSQL_PORT 而不是 MYSQL_HOST +# password=MYSQL_PASS, +# database=MYSQL_NAME, +# local_infile=True, +# cursorclass=pymysql.cursors.DictCursor +# ) +# +## 连接 MySQL (开启 MySQL 服务) +#conn = create_connection() +# +## 获取 MySQL 连接 +#def get_cursor(): +# global conn +# try: +# conn.ping() +# return conn.cursor() +# except Exception: +# conn = create_connection() +# return conn.cursor() +# \ No newline at end of file diff --git a/routers/img.py b/routers/img.py index 1af387a..1effd2c 100644 --- a/routers/img.py +++ b/routers/img.py @@ -8,7 +8,7 @@ import _thread as thread from fastapi import APIRouter, HTTPException, Response from urllib.parse import unquote from configs.config import IMAGES_PATH -from models.mysql import get_cursor, conn +from models.mysql import pool from utilities.download import download_image, generate_thumbnail @@ -18,116 +18,109 @@ router = APIRouter() # 预热图片(获取一次图片, 遍历图片表, 检查OSS中所有被预定的尺寸是否存在, 不存在则生成) @router.get("/warm", summary="预热图片", description="预热图片") def warm_image(op:int=0, end:int=10, version:str='0'): - cursor = get_cursor() - cursor.execute(f"SELECT * FROM `web_images` limit {op}, {end}") - for img in cursor.fetchall(): - # 如果CPU使用率大于50%, 则等待, 直到CPU使用率小于50% - while statistics.mean(psutil.cpu_percent(interval=1, percpu=True)) > 50: - print(statistics.mean(psutil.cpu_percent(interval=1, percpu=True)), '等待CPU释放...') - time.sleep(2) - - # 如果内存剩余小于1G, 则等待, 直到内存剩余大于1G - while psutil.virtual_memory().available < 1024 * 1024 * 1024: - print(psutil.virtual_memory().available, '等待内存释放...') - time.sleep(2) - - # CPU使用率已降低, 开始处理图片 - image = download_image(img['content']) # 从OSS下载原图 - if not image: - print('跳过不存在的图片:', img['content']) - continue - - # 创建新线程处理图片 - try: - print('开始处理图片:', img['content']) - thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 3, 328, 'webp')) - thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 2, 328, 'webp')) - thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 1, 328, 'webp')) - except: - print('无法启动线程') - cursor.close() - return Response('预热成功', status_code=200, media_type='text/plain', headers={'Content-Type': 'text/plain; charset=utf-8'}) + with pool.connection() as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT * FROM `web_images` limit {op}, {end}") + for img in cursor.fetchall(): + # 如果CPU使用率大于50%, 则等待, 直到CPU使用率小于50% + while statistics.mean(psutil.cpu_percent(interval=1, percpu=True)) > 50: + print(statistics.mean(psutil.cpu_percent(interval=1, percpu=True)), '等待CPU释放...') + time.sleep(2) + # 如果内存剩余小于1G, 则等待, 直到内存剩余大于1G + while psutil.virtual_memory().available < 1024 * 1024 * 1024: + print(psutil.virtual_memory().available, '等待内存释放...') + time.sleep(2) + # CPU使用率已降低, 开始处理图片 + image = download_image(img['content']) # 从OSS下载原图 + if not image: + print('跳过不存在的图片:', img['content']) + continue + # 创建新线程处理图片 + try: + print('开始处理图片:', img['content']) + thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 3, 328, 'webp')) + thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 2, 328, 'webp')) + thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 1, 328, 'webp')) + except: + print('无法启动线程') + return Response('预热成功', status_code=200, media_type='text/plain', headers={'Content-Type': 'text/plain; charset=utf-8'}) # 获取非标准类缩略图 @router.get("/{type}-{id}-{version}@{n}x{w}.{ext}", summary="获取非标准类缩略图", description="/img/article-233-version@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:str): - img_path = f"{IMAGES_PATH}/{type}-{id}-{version}@{n}x{w}.{ext}" - if os.path.exists(img_path): - return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") - if type == 'ad' or type == 'article' or type == 'article_attribute': - cursor = get_cursor() - count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}") - img = cursor.fetchone() - cursor.close() - if img is None: - print('图片不存在:', count) - return Response('图片不存在', status_code=404) - url = img['image'] - elif type == 'url': - id = unquote(id, 'utf-8') - id = id.replace(' ','+') - url = unquote(base64.b64decode(id)) - print(url) - elif type == 'avatar': - cursor = get_cursor() - count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}") - user = cursor.fetchone() - cursor.close() - if user is None: - print('用户不存在:', count) - return Response('用户不存在', status_code=404) - url = user['avatar'] - else: - print('图片类型不存在:', type) - return Response('图片类型不存在', status_code=404) - image = download_image(url) - if not image: - return Response('图片不存在', status_code=404) - # 如果是 avatar, 则裁剪为正方形 - if type == 'avatar': - px = image.size[0] if image.size[0] < image.size[1] else image.size[1] - image = image.crop((0, 0, px, px)) - image.thumbnail((n*w, image.size[1])) - image.save(img_path, ext, save_all=True) - return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + with pool.connection() as conn: + with conn.cursor() as cursor: + img_path = f"{IMAGES_PATH}/{type}-{id}-{version}@{n}x{w}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + if type == 'ad' or type == 'article' or type == 'article_attribute': + count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}") + img = cursor.fetchone() + if img is None: + print('图片不存在:', count) + return Response('图片不存在', status_code=404) + url = img['image'] + elif type == 'url': + id = unquote(id, 'utf-8') + id = id.replace(' ','+') + url = unquote(base64.b64decode(id)) + print(url) + elif type == 'avatar': + count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}") + user = cursor.fetchone() + if user is None: + print('用户不存在:', count) + return Response('用户不存在', status_code=404) + url = user['avatar'] + else: + print('图片类型不存在:', type) + return Response('图片类型不存在', status_code=404) + image = download_image(url) + if not image: + return Response('图片不存在', status_code=404) + # 如果是 avatar, 则裁剪为正方形 + if type == 'avatar': + px = image.size[0] if image.size[0] < image.size[1] else image.size[1] + image = image.crop((0, 0, px, px)) + image.thumbnail((n*w, image.size[1])) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") # 获取非标准类原尺寸图 @router.get("/{type}-{id}-{version}.{ext}", summary="获取文章缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") def get_image_type(type:str, id:str, version:str, ext:str): - img_path = f"{IMAGES_PATH}/{type}-{id}-{version}.{ext}" - if os.path.exists(img_path): - return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") - if type == 'ad' or type == 'article' or type == 'article_attribute': - cursor = get_cursor() - count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}") - img = cursor.fetchone() - cursor.close() - if img is None: - print('图片不存在:', count) - return Response('图片不存在', status_code=404) - url = img['image'] - elif type == 'url': - id = unquote(id, 'utf-8') - id = id.replace(' ','+') - url = unquote(base64.b64decode(id)) - print("url:", url) - elif type == 'avatar': - cursor = get_cursor() - count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}") - user = cursor.fetchone() - cursor.close() - if user is None: - print('用户不存在:', count) - return Response('用户不存在', status_code=404) - url = user['avatar'] - else: - print('图片类型不存在:', type) - return Response('图片类型不存在', status_code=404) - image = download_image(url) - image.save(img_path, ext, save_all=True) - return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + with pool.connection() as conn: + with conn.cursor() as cursor: + img_path = f"{IMAGES_PATH}/{type}-{id}-{version}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + if type == 'ad' or type == 'article' or type == 'article_attribute': + count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}") + img = cursor.fetchone() + if img is None: + print('图片不存在:', count) + return Response('图片不存在', status_code=404) + url = img['image'] + elif type == 'url': + id = unquote(id, 'utf-8') + id = id.replace(' ','+') + url = unquote(base64.b64decode(id)) + print("url:", url) + elif type == 'avatar': + count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}") + user = cursor.fetchone() + if user is None: + print('用户不存在:', count) + return Response('用户不存在', status_code=404) + url = user['avatar'] + else: + print('图片类型不存在:', type) + return Response('图片类型不存在', status_code=404) + image = download_image(url) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") # 通过url获取图片 @@ -149,66 +142,66 @@ def get_image_url(url:str, n:int, w:int, ext:str): # 获取标准缩略图(带版本号) @router.get("/{id}-{version}@{n}x{w}.{ext}", summary="获取缩略图(带版本号)", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") def get_image_thumbnail(id:int, version:str, n:int, w:int, ext:str): - # 判断图片是否已经生成 - img_path = f"{IMAGES_PATH}/{id}-{version}@{n}x{w}.{ext}" - if os.path.exists(img_path): - return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") - # 从数据库获取原图地址 - cursor = get_cursor() - cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") - img = cursor.fetchone() - cursor.close() - if img is None: - print('图片不存在:', id) - return Response('图片不存在', status_code=404) - image = download_image(img['content']) - if not image: - return Response('图片不存在', status_code=404) - image.thumbnail((n*w, image.size[1])) - image.save(img_path, ext, save_all=True) - return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + with pool.connection() as conn: + with conn.cursor() as cursor: + # 判断图片是否已经生成 + img_path = f"{IMAGES_PATH}/{id}-{version}@{n}x{w}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + # 从数据库获取原图地址 + cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") + img = cursor.fetchone() + if img is None: + print('图片不存在:', id) + return Response('图片不存在', status_code=404) + image = download_image(img['content']) + if not image: + return Response('图片不存在', status_code=404) + image.thumbnail((n*w, image.size[1])) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") # 获取标准缩略图 @router.get("/{id}@{n}x{w}.{ext}", summary="获取缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图") def get_image_thumbnail(id:int, n:int, w:int, ext:str): - # 判断图片是否已经生成 - img_path = f"{IMAGES_PATH}/{id}@{n}x{w}.{ext}" - if os.path.exists(img_path): - return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") - # 从数据库获取原图地址 - cursor = get_cursor() - cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") - img = cursor.fetchone() - cursor.close() - if img is None: - print('图片不存在:', id) - return Response('图片不存在', status_code=404) - image = download_image(img['content']) - if not image: - return Response('图片不存在', status_code=404) - image.thumbnail((n*w, image.size[1])) - image.save(img_path, ext, save_all=True) - return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + with pool.connection() as conn: + with conn.cursor() as cursor: + # 判断图片是否已经生成 + img_path = f"{IMAGES_PATH}/{id}@{n}x{w}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + # 从数据库获取原图地址 + cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") + img = cursor.fetchone() + if img is None: + print('图片不存在:', id) + return Response('图片不存在', status_code=404) + image = download_image(img['content']) + if not image: + return Response('图片不存在', status_code=404) + image.thumbnail((n*w, image.size[1])) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") # 获取标准原尺寸图 @router.get("/{id}.{ext}", summary="获取标准原尺寸图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 无后缀获取原图") def get_image(id: int = 824, ext: str = 'webp'): - # 判断图片是否已经生成 - img_path = f"{IMAGES_PATH}/{id}.{ext}" - if os.path.exists(img_path): - return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") - # 从数据库获取原图地址 - cursor = get_cursor() - cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") - img = cursor.fetchone() - cursor.close() - if img is None: - print('图片不存在:', id) - return Response('图片不存在', status_code=404) - image = download_image(img['content']) - if not image: - return Response('图片不存在', status_code=404) - image.save(img_path, ext, save_all=True) - return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + with pool.connection() as conn: + with conn.cursor() as cursor: + # 判断图片是否已经生成 + img_path = f"{IMAGES_PATH}/{id}.{ext}" + if os.path.exists(img_path): + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") + # 从数据库获取原图地址 + cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}") + img = cursor.fetchone() + if img is None: + print('图片不存在:', id) + return Response('图片不存在', status_code=404) + image = download_image(img['content']) + if not image: + return Response('图片不存在', status_code=404) + image.save(img_path, ext, save_all=True) + return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}") diff --git a/routers/reverse.py b/routers/reverse.py index d986c65..7759b4c 100644 --- a/routers/reverse.py +++ b/routers/reverse.py @@ -10,7 +10,7 @@ from towhee import pipe, ops from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header from models.milvus import get_collection, collection_name -from models.mysql import get_cursor, conn +from models.mysql import pool from configs.config import UPLOAD_PATH from utilities.download import download_image @@ -39,111 +39,108 @@ async def create_index(): # 重建指定图像的向量 @router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量') async def rewrite_image(image_id: int): - print('START', image_id, '重建向量') - cursor = get_cursor() - cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}") - img = cursor.fetchone() - cursor.close() - if img is None: - print('mysql中原始图片不存在:', image_id) - return Response('图片不存在', status_code=404) - img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content'])) - print('img_path', img_path) - image = download_image(img['content']) - if image is None: - print('图片下载失败:', img['content']) - return Response('图片下载失败', status_code=404) - image.save(img_path, 'png', save_all=True) - - with Image.open(img_path) as imgx: - feat = RESNET50(imgx).get()[0] - - collection = get_collection(collection_name) - collection.delete(expr=f'id in [{image_id}]') - rest = collection.insert([[image_id], [feat], [img['article_id']]]) - os.remove(img_path) - print('END', image_id, '重建向量', rest.primary_keys) - return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()} + with pool.connection() as conn: + with conn.cursor() as cursor: + print('START', image_id, '重建向量') + cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}") + img = cursor.fetchone() + if img is None: + print('mysql中原始图片不存在:', image_id) + return Response('图片不存在', status_code=404) + img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content'])) + print('img_path', img_path) + image = download_image(img['content']) + if image is None: + print('图片下载失败:', img['content']) + return Response('图片下载失败', status_code=404) + image.save(img_path, 'png', save_all=True) + with Image.open(img_path) as imgx: + feat = RESNET50(imgx).get()[0] + collection = get_collection(collection_name) + collection.delete(expr=f'id in [{image_id}]') + rest = collection.insert([[image_id], [feat], [img['article_id']]]) + os.remove(img_path) + print('END', image_id, '重建向量', rest.primary_keys) + return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()} # 获取相似(废弃) @router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片') async def similar_images(image_id: int, page: int = 1, pageSize: int = 20): - collection = get_collection(collection_name) - result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1) - # 如果没有结果, 则重新生成记录 - if len(result) == 0: - cursor = get_cursor() - cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}") - img = cursor.fetchone() - cursor.close() - if img is None: - print('mysql 中图片不存在:', image_id) - return Response('图片不存在', status_code=404) - img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content'])) - image = download_image(img['content']) - if image is None: - print('图片下载失败:', img['content']) - return Response('图片下载失败', status_code=404) - image.save(img_path, 'png', save_all=True) - with Image.open(img_path) as imgx: - feat = RESNET50(imgx).get()[0] - # 移除可能存在的旧记录, 换上新的 - collection.delete(expr=f'id in [{image_id}]') - collection.insert([[image_id], [feat], [img['article_id']]]) - os.remove(img_path) - print('生成') - else: - print('通过') - feat = result[0]['embedding'] - res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0] - # 翻页(截取有效范围, page * pageize) - ope = page*pageSize-pageSize - end = page*pageSize - next = False - # 为数据附加信息 - ids = [i.id for i in res] # 获取所有ID - if len(res) <= end: - ids = ids[ope:] + with pool.connection() as conn: + collection = get_collection(collection_name) + result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1) + # 如果没有结果, 则重新生成记录 + if len(result) == 0: + with conn.cursor() as cursor: + cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}") + img = cursor.fetchone() + if img is None: + print('mysql 中图片不存在:', image_id) + return Response('图片不存在', status_code=404) + img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content'])) + image = download_image(img['content']) + if image is None: + print('图片下载失败:', img['content']) + return Response('图片下载失败', status_code=404) + image.save(img_path, 'png', save_all=True) + with Image.open(img_path) as imgx: + feat = RESNET50(imgx).get()[0] + # 移除可能存在的旧记录, 换上新的 + collection.delete(expr=f'id in [{image_id}]') + collection.insert([[image_id], [feat], [img['article_id']]]) + os.remove(img_path) + print('生成') + else: + print('通过') + feat = result[0]['embedding'] + res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0] + # 翻页(截取有效范围, page * pageize) + ope = page*pageSize-pageSize + end = page*pageSize next = False - else: - ids = ids[ope:end] - next = True - str_ids = str(ids).replace('[', '').replace(']', '') - if str_ids == '': - print('没有更多数据了') - return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []} - cursor = get_cursor() - cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})") - imgs = cursor.fetchall() - if len(imgs) == 0: - return imgs - # 获取用户ID和文章ID - uids = list(set([x['user_id'] for x in imgs])) - tids = list(set([x['article_id'] for x in imgs])) - # 获取用户信息 - cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})") - users = cursor.fetchall() - # 获取文章信息 - cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})") - articles = cursor.fetchall() - cursor.close() - # 合并信息 - user, article = {}, {} - for x in users: user[x['id']] = x - for x in articles: article[x['id']] = x - for x in imgs: - x['article'] = article[x['article_id']] - x['user'] = user[x['user_id']] - x['distance'] = [i.distance for i in res if i.id == x['id']][0] - if x['praise_count'] == None: x['praise_count'] = 0 - if x['collect_count'] == None: x['collect_count'] = 0 - # 将字段名转换为驼峰 - x['createTime'] = x.pop('create_time') - x['updateTime'] = x.pop('update_time') - # 对 imgs 重新排序(按照 distance 字段) - imgs = sorted(imgs, key=lambda x: x['distance']) - return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs} + # 为数据附加信息 + ids = [i.id for i in res] # 获取所有ID + if len(res) <= end: + ids = ids[ope:] + next = False + else: + ids = ids[ope:end] + next = True + str_ids = str(ids).replace('[', '').replace(']', '') + if str_ids == '': + print('没有更多数据了') + return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []} + with conn.cursor() as cursor: + cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})") + imgs = cursor.fetchall() + if len(imgs) == 0: + return imgs + # 获取用户ID和文章ID + uids = list(set([x['user_id'] for x in imgs])) + tids = list(set([x['article_id'] for x in imgs])) + # 获取用户信息 + cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})") + users = cursor.fetchall() + # 获取文章信息 + cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})") + articles = cursor.fetchall() + # 合并信息 + user, article = {}, {} + for x in users: user[x['id']] = x + for x in articles: article[x['id']] = x + for x in imgs: + x['article'] = article[x['article_id']] + x['user'] = user[x['user_id']] + x['distance'] = [i.distance for i in res if i.id == x['id']][0] + if x['praise_count'] == None: x['praise_count'] = 0 + if x['collect_count'] == None: x['collect_count'] = 0 + # 将字段名转换为驼峰 + x['createTime'] = x.pop('create_time') + x['updateTime'] = x.pop('update_time') + # 对 imgs 重新排序(按照 distance 字段) + imgs = sorted(imgs, key=lambda x: x['distance']) + return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs} @router.post(path='', summary='以图搜图', description='上传图片进行搜索') @@ -159,33 +156,28 @@ async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize: if not ids: return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []} - cursor = get_cursor() - cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})") - imgs = cursor.fetchall() - - if not imgs: - return imgs - - uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs)) - cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})") - users = cursor.fetchall() - cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})") - articles = cursor.fetchall() - cursor.close() - - user, article = {x['id']: x for x in users}, {x['id']: x for x in articles} - for x in imgs: - x.update({ - 'article': article[x['article_id']], - 'user': user[x['user_id']], - #'distance': next(i.distance for i in res if i.id == x['id'], 0), - 'distance': [i.distance for i in res if i.id == x['id']][0], - 'praise_count': x.get('praise_count', 0), - 'collect_count': x.get('collect_count', 0) - }) - - imgs.sort(key=lambda x: x['distance']) - return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs} + with pool.connection() as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})") + imgs = cursor.fetchall() + if not imgs: + return imgs + uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs)) + cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})") + users = cursor.fetchall() + cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})") + articles = cursor.fetchall() + user, article = {x['id']: x for x in users}, {x['id']: x for x in articles} + for x in imgs: + x.update({ + 'article': article[x['article_id']], + 'user': user[x['user_id']], + 'distance': [i.distance for i in res if i.id == x['id']][0], + 'praise_count': x.get('praise_count', 0), + 'collect_count': x.get('collect_count', 0) + }) + imgs.sort(key=lambda x: x['distance']) + return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs} @router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像") async def delete_images(thread_id: str): diff --git a/routers/user.py b/routers/user.py index eb16016..d6d94bf 100644 --- a/routers/user.py +++ b/routers/user.py @@ -1,19 +1,19 @@ from fastapi import APIRouter, HTTPException -from models.mysql import conn, get_cursor +from models.mysql import pool router = APIRouter() @router.get('/{user_id}/collect', summary="用户收藏记录", description="获取指定用户收藏记录") def get_user_collect(user_id:int): - # TODO: 需要验证权限 - cursor = get_cursor() - cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'") - data = cursor.fetchall() - if not data: - return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []} - data = [str(item['content']) for item in data] - cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) - data = cursor.fetchall() - data = [str(item['id']) for item in data] - return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data } + with pool.connection() as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'") + data = cursor.fetchall() + if not data: + return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []} + data = [str(item['content']) for item in data] + cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) + data = cursor.fetchall() + data = [str(item['id']) for item in data] + return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data } diff --git a/routers/user_collect.py b/routers/user_collect.py index e8d912a..68c9d21 100644 --- a/routers/user_collect.py +++ b/routers/user_collect.py @@ -1,6 +1,6 @@ from typing import Optional from fastapi import APIRouter, HTTPException, Header -from models.mysql import conn, get_cursor +from models.mysql import pool router = APIRouter() @@ -9,37 +9,38 @@ router = APIRouter() # 获取当前用户收藏记录 @router.get('', summary='自己的收藏记录', description='获取自己的收藏记录, 用于判断是否收藏(headers中必须附带token)') def get_self_collect(token: Optional[str] = Header()): - print('token: ', token) - cursor = get_cursor() - # 查询用户ID - cursor.execute(f"SELECT user_id FROM web_auth WHERE token={token} limit 1") - data = cursor.fetchone() - print('auth: ', data) - if not data: - raise HTTPException(status_code=401, detail="用户未登录") - user_id = data['user_id'] - # 查询收藏记录 - cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'") - data = cursor.fetchall() # 获取所有记录列表 - data = [str(item['content']) for item in data] # 转换为数组 - # 查询图片ID(对特殊字符安全转义) - cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) - data = cursor.fetchall() - data = [str(item['id']) for item in data] - return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data } + with pool.connection() as conn: + with conn.cursor() as cursor: + # 查询用户ID + cursor.execute(f"SELECT user_id FROM web_auth WHERE token={token} limit 1") + data = cursor.fetchone() + print('auth: ', data) + if not data: + raise HTTPException(status_code=401, detail="用户未登录") + user_id = data['user_id'] + # 查询收藏记录 + cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'") + data = cursor.fetchall() # 获取所有记录列表 + data = [str(item['content']) for item in data] # 转换为数组 + # 查询图片ID(对特殊字符安全转义) + cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) + data = cursor.fetchall() + data = [str(item['id']) for item in data] + return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data } # 获取指定用户收藏记录 @router.get('/{user_id}', summary='指定用的户收藏记录', description='获取指定用户收藏记录(仅测试用)') def get_user_collect(user_id: int): - cursor = get_cursor() - cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type=1") - data = cursor.fetchall() # 获取所有记录列表 - data = [str(item['content']) for item in data] # 转换为数组 - if not data: - return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []} - # 查询图片ID(对特殊字符安全转义) - cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) - data = cursor.fetchall() - data = [str(item['id']) for item in data] - return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data} + with pool.connection() as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type=1") + data = cursor.fetchall() # 获取所有记录列表 + data = [str(item['content']) for item in data] # 转换为数组 + if not data: + return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []} + # 查询图片ID(对特殊字符安全转义) + cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data) + data = cursor.fetchall() + data = [str(item['id']) for item in data] + return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data}