This commit is contained in:
2024-11-04 05:20:42 +08:00
parent e990473dcd
commit 07de4d5fd5
24 changed files with 1385 additions and 2 deletions

214
routers/img.py Normal file
View File

@@ -0,0 +1,214 @@
import os
import time
import base64
import psutil
import statistics
import _thread as thread
from fastapi import APIRouter, HTTPException, Response
from urllib.parse import unquote
from configs.config import IMAGES_PATH
from models.mysql import get_cursor, conn
from utilities.download import download_image, generate_thumbnail
router = APIRouter()
# 预热图片(获取一次图片, 遍历图片表, 检查OSS中所有被预定的尺寸是否存在, 不存在则生成)
@router.get("/warm", summary="预热图片", description="预热图片")
def warm_image(op:int=0, end:int=10, version:str='0'):
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` limit {op}, {end}")
for img in cursor.fetchall():
# 如果CPU使用率大于50%, 则等待, 直到CPU使用率小于50%
while statistics.mean(psutil.cpu_percent(interval=1, percpu=True)) > 50:
print(statistics.mean(psutil.cpu_percent(interval=1, percpu=True)), '等待CPU释放...')
time.sleep(2)
# 如果内存剩余小于1G, 则等待, 直到内存剩余大于1G
while psutil.virtual_memory().available < 1024 * 1024 * 1024:
print(psutil.virtual_memory().available, '等待内存释放...')
time.sleep(2)
# CPU使用率已降低, 开始处理图片
image = download_image(img['content']) # 从OSS下载原图
if not image:
print('跳过不存在的图片:', img['content'])
continue
# 创建新线程处理图片
try:
print('开始处理图片:', img['content'])
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 3, 328, 'webp'))
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 2, 328, 'webp'))
thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 1, 328, 'webp'))
except:
print('无法启动线程')
cursor.close()
return Response('预热成功', status_code=200, media_type='text/plain', headers={'Content-Type': 'text/plain; charset=utf-8'})
# 获取非标准类缩略图
@router.get("/{type}-{id}-{version}@{n}x{w}.{ext}", summary="获取非标准类缩略图", description="/img/article-233-version@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:str):
img_path = f"{IMAGES_PATH}/{type}-{id}-{version}@{n}x{w}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
if type == 'ad' or type == 'article' or type == 'article_attribute':
cursor = get_cursor()
count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('图片不存在:', count)
return Response('图片不存在', status_code=404)
url = img['image']
elif type == 'url':
id = unquote(id, 'utf-8')
id = id.replace(' ','+')
url = unquote(base64.b64decode(id))
print(url)
elif type == 'avatar':
cursor = get_cursor()
count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
user = cursor.fetchone()
cursor.close()
if user is None:
print('用户不存在:', count)
return Response('用户不存在', status_code=404)
url = user['avatar']
else:
print('图片类型不存在:', type)
return Response('图片类型不存在', status_code=404)
image = download_image(url)
if not image:
return Response('图片不存在', status_code=404)
# 如果是 avatar, 则裁剪为正方形
if type == 'avatar':
px = image.size[0] if image.size[0] < image.size[1] else image.size[1]
image = image.crop((0, 0, px, px))
image.thumbnail((n*w, image.size[1]))
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取非标准类原尺寸图
@router.get("/{type}-{id}-{version}.{ext}", summary="获取文章缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_type(type:str, id:str, version:str, ext:str):
img_path = f"{IMAGES_PATH}/{type}-{id}-{version}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
if type == 'ad' or type == 'article' or type == 'article_attribute':
cursor = get_cursor()
count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('图片不存在:', count)
return Response('图片不存在', status_code=404)
url = img['image']
elif type == 'url':
id = unquote(id, 'utf-8')
id = id.replace(' ','+')
url = unquote(base64.b64decode(id))
print("url:", url)
elif type == 'avatar':
cursor = get_cursor()
count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
user = cursor.fetchone()
cursor.close()
if user is None:
print('用户不存在:', count)
return Response('用户不存在', status_code=404)
url = user['avatar']
else:
print('图片类型不存在:', type)
return Response('图片类型不存在', status_code=404)
image = download_image(url)
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 通过url获取图片
@router.get("/url-{url}@{n}x{w}.{ext}", summary="通过url获取图片", description="/img/article-233.webp")
def get_image_url(url:str, n:int, w:int, ext:str):
img_path = f"{IMAGES_PATH}/{type}-{url}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
url = unquote(url, 'utf-8').replace(' ','+')
url = unquote(base64.b64decode(url))
image = download_image(url)
if not image:
return Response('图片不存在', status_code=404)
image.thumbnail((n*w, image.size[1]))
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取标准缩略图(带版本号)
@router.get("/{id}-{version}@{n}x{w}.{ext}", summary="获取缩略图(带版本号)", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_thumbnail(id:int, version:str, n:int, w:int, ext:str):
# 判断图片是否已经生成
img_path = f"{IMAGES_PATH}/{id}-{version}@{n}x{w}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 从数据库获取原图地址
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('图片不存在:', id)
return Response('图片不存在', status_code=404)
image = download_image(img['content'])
if not image:
return Response('图片不存在', status_code=404)
image.thumbnail((n*w, image.size[1]))
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取标准缩略图
@router.get("/{id}@{n}x{w}.{ext}", summary="获取缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
def get_image_thumbnail(id:int, n:int, w:int, ext:str):
# 判断图片是否已经生成
img_path = f"{IMAGES_PATH}/{id}@{n}x{w}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 从数据库获取原图地址
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('图片不存在:', id)
return Response('图片不存在', status_code=404)
image = download_image(img['content'])
if not image:
return Response('图片不存在', status_code=404)
image.thumbnail((n*w, image.size[1]))
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 获取标准原尺寸图
@router.get("/{id}.{ext}", summary="获取标准原尺寸图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 无后缀获取原图")
def get_image(id: int = 824, ext: str = 'webp'):
# 判断图片是否已经生成
img_path = f"{IMAGES_PATH}/{id}.{ext}"
if os.path.exists(img_path):
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
# 从数据库获取原图地址
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('图片不存在:', id)
return Response('图片不存在', status_code=404)
image = download_image(img['content'])
if not image:
return Response('图片不存在', status_code=404)
image.save(img_path, ext, save_all=True)
return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")

231
routers/reverse.py Normal file
View File

@@ -0,0 +1,231 @@
import os
import random
import string
import sqlite3
import numpy as np
import time
import io
from PIL import Image
from towhee import pipe, ops
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
from models.milvus import get_collection, collection_name
from models.mysql import get_cursor, conn
#from models.resnet import Resnet50
from configs.config import UPLOAD_PATH
from utilities.download import download_image
router = APIRouter()
#MODEL = Resnet50()
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
# 获取状态统计
@router.get('', summary='状态统计', description='通过表名获取状态统计')
def count_images():
collection = get_collection(collection_name)
return {'status': True, 'count': collection.num_entities}
# 手动重建索引
@router.get('/create_index', summary='重建索引', description='手动重建索引', include_in_schema=False)
async def create_index():
collection = get_collection(collection_name)
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
collection.create_index(field_name="embedding", index_params=default_index)
collection.load()
return {'status': True, 'count': collection.num_entities}
'''
# 批量生成向量
@router.get('/create_vector', summary='生成向量', description='手动生成向量', include_in_schema=False)
async def create_vector(count: int = 10):
cursor = get_cursor()
cursor.execute(f"SELECT id,thumbnail_image,article_id,milvus_id FROM `web_images` WHERE thumbnail_image IS NOT NULL AND article_id IS NOT NULL AND milvus_id != 2048 AND width IS NOT NULL LIMIT 0,{count}")
images = cursor.fetchall()
cursor.close()
for item in images:
print(item)
# 先查询 milvus 中是否存在
collection = get_collection(collection_name)
data = collection.query(expr=f'id in [{item["id"]}]', output_fields=None, partition_names=None, timeout=None) # offset, limit
if len(data) > 0:
cursor = get_cursor()
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
conn.commit()
cursor.close()
continue
try:
img_path = os.path.join(UPLOAD_PATH, os.path.basename(item['thumbnail_image']))
download_image(item['thumbnail_image']).save(img_path, 'png', save_all=True)
feat = MODEL.resnet50_extract_feat(img_path)
collection.insert([[item['id']], [feat], [item['article_id']]])
cursor = get_cursor()
cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
conn.commit()
cursor.close()
except Exception as e:
print(e)
print('END')
return images
'''
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
async def rewrite_image(image_id: int):
print('START', image_id, '重建向量')
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('mysql中原始图片不存在:', image_id)
return Response('图片不存在', status_code=404)
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
print('img_path', img_path)
image = download_image(img['content'])
if image is None:
print('图片下载失败:', img['content'])
return Response('图片下载失败', status_code=404)
image.save(img_path, 'png', save_all=True)
with Image.open(img_path) as imgx:
feat = RESNET50(imgx).get()[0]
collection = get_collection(collection_name)
collection.delete(expr=f'id in [{image_id}]')
rest = collection.insert([[image_id], [feat], [img['article_id']]])
os.remove(img_path)
print('END', image_id, '重建向量', rest.primary_keys)
return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()}
# 获取相似(废弃)
@router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片')
async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
collection = get_collection(collection_name)
result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1)
# 如果没有结果, 则重新生成记录
if len(result) == 0:
cursor = get_cursor()
cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
img = cursor.fetchone()
cursor.close()
if img is None:
print('mysql 中图片不存在:', image_id)
return Response('图片不存在', status_code=404)
img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
image = download_image(img['content'])
if image is None:
print('图片下载失败:', img['content'])
return Response('图片下载失败', status_code=404)
image.save(img_path, 'png', save_all=True)
with Image.open(img_path) as imgx:
feat = RESNET50(imgx).get()[0]
# 移除可能存在的旧记录, 换上新的
collection.delete(expr=f'id in [{image_id}]')
collection.insert([[image_id], [feat], [img['article_id']]])
os.remove(img_path)
print('生成')
else:
print('通过')
feat = result[0]['embedding']
res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0]
# 翻页(截取有效范围, page * pageize)
ope = page*pageSize-pageSize
end = page*pageSize
next = False
# 为数据附加信息
ids = [i.id for i in res] # 获取所有ID
if len(res) <= end:
ids = ids[ope:]
next = False
else:
ids = ids[ope:end]
next = True
str_ids = str(ids).replace('[', '').replace(']', '')
if str_ids == '':
print('没有更多数据了')
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
cursor = get_cursor()
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})")
imgs = cursor.fetchall()
if len(imgs) == 0:
return imgs
# 获取用户ID和文章ID
uids = list(set([x['user_id'] for x in imgs]))
tids = list(set([x['article_id'] for x in imgs]))
# 获取用户信息
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
users = cursor.fetchall()
# 获取文章信息
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
articles = cursor.fetchall()
cursor.close()
# 合并信息
user, article = {}, {}
for x in users: user[x['id']] = x
for x in articles: article[x['id']] = x
for x in imgs:
x['article'] = article[x['article_id']]
x['user'] = user[x['user_id']]
x['distance'] = [i.distance for i in res if i.id == x['id']][0]
if x['praise_count'] == None: x['praise_count'] = 0
if x['collect_count'] == None: x['collect_count'] = 0
# 将字段名转换为驼峰
x['createTime'] = x.pop('create_time')
x['updateTime'] = x.pop('update_time')
# 对 imgs 重新排序(按照 distance 字段)
imgs = sorted(imgs, key=lambda x: x['distance'])
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs}
@router.post(path='', summary='以图搜图', description='上传图片进行搜索')
async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize: int = 20):
content = await image.read()
img = Image.open(image.file)
embeddig = RESNET50(img).get()[0]
collection = get_collection('default')
res = collection.search([embeddig],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=500)[0]
ope, end = (page - 1) * pageSize, page * pageSize
ids, nextx = [x.id for x in res][ope:end], len(res) > end
if not ids:
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
cursor = get_cursor()
cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})")
imgs = cursor.fetchall()
if not imgs:
return imgs
uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs))
cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
users = cursor.fetchall()
cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
articles = cursor.fetchall()
cursor.close()
user, article = {x['id']: x for x in users}, {x['id']: x for x in articles}
for x in imgs:
x.update({
'article': article[x['article_id']],
'user': user[x['user_id']],
#'distance': next(i.distance for i in res if i.id == x['id'], 0),
'distance': [i.distance for i in res if i.id == x['id']][0],
'praise_count': x.get('praise_count', 0),
'collect_count': x.get('collect_count', 0)
})
imgs.sort(key=lambda x: x['distance'])
return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs}
@router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像")
async def delete_images(thread_id: str):
collection = get_collection(collection_name)
collection.delete(expr="article_id in ["+thread_id+"]")
collection.load()
default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
collection.create_index(field_name="embedding", index_params=default_index)
return {"status": True, 'msg': '删除完毕'}

94
routers/task.py Normal file
View File

@@ -0,0 +1,94 @@
from fastapi import APIRouter, HTTPException, Request, WebSocket
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from models.task import TaskForm, TaskManager
router = APIRouter()
task_manager = TaskManager()
# 创建新任务
@router.post("", summary="创建新任务")
def create_task(form: TaskForm):
task = task_manager.add_task(name=form.name, user_id=123456, description=form.description)
return task
# 获取任务列表
@router.get("", summary="获取任务列表", description="可使用user_id参数筛选指定用户的任务列表")
def get_task_list(user_id: int=None):
return task_manager.get_tasks(user_id)
# websocket demo
@router.get("/demo", response_class=HTMLResponse)
async def websocket_demo(request: Request):
task_list = task_manager.get_tasks()
templates = Jinja2Templates(directory="templates")
return templates.TemplateResponse("websocket.html", {"request": request, "task_list": task_list})
# 监听指定任务的变化事件, 通知前端(不使用pydantic模型, 正确的写法)
@router.websocket("/{task_id}", name="监听任务变化")
async def websocket_endpoint(task_id: str, websocket: WebSocket):
await websocket.accept()
await task_manager.add_websocket(task_id, websocket)
async for data in websocket.iter_text():
await websocket.send_text(f"Message text was: {data}")
task_manager.remove_websocket(task_id, websocket)
print("websocket 连接已自动关闭")
#await websocket.send_json({"message": "Hello WebSocket!"})
#task = task_manager.get_task(task_id)
#print(task)
#if not task:
# print("task 不存在, 结束连接")
# return await websocket.close() # 任务不存在, 结束连接
#await websocket.send_json(task) # 将任务的状态发送给客户端
#await task_manager.add_websocket(task_id, websocket)
# 正确的写法, 使用 async for, 并且处理意外断开的情况
#try:
# async for data in websocket.iter_text():
# if data == "close":
# print("客户端主动关闭连接")
# task_manager.remove_websocket(task_id, websocket)
# await websocket.close()
# break
# else:
# print(f"接收到客户端消息: {data}")
# await websocket.send_text(f"Message text was: {data}")
#except Exception as e:
# print(f"客户端意外断开连接: {e}")
# task_manager.remove_websocket(task_id, websocket)
# #await websocket.close()
# 监听客户端的状态, 如果客户端主动关闭连接或意外断开连接, 都从任务的websocket列表中移除
#while True:
# try:
# data = await websocket.receive_text()
# if data == "close":
# print("客户端主动关闭连接")
# task_manager.remove_websocket(task_id, websocket)
# await websocket.close()
# break
# else:
# print(f"接收到客户端消息: {data}")
# await websocket.send_text(f"Message text was: {data}")
# except Exception as e:
# print(f"客户端意外断开连接: {e}")
# task_manager.remove_websocket(task_id, websocket)
# #await websocket.close()
# break
# 获取任务详情
@router.get("/{task_id}", summary="获取任务详情")
def get_task(task_id: int):
return get_task(task_id)
# 删除任务
@router.delete("/{task_id}", summary="删除指定任务")
def delete_task(task_id: int):
return delete_task(task_id)

19
routers/user.py Normal file
View File

@@ -0,0 +1,19 @@
from fastapi import APIRouter, HTTPException
from models.mysql import conn, get_cursor
router = APIRouter()
@router.get('/{user_id}/collect', summary="用户收藏记录", description="获取指定用户收藏记录")
def get_user_collect(user_id:int):
# TODO: 需要验证权限
cursor = get_cursor()
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'")
data = cursor.fetchall()
if not data:
return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []}
data = [str(item['content']) for item in data]
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
data = cursor.fetchall()
data = [str(item['id']) for item in data]
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data }

45
routers/user_collect.py Normal file
View File

@@ -0,0 +1,45 @@
from typing import Optional
from fastapi import APIRouter, HTTPException, Header
from models.mysql import conn, get_cursor
router = APIRouter()
# 获取当前用户收藏记录
@router.get('', summary='自己的收藏记录', description='获取自己的收藏记录, 用于判断是否收藏(headers中必须附带token)')
def get_self_collect(token: Optional[str] = Header()):
print('token: ', token)
cursor = get_cursor()
# 查询用户ID
cursor.execute(f"SELECT user_id FROM web_auth WHERE token={token} limit 1")
data = cursor.fetchone()
print('auth: ', data)
if not data:
raise HTTPException(status_code=401, detail="用户未登录")
user_id = data['user_id']
# 查询收藏记录
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type='1'")
data = cursor.fetchall() # 获取所有记录列表
data = [str(item['content']) for item in data] # 转换为数组
# 查询图片ID(对特殊字符安全转义)
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
data = cursor.fetchall()
data = [str(item['id']) for item in data]
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data }
# 获取指定用户收藏记录
@router.get('/{user_id}', summary='指定用的户收藏记录', description='获取指定用户收藏记录(仅测试用)')
def get_user_collect(user_id: int):
cursor = get_cursor()
cursor.execute(f"SELECT content FROM web_collect WHERE user_id={user_id} AND type=1")
data = cursor.fetchall() # 获取所有记录列表
data = [str(item['content']) for item in data] # 转换为数组
if not data:
return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []}
# 查询图片ID(对特殊字符安全转义)
cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
data = cursor.fetchall()
data = [str(item['id']) for item in data]
return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data}