替换连接池
This commit is contained in:
		@@ -1,27 +1,47 @@
 | 
			
		||||
import pymysql
 | 
			
		||||
from configs.config import MYSQL_HOST, MYSQL_PORT, MYSQL_NAME, MYSQL_USER, MYSQL_PASS
 | 
			
		||||
from dbutils.pooled_db import PooledDB
 | 
			
		||||
 | 
			
		||||
# 创建 MySQL 连接
 | 
			
		||||
def create_connection():
 | 
			
		||||
    return pymysql.connect(
 | 
			
		||||
 | 
			
		||||
# 创建数据库连接池
 | 
			
		||||
pool = PooledDB(
 | 
			
		||||
    creator=pymysql,        # 使用 pymysql 作为数据库驱动
 | 
			
		||||
    maxconnections=20,      # 最大连接数
 | 
			
		||||
    mincached=2,            # 初始化时,连接池中至少创建的空闲连接
 | 
			
		||||
    maxcached=5,            # 连接池中最多空闲连接数
 | 
			
		||||
    blocking=True,          # 超过最大连接数时,是否阻塞
 | 
			
		||||
    maxusage=None,          # 单个连接的最大复用次数
 | 
			
		||||
    ping=0,                 # 设置连接是否检查
 | 
			
		||||
    host=MYSQL_HOST,
 | 
			
		||||
    port=MYSQL_PORT,
 | 
			
		||||
    user=MYSQL_USER,
 | 
			
		||||
        port=MYSQL_PORT,  # 应该使用 MYSQL_PORT 而不是 MYSQL_HOST
 | 
			
		||||
    password=MYSQL_PASS,
 | 
			
		||||
    database=MYSQL_NAME,
 | 
			
		||||
        local_infile=True,
 | 
			
		||||
        cursorclass=pymysql.cursors.DictCursor
 | 
			
		||||
    charset='utf8mb4'
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# 连接 MySQL (开启 MySQL 服务)
 | 
			
		||||
conn = create_connection()
 | 
			
		||||
 | 
			
		||||
# 获取 MySQL 连接
 | 
			
		||||
def get_cursor():
 | 
			
		||||
    global conn
 | 
			
		||||
    try:
 | 
			
		||||
        conn.ping()
 | 
			
		||||
        return conn.cursor()
 | 
			
		||||
    except Exception:
 | 
			
		||||
        conn = create_connection()
 | 
			
		||||
        return conn.cursor()
 | 
			
		||||
## 创建 MySQL 连接
 | 
			
		||||
#def create_connection():
 | 
			
		||||
#    return pymysql.connect(
 | 
			
		||||
#        host=MYSQL_HOST,
 | 
			
		||||
#        user=MYSQL_USER,
 | 
			
		||||
#        port=MYSQL_PORT,  # 应该使用 MYSQL_PORT 而不是 MYSQL_HOST
 | 
			
		||||
#        password=MYSQL_PASS,
 | 
			
		||||
#        database=MYSQL_NAME,
 | 
			
		||||
#        local_infile=True,
 | 
			
		||||
#        cursorclass=pymysql.cursors.DictCursor
 | 
			
		||||
#    )
 | 
			
		||||
#
 | 
			
		||||
## 连接 MySQL (开启 MySQL 服务)
 | 
			
		||||
#conn = create_connection()
 | 
			
		||||
#
 | 
			
		||||
## 获取 MySQL 连接
 | 
			
		||||
#def get_cursor():
 | 
			
		||||
#    global conn
 | 
			
		||||
#    try:
 | 
			
		||||
#        conn.ping()
 | 
			
		||||
#        return conn.cursor()
 | 
			
		||||
#    except Exception:
 | 
			
		||||
#        conn = create_connection()
 | 
			
		||||
#        return conn.cursor()
 | 
			
		||||
#
 | 
			
		||||
@@ -8,7 +8,7 @@ import _thread as thread
 | 
			
		||||
from fastapi import APIRouter, HTTPException, Response
 | 
			
		||||
from urllib.parse import unquote
 | 
			
		||||
from configs.config import IMAGES_PATH
 | 
			
		||||
from models.mysql import get_cursor, conn
 | 
			
		||||
from models.mysql import pool
 | 
			
		||||
from utilities.download import download_image, generate_thumbnail
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -18,25 +18,23 @@ router = APIRouter()
 | 
			
		||||
# 预热图片(获取一次图片, 遍历图片表, 检查OSS中所有被预定的尺寸是否存在, 不存在则生成)
 | 
			
		||||
@router.get("/warm", summary="预热图片", description="预热图片")
 | 
			
		||||
def warm_image(op:int=0, end:int=10, version:str='0'):
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            cursor.execute(f"SELECT * FROM `web_images` limit {op}, {end}")
 | 
			
		||||
            for img in cursor.fetchall():
 | 
			
		||||
                # 如果CPU使用率大于50%, 则等待, 直到CPU使用率小于50%
 | 
			
		||||
                while statistics.mean(psutil.cpu_percent(interval=1, percpu=True)) > 50:
 | 
			
		||||
                    print(statistics.mean(psutil.cpu_percent(interval=1, percpu=True)), '等待CPU释放...')
 | 
			
		||||
                    time.sleep(2)
 | 
			
		||||
        
 | 
			
		||||
                # 如果内存剩余小于1G, 则等待, 直到内存剩余大于1G
 | 
			
		||||
                while psutil.virtual_memory().available < 1024 * 1024 * 1024:
 | 
			
		||||
                    print(psutil.virtual_memory().available, '等待内存释放...')
 | 
			
		||||
                    time.sleep(2)
 | 
			
		||||
 | 
			
		||||
                # CPU使用率已降低, 开始处理图片
 | 
			
		||||
                image = download_image(img['content']) # 从OSS下载原图
 | 
			
		||||
                if not image:
 | 
			
		||||
                    print('跳过不存在的图片:', img['content'])
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                # 创建新线程处理图片
 | 
			
		||||
                try:
 | 
			
		||||
                    print('开始处理图片:', img['content'])
 | 
			
		||||
@@ -45,21 +43,20 @@ def warm_image(op:int=0, end:int=10, version:str='0'):
 | 
			
		||||
                    thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 1, 328, 'webp'))
 | 
			
		||||
                except:
 | 
			
		||||
                    print('无法启动线程')
 | 
			
		||||
    cursor.close()
 | 
			
		||||
            return Response('预热成功', status_code=200, media_type='text/plain', headers={'Content-Type': 'text/plain; charset=utf-8'})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取非标准类缩略图
 | 
			
		||||
@router.get("/{type}-{id}-{version}@{n}x{w}.{ext}", summary="获取非标准类缩略图", description="/img/article-233-version@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
 | 
			
		||||
def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:str):
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            img_path = f"{IMAGES_PATH}/{type}-{id}-{version}@{n}x{w}.{ext}"
 | 
			
		||||
            if os.path.exists(img_path):
 | 
			
		||||
                return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
            if type == 'ad' or type == 'article' or type == 'article_attribute':
 | 
			
		||||
        cursor = get_cursor()
 | 
			
		||||
                count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
 | 
			
		||||
                img = cursor.fetchone()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
                if img is None:
 | 
			
		||||
                    print('图片不存在:', count)
 | 
			
		||||
                    return Response('图片不存在', status_code=404)
 | 
			
		||||
@@ -70,10 +67,8 @@ def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:st
 | 
			
		||||
                url = unquote(base64.b64decode(id))
 | 
			
		||||
                print(url)
 | 
			
		||||
            elif type == 'avatar':
 | 
			
		||||
        cursor = get_cursor()
 | 
			
		||||
                count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
 | 
			
		||||
                user = cursor.fetchone()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
                if user is None:
 | 
			
		||||
                    print('用户不存在:', count)
 | 
			
		||||
                    return Response('用户不存在', status_code=404)
 | 
			
		||||
@@ -96,14 +91,14 @@ def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:st
 | 
			
		||||
# 获取非标准类原尺寸图
 | 
			
		||||
@router.get("/{type}-{id}-{version}.{ext}", summary="获取文章缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
 | 
			
		||||
def get_image_type(type:str, id:str, version:str, ext:str):
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            img_path = f"{IMAGES_PATH}/{type}-{id}-{version}.{ext}"
 | 
			
		||||
            if os.path.exists(img_path):
 | 
			
		||||
                return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
            if type == 'ad' or type == 'article' or type == 'article_attribute':
 | 
			
		||||
        cursor = get_cursor()
 | 
			
		||||
                count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
 | 
			
		||||
                img = cursor.fetchone()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
                if img is None:
 | 
			
		||||
                    print('图片不存在:', count)
 | 
			
		||||
                    return Response('图片不存在', status_code=404)
 | 
			
		||||
@@ -114,10 +109,8 @@ def get_image_type(type:str, id:str, version:str, ext:str):
 | 
			
		||||
                url = unquote(base64.b64decode(id))
 | 
			
		||||
                print("url:", url)
 | 
			
		||||
            elif type == 'avatar':
 | 
			
		||||
        cursor = get_cursor()
 | 
			
		||||
                count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
 | 
			
		||||
                user = cursor.fetchone()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
                if user is None:
 | 
			
		||||
                    print('用户不存在:', count)
 | 
			
		||||
                    return Response('用户不存在', status_code=404)
 | 
			
		||||
@@ -149,15 +142,15 @@ def get_image_url(url:str, n:int, w:int, ext:str):
 | 
			
		||||
# 获取标准缩略图(带版本号)
 | 
			
		||||
@router.get("/{id}-{version}@{n}x{w}.{ext}", summary="获取缩略图(带版本号)", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
 | 
			
		||||
def get_image_thumbnail(id:int, version:str, n:int, w:int, ext:str):
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            # 判断图片是否已经生成
 | 
			
		||||
            img_path = f"{IMAGES_PATH}/{id}-{version}@{n}x{w}.{ext}"
 | 
			
		||||
            if os.path.exists(img_path):
 | 
			
		||||
                return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
            # 从数据库获取原图地址
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
            cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
 | 
			
		||||
            img = cursor.fetchone()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
            if img is None:
 | 
			
		||||
                print('图片不存在:', id)
 | 
			
		||||
                return Response('图片不存在', status_code=404)
 | 
			
		||||
@@ -172,15 +165,15 @@ def get_image_thumbnail(id:int, version:str, n:int, w:int, ext:str):
 | 
			
		||||
# 获取标准缩略图
 | 
			
		||||
@router.get("/{id}@{n}x{w}.{ext}", summary="获取缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
 | 
			
		||||
def get_image_thumbnail(id:int, n:int, w:int, ext:str):
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            # 判断图片是否已经生成
 | 
			
		||||
            img_path = f"{IMAGES_PATH}/{id}@{n}x{w}.{ext}"
 | 
			
		||||
            if os.path.exists(img_path):
 | 
			
		||||
                return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
            # 从数据库获取原图地址
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
            cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
 | 
			
		||||
            img = cursor.fetchone()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
            if img is None:
 | 
			
		||||
                print('图片不存在:', id)
 | 
			
		||||
                return Response('图片不存在', status_code=404)
 | 
			
		||||
@@ -195,15 +188,15 @@ def get_image_thumbnail(id:int, n:int, w:int, ext:str):
 | 
			
		||||
# 获取标准原尺寸图
 | 
			
		||||
@router.get("/{id}.{ext}", summary="获取标准原尺寸图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 无后缀获取原图")
 | 
			
		||||
def get_image(id: int = 824, ext: str = 'webp'):
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            # 判断图片是否已经生成
 | 
			
		||||
            img_path = f"{IMAGES_PATH}/{id}.{ext}"
 | 
			
		||||
            if os.path.exists(img_path):
 | 
			
		||||
                return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
            # 从数据库获取原图地址
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
            cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
 | 
			
		||||
            img = cursor.fetchone()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
            if img is None:
 | 
			
		||||
                print('图片不存在:', id)
 | 
			
		||||
                return Response('图片不存在', status_code=404)
 | 
			
		||||
 
 | 
			
		||||
@@ -10,7 +10,7 @@ from towhee import pipe, ops
 | 
			
		||||
 | 
			
		||||
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
 | 
			
		||||
from models.milvus import get_collection, collection_name
 | 
			
		||||
from models.mysql import get_cursor, conn
 | 
			
		||||
from models.mysql import pool
 | 
			
		||||
 | 
			
		||||
from configs.config import UPLOAD_PATH
 | 
			
		||||
from utilities.download import download_image
 | 
			
		||||
@@ -39,11 +39,11 @@ async def create_index():
 | 
			
		||||
# 重建指定图像的向量
 | 
			
		||||
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
 | 
			
		||||
async def rewrite_image(image_id: int):
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            print('START', image_id, '重建向量')
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
            cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
 | 
			
		||||
            img = cursor.fetchone()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
            if img is None:
 | 
			
		||||
                print('mysql中原始图片不存在:', image_id)
 | 
			
		||||
                return Response('图片不存在', status_code=404)
 | 
			
		||||
@@ -54,10 +54,8 @@ async def rewrite_image(image_id: int):
 | 
			
		||||
                print('图片下载失败:', img['content'])
 | 
			
		||||
                return Response('图片下载失败', status_code=404)
 | 
			
		||||
            image.save(img_path, 'png', save_all=True)
 | 
			
		||||
 | 
			
		||||
            with Image.open(img_path) as imgx:
 | 
			
		||||
                feat = RESNET50(imgx).get()[0]
 | 
			
		||||
 | 
			
		||||
            collection = get_collection(collection_name)
 | 
			
		||||
            collection.delete(expr=f'id in [{image_id}]')
 | 
			
		||||
            rest = collection.insert([[image_id], [feat], [img['article_id']]])
 | 
			
		||||
@@ -69,14 +67,14 @@ async def rewrite_image(image_id: int):
 | 
			
		||||
# 获取相似(废弃)
 | 
			
		||||
@router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片')
 | 
			
		||||
async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        collection = get_collection(collection_name)
 | 
			
		||||
        result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1)
 | 
			
		||||
        # 如果没有结果, 则重新生成记录
 | 
			
		||||
        if len(result) == 0:
 | 
			
		||||
        cursor = get_cursor()
 | 
			
		||||
            with conn.cursor() as cursor:
 | 
			
		||||
                cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
 | 
			
		||||
                img = cursor.fetchone()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
                if img is None:
 | 
			
		||||
                    print('mysql 中图片不存在:', image_id)
 | 
			
		||||
                    return Response('图片不存在', status_code=404)
 | 
			
		||||
@@ -113,7 +111,7 @@ async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
 | 
			
		||||
        if str_ids == '':
 | 
			
		||||
            print('没有更多数据了')
 | 
			
		||||
            return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})")
 | 
			
		||||
            imgs = cursor.fetchall()
 | 
			
		||||
            if len(imgs) == 0:
 | 
			
		||||
@@ -127,7 +125,6 @@ async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
 | 
			
		||||
            # 获取文章信息
 | 
			
		||||
            cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
 | 
			
		||||
            articles = cursor.fetchall()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
            # 合并信息
 | 
			
		||||
            user, article = {}, {}
 | 
			
		||||
            for x in users: user[x['id']] = x
 | 
			
		||||
@@ -159,31 +156,26 @@ async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize:
 | 
			
		||||
    if not ids:
 | 
			
		||||
        return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
 | 
			
		||||
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})")
 | 
			
		||||
            imgs = cursor.fetchall()
 | 
			
		||||
 | 
			
		||||
            if not imgs:
 | 
			
		||||
                return imgs
 | 
			
		||||
 | 
			
		||||
            uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs))
 | 
			
		||||
            cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
 | 
			
		||||
            users = cursor.fetchall()
 | 
			
		||||
            cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
 | 
			
		||||
            articles = cursor.fetchall()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
 | 
			
		||||
            user, article = {x['id']: x for x in users}, {x['id']: x for x in articles}
 | 
			
		||||
            for x in imgs:
 | 
			
		||||
                x.update({
 | 
			
		||||
                    'article': article[x['article_id']],
 | 
			
		||||
                    'user': user[x['user_id']],
 | 
			
		||||
            #'distance': next(i.distance for i in res if i.id == x['id'], 0),
 | 
			
		||||
                    'distance': [i.distance for i in res if i.id == x['id']][0],
 | 
			
		||||
                    'praise_count': x.get('praise_count', 0),
 | 
			
		||||
                    'collect_count': x.get('collect_count', 0)
 | 
			
		||||
                })
 | 
			
		||||
 | 
			
		||||
            imgs.sort(key=lambda x: x['distance'])
 | 
			
		||||
            return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,12 +1,12 @@
 | 
			
		||||
from fastapi import APIRouter, HTTPException
 | 
			
		||||
from models.mysql import conn, get_cursor
 | 
			
		||||
from models.mysql import pool
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
@router.get('/{user_id}/collect', summary="用户收藏记录", description="获取指定用户收藏记录")
 | 
			
		||||
def get_user_collect(user_id:int):
 | 
			
		||||
    # TODO: 需要验证权限
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            cursor.execute(f"SELECT content FROM web_collect  WHERE user_id={user_id} AND type='1'")
 | 
			
		||||
            data = cursor.fetchall()
 | 
			
		||||
            if not data:
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from fastapi import APIRouter, HTTPException, Header
 | 
			
		||||
from models.mysql import conn, get_cursor
 | 
			
		||||
from models.mysql import pool
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
@@ -9,8 +9,8 @@ router = APIRouter()
 | 
			
		||||
# 获取当前用户收藏记录
 | 
			
		||||
@router.get('', summary='自己的收藏记录', description='获取自己的收藏记录, 用于判断是否收藏(headers中必须附带token)')
 | 
			
		||||
def get_self_collect(token: Optional[str] = Header()):
 | 
			
		||||
    print('token: ', token)
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            # 查询用户ID
 | 
			
		||||
            cursor.execute(f"SELECT user_id FROM web_auth WHERE token={token} limit 1")
 | 
			
		||||
            data = cursor.fetchone()
 | 
			
		||||
@@ -32,7 +32,8 @@ def get_self_collect(token: Optional[str] = Header()):
 | 
			
		||||
# 获取指定用户收藏记录
 | 
			
		||||
@router.get('/{user_id}', summary='指定用的户收藏记录', description='获取指定用户收藏记录(仅测试用)')
 | 
			
		||||
def get_user_collect(user_id: int):
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    with pool.connection() as conn:
 | 
			
		||||
        with conn.cursor() as cursor:
 | 
			
		||||
            cursor.execute(f"SELECT content FROM web_collect  WHERE user_id={user_id} AND type=1")
 | 
			
		||||
            data = cursor.fetchall()                       # 获取所有记录列表
 | 
			
		||||
            data = [str(item['content']) for item in data] # 转换为数组
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user