转移
This commit is contained in:
		
							
								
								
									
										214
									
								
								routers/img.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										214
									
								
								routers/img.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,214 @@
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
import base64
 | 
			
		||||
import psutil
 | 
			
		||||
import statistics
 | 
			
		||||
import _thread as thread
 | 
			
		||||
 | 
			
		||||
from fastapi import APIRouter, HTTPException, Response
 | 
			
		||||
from urllib.parse import unquote
 | 
			
		||||
from configs.config import IMAGES_PATH
 | 
			
		||||
from models.mysql import get_cursor, conn
 | 
			
		||||
from utilities.download import download_image, generate_thumbnail
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 预热图片(获取一次图片, 遍历图片表, 检查OSS中所有被预定的尺寸是否存在, 不存在则生成)
 | 
			
		||||
@router.get("/warm", summary="预热图片", description="预热图片")
 | 
			
		||||
def warm_image(op:int=0, end:int=10, version:str='0'):
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    cursor.execute(f"SELECT * FROM `web_images` limit {op}, {end}")
 | 
			
		||||
    for img in cursor.fetchall():
 | 
			
		||||
        # 如果CPU使用率大于50%, 则等待, 直到CPU使用率小于50%
 | 
			
		||||
        while statistics.mean(psutil.cpu_percent(interval=1, percpu=True)) > 50:
 | 
			
		||||
            print(statistics.mean(psutil.cpu_percent(interval=1, percpu=True)), '等待CPU释放...')
 | 
			
		||||
            time.sleep(2)
 | 
			
		||||
        
 | 
			
		||||
        # 如果内存剩余小于1G, 则等待, 直到内存剩余大于1G
 | 
			
		||||
        while psutil.virtual_memory().available < 1024 * 1024 * 1024:
 | 
			
		||||
            print(psutil.virtual_memory().available, '等待内存释放...')
 | 
			
		||||
            time.sleep(2)
 | 
			
		||||
 | 
			
		||||
        # CPU使用率已降低, 开始处理图片
 | 
			
		||||
        image = download_image(img['content']) # 从OSS下载原图
 | 
			
		||||
        if not image:
 | 
			
		||||
            print('跳过不存在的图片:', img['content'])
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # 创建新线程处理图片
 | 
			
		||||
        try:
 | 
			
		||||
            print('开始处理图片:', img['content'])
 | 
			
		||||
            thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 3, 328, 'webp'))
 | 
			
		||||
            thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 2, 328, 'webp'))
 | 
			
		||||
            thread.start_new_thread(generate_thumbnail, (image, img['id'], version, 1, 328, 'webp'))
 | 
			
		||||
        except:
 | 
			
		||||
            print('无法启动线程')
 | 
			
		||||
    cursor.close()
 | 
			
		||||
    return Response('预热成功', status_code=200, media_type='text/plain', headers={'Content-Type': 'text/plain; charset=utf-8'})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取非标准类缩略图
 | 
			
		||||
@router.get("/{type}-{id}-{version}@{n}x{w}.{ext}", summary="获取非标准类缩略图", description="/img/article-233-version@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
 | 
			
		||||
def get_image_type_thumbnail(type:str, id:str, version:str, n:int, w:int, ext:str):
 | 
			
		||||
    img_path = f"{IMAGES_PATH}/{type}-{id}-{version}@{n}x{w}.{ext}"
 | 
			
		||||
    if os.path.exists(img_path):
 | 
			
		||||
        return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
    if type == 'ad' or type == 'article' or type == 'article_attribute':
 | 
			
		||||
        cursor = get_cursor()
 | 
			
		||||
        count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
 | 
			
		||||
        img = cursor.fetchone()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
        if img is None:
 | 
			
		||||
            print('图片不存在:', count)
 | 
			
		||||
            return Response('图片不存在', status_code=404)
 | 
			
		||||
        url = img['image']
 | 
			
		||||
    elif type == 'url':
 | 
			
		||||
        id = unquote(id, 'utf-8')
 | 
			
		||||
        id = id.replace(' ','+')
 | 
			
		||||
        url = unquote(base64.b64decode(id))
 | 
			
		||||
        print(url)
 | 
			
		||||
    elif type == 'avatar':
 | 
			
		||||
        cursor = get_cursor()
 | 
			
		||||
        count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
 | 
			
		||||
        user = cursor.fetchone()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
        if user is None:
 | 
			
		||||
            print('用户不存在:', count)
 | 
			
		||||
            return Response('用户不存在', status_code=404)
 | 
			
		||||
        url = user['avatar']
 | 
			
		||||
    else:
 | 
			
		||||
        print('图片类型不存在:', type)
 | 
			
		||||
        return Response('图片类型不存在', status_code=404)
 | 
			
		||||
    image = download_image(url)
 | 
			
		||||
    if not image:
 | 
			
		||||
        return Response('图片不存在', status_code=404)
 | 
			
		||||
    # 如果是 avatar, 则裁剪为正方形
 | 
			
		||||
    if type == 'avatar':
 | 
			
		||||
        px = image.size[0] if image.size[0] < image.size[1] else image.size[1]
 | 
			
		||||
        image = image.crop((0, 0, px, px))
 | 
			
		||||
    image.thumbnail((n*w, image.size[1]))
 | 
			
		||||
    image.save(img_path, ext, save_all=True)
 | 
			
		||||
    return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取非标准类原尺寸图
 | 
			
		||||
@router.get("/{type}-{id}-{version}.{ext}", summary="获取文章缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
 | 
			
		||||
def get_image_type(type:str, id:str, version:str, ext:str):
 | 
			
		||||
    img_path = f"{IMAGES_PATH}/{type}-{id}-{version}.{ext}"
 | 
			
		||||
    if os.path.exists(img_path):
 | 
			
		||||
        return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
    if type == 'ad' or type == 'article' or type == 'article_attribute':
 | 
			
		||||
        cursor = get_cursor()
 | 
			
		||||
        count = cursor.execute(f"SELECT * FROM `web_{type}` WHERE `id`={id}")
 | 
			
		||||
        img = cursor.fetchone()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
        if img is None:
 | 
			
		||||
            print('图片不存在:', count)
 | 
			
		||||
            return Response('图片不存在', status_code=404)
 | 
			
		||||
        url = img['image']
 | 
			
		||||
    elif type == 'url':
 | 
			
		||||
        id = unquote(id, 'utf-8')
 | 
			
		||||
        id = id.replace(' ','+')
 | 
			
		||||
        url = unquote(base64.b64decode(id))
 | 
			
		||||
        print("url:", url)
 | 
			
		||||
    elif type == 'avatar':
 | 
			
		||||
        cursor = get_cursor()
 | 
			
		||||
        count = cursor.execute(f"SELECT avatar FROM `web_member` WHERE `id`={id}")
 | 
			
		||||
        user = cursor.fetchone()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
        if user is None:
 | 
			
		||||
            print('用户不存在:', count)
 | 
			
		||||
            return Response('用户不存在', status_code=404)
 | 
			
		||||
        url = user['avatar']
 | 
			
		||||
    else:
 | 
			
		||||
       print('图片类型不存在:', type)
 | 
			
		||||
       return Response('图片类型不存在', status_code=404)
 | 
			
		||||
    image = download_image(url)
 | 
			
		||||
    image.save(img_path, ext, save_all=True)
 | 
			
		||||
    return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 通过url获取图片
 | 
			
		||||
@router.get("/url-{url}@{n}x{w}.{ext}", summary="通过url获取图片", description="/img/article-233.webp")
 | 
			
		||||
def get_image_url(url:str, n:int, w:int, ext:str):
 | 
			
		||||
    img_path = f"{IMAGES_PATH}/{type}-{url}.{ext}"
 | 
			
		||||
    if os.path.exists(img_path):
 | 
			
		||||
        return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
    url = unquote(url, 'utf-8').replace(' ','+')
 | 
			
		||||
    url = unquote(base64.b64decode(url))
 | 
			
		||||
    image = download_image(url)
 | 
			
		||||
    if not image:
 | 
			
		||||
        return Response('图片不存在', status_code=404)
 | 
			
		||||
    image.thumbnail((n*w, image.size[1]))
 | 
			
		||||
    image.save(img_path, ext, save_all=True)
 | 
			
		||||
    return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取标准缩略图(带版本号)
 | 
			
		||||
@router.get("/{id}-{version}@{n}x{w}.{ext}", summary="获取缩略图(带版本号)", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
 | 
			
		||||
def get_image_thumbnail(id:int, version:str, n:int, w:int, ext:str):
 | 
			
		||||
    # 判断图片是否已经生成
 | 
			
		||||
    img_path = f"{IMAGES_PATH}/{id}-{version}@{n}x{w}.{ext}"
 | 
			
		||||
    if os.path.exists(img_path):
 | 
			
		||||
        return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
    # 从数据库获取原图地址
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
 | 
			
		||||
    img = cursor.fetchone()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
    if img is None:
 | 
			
		||||
        print('图片不存在:', id)
 | 
			
		||||
        return Response('图片不存在', status_code=404)
 | 
			
		||||
    image = download_image(img['content'])
 | 
			
		||||
    if not image:
 | 
			
		||||
        return Response('图片不存在', status_code=404)
 | 
			
		||||
    image.thumbnail((n*w, image.size[1]))
 | 
			
		||||
    image.save(img_path, ext, save_all=True)
 | 
			
		||||
    return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取标准缩略图
 | 
			
		||||
@router.get("/{id}@{n}x{w}.{ext}", summary="获取缩略图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 通过@1x320后缀获取1x缩略图, 通过@2x320后缀获取2x缩略图, 通过@3x320后缀获取3x缩略图")
 | 
			
		||||
def get_image_thumbnail(id:int, n:int, w:int, ext:str):
 | 
			
		||||
    # 判断图片是否已经生成
 | 
			
		||||
    img_path = f"{IMAGES_PATH}/{id}@{n}x{w}.{ext}"
 | 
			
		||||
    if os.path.exists(img_path):
 | 
			
		||||
        return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
    # 从数据库获取原图地址
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
 | 
			
		||||
    img = cursor.fetchone()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
    if img is None:
 | 
			
		||||
        print('图片不存在:', id)
 | 
			
		||||
        return Response('图片不存在', status_code=404)
 | 
			
		||||
    image = download_image(img['content'])
 | 
			
		||||
    if not image:
 | 
			
		||||
        return Response('图片不存在', status_code=404)
 | 
			
		||||
    image.thumbnail((n*w, image.size[1]))
 | 
			
		||||
    image.save(img_path, ext, save_all=True)
 | 
			
		||||
    return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取标准原尺寸图
 | 
			
		||||
@router.get("/{id}.{ext}", summary="获取标准原尺寸图", description="/img/233@1x320.webp 通过webp后缀获取webp格式图片, 无后缀获取原图")
 | 
			
		||||
def get_image(id: int = 824, ext: str = 'webp'):
 | 
			
		||||
    # 判断图片是否已经生成
 | 
			
		||||
    img_path = f"{IMAGES_PATH}/{id}.{ext}"
 | 
			
		||||
    if os.path.exists(img_path):
 | 
			
		||||
        return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
    # 从数据库获取原图地址
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    cursor.execute(f"SELECT * FROM `web_images` WHERE `id`={id}")
 | 
			
		||||
    img = cursor.fetchone()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
    if img is None:
 | 
			
		||||
        print('图片不存在:', id)
 | 
			
		||||
        return Response('图片不存在', status_code=404)
 | 
			
		||||
    image = download_image(img['content'])
 | 
			
		||||
    if not image:
 | 
			
		||||
        return Response('图片不存在', status_code=404)
 | 
			
		||||
    image.save(img_path, ext, save_all=True)
 | 
			
		||||
    return Response(content=open(img_path, 'rb').read(), media_type=f"image/{ext}")
 | 
			
		||||
							
								
								
									
										231
									
								
								routers/reverse.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										231
									
								
								routers/reverse.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,231 @@
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import string
 | 
			
		||||
import sqlite3
 | 
			
		||||
import numpy as np
 | 
			
		||||
import time
 | 
			
		||||
import io
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from towhee import pipe, ops
 | 
			
		||||
 | 
			
		||||
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
 | 
			
		||||
from models.milvus import get_collection, collection_name
 | 
			
		||||
from models.mysql import get_cursor, conn
 | 
			
		||||
#from models.resnet import Resnet50
 | 
			
		||||
 | 
			
		||||
from configs.config import UPLOAD_PATH
 | 
			
		||||
from utilities.download import download_image
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
#MODEL = Resnet50()
 | 
			
		||||
 | 
			
		||||
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
 | 
			
		||||
 | 
			
		||||
# 获取状态统计
 | 
			
		||||
@router.get('', summary='状态统计', description='通过表名获取状态统计')
 | 
			
		||||
def count_images():
 | 
			
		||||
    collection = get_collection(collection_name)
 | 
			
		||||
    return {'status': True, 'count': collection.num_entities}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 手动重建索引
 | 
			
		||||
@router.get('/create_index', summary='重建索引', description='手动重建索引', include_in_schema=False)
 | 
			
		||||
async def create_index():
 | 
			
		||||
    collection = get_collection(collection_name)
 | 
			
		||||
    default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
 | 
			
		||||
    collection.create_index(field_name="embedding", index_params=default_index)
 | 
			
		||||
    collection.load()
 | 
			
		||||
    return {'status': True, 'count': collection.num_entities}
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
# 批量生成向量
 | 
			
		||||
@router.get('/create_vector', summary='生成向量', description='手动生成向量', include_in_schema=False)
 | 
			
		||||
async def create_vector(count: int = 10):
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    cursor.execute(f"SELECT id,thumbnail_image,article_id,milvus_id FROM `web_images` WHERE thumbnail_image IS NOT NULL AND article_id IS NOT NULL AND milvus_id != 2048 AND width IS NOT NULL LIMIT 0,{count}")
 | 
			
		||||
    images = cursor.fetchall()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
    for item in images:
 | 
			
		||||
        print(item)
 | 
			
		||||
        # 先查询 milvus 中是否存在
 | 
			
		||||
        collection = get_collection(collection_name)
 | 
			
		||||
        data = collection.query(expr=f'id in [{item["id"]}]', output_fields=None, partition_names=None, timeout=None) # offset, limit
 | 
			
		||||
        if len(data) > 0:
 | 
			
		||||
            cursor = get_cursor()
 | 
			
		||||
            cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
 | 
			
		||||
            conn.commit()
 | 
			
		||||
            cursor.close()
 | 
			
		||||
            continue
 | 
			
		||||
        try:
 | 
			
		||||
            img_path = os.path.join(UPLOAD_PATH, os.path.basename(item['thumbnail_image']))
 | 
			
		||||
            download_image(item['thumbnail_image']).save(img_path, 'png', save_all=True)
 | 
			
		||||
            feat = MODEL.resnet50_extract_feat(img_path)
 | 
			
		||||
            collection.insert([[item['id']], [feat], [item['article_id']]])
 | 
			
		||||
            cursor = get_cursor()
 | 
			
		||||
            cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
 | 
			
		||||
            conn.commit()
 | 
			
		||||
            cursor.close()
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(e)
 | 
			
		||||
    print('END')
 | 
			
		||||
    return images
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
 | 
			
		||||
async def rewrite_image(image_id: int):
 | 
			
		||||
    print('START', image_id, '重建向量')
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
 | 
			
		||||
    img = cursor.fetchone()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
    if img is None:
 | 
			
		||||
        print('mysql中原始图片不存在:', image_id)
 | 
			
		||||
        return Response('图片不存在', status_code=404)
 | 
			
		||||
    img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
 | 
			
		||||
    print('img_path', img_path)
 | 
			
		||||
    image = download_image(img['content'])
 | 
			
		||||
    if image is None:
 | 
			
		||||
        print('图片下载失败:', img['content'])
 | 
			
		||||
        return Response('图片下载失败', status_code=404)
 | 
			
		||||
    image.save(img_path, 'png', save_all=True)
 | 
			
		||||
 | 
			
		||||
    with Image.open(img_path) as imgx:
 | 
			
		||||
        feat = RESNET50(imgx).get()[0]
 | 
			
		||||
 | 
			
		||||
    collection = get_collection(collection_name)
 | 
			
		||||
    collection.delete(expr=f'id in [{image_id}]')
 | 
			
		||||
    rest = collection.insert([[image_id], [feat], [img['article_id']]])
 | 
			
		||||
    os.remove(img_path)
 | 
			
		||||
    print('END', image_id, '重建向量', rest.primary_keys)
 | 
			
		||||
    return {"code": 0, "status": True, "message": "重建成功", "feature": feat.tolist()}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取相似(废弃)
 | 
			
		||||
@router.get('/{image_id}', summary='获取相似', description='通过图片ID获取相似图片')
 | 
			
		||||
async def similar_images(image_id: int, page: int = 1, pageSize: int = 20):
 | 
			
		||||
    collection = get_collection(collection_name)
 | 
			
		||||
    result = collection.query(expr=f'id in [{image_id}]', output_fields = ['id', 'article_id', 'embedding'], top_k=1)
 | 
			
		||||
    # 如果没有结果, 则重新生成记录
 | 
			
		||||
    if len(result) == 0:
 | 
			
		||||
        cursor = get_cursor()
 | 
			
		||||
        cursor.execute(f"SELECT * FROM `web_images` WHERE id={image_id}")
 | 
			
		||||
        img = cursor.fetchone()
 | 
			
		||||
        cursor.close()
 | 
			
		||||
        if img is None:
 | 
			
		||||
            print('mysql 中图片不存在:', image_id)
 | 
			
		||||
            return Response('图片不存在', status_code=404)
 | 
			
		||||
        img_path = os.path.join(UPLOAD_PATH, os.path.basename(img['content']))
 | 
			
		||||
        image = download_image(img['content'])
 | 
			
		||||
        if image is None:
 | 
			
		||||
            print('图片下载失败:', img['content'])
 | 
			
		||||
            return Response('图片下载失败', status_code=404)
 | 
			
		||||
        image.save(img_path, 'png', save_all=True)
 | 
			
		||||
        with Image.open(img_path) as imgx:
 | 
			
		||||
            feat = RESNET50(imgx).get()[0]
 | 
			
		||||
            # 移除可能存在的旧记录, 换上新的
 | 
			
		||||
            collection.delete(expr=f'id in [{image_id}]')
 | 
			
		||||
            collection.insert([[image_id], [feat], [img['article_id']]])
 | 
			
		||||
        os.remove(img_path)
 | 
			
		||||
        print('生成')
 | 
			
		||||
    else:
 | 
			
		||||
        print('通过')
 | 
			
		||||
        feat = result[0]['embedding']
 | 
			
		||||
    res = collection.search([feat],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=200)[0]
 | 
			
		||||
    # 翻页(截取有效范围, page * pageize)
 | 
			
		||||
    ope = page*pageSize-pageSize
 | 
			
		||||
    end = page*pageSize
 | 
			
		||||
    next = False
 | 
			
		||||
    # 为数据附加信息
 | 
			
		||||
    ids = [i.id for i in res] # 获取所有ID
 | 
			
		||||
    if len(res) <= end:
 | 
			
		||||
        ids = ids[ope:]
 | 
			
		||||
        next = False
 | 
			
		||||
    else:
 | 
			
		||||
        ids = ids[ope:end]
 | 
			
		||||
        next = True
 | 
			
		||||
    str_ids = str(ids).replace('[', '').replace(']', '')
 | 
			
		||||
    if str_ids == '':
 | 
			
		||||
        print('没有更多数据了')
 | 
			
		||||
        return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({str_ids})")
 | 
			
		||||
    imgs = cursor.fetchall()
 | 
			
		||||
    if len(imgs) == 0:
 | 
			
		||||
        return imgs
 | 
			
		||||
    # 获取用户ID和文章ID
 | 
			
		||||
    uids = list(set([x['user_id'] for x in imgs]))
 | 
			
		||||
    tids = list(set([x['article_id'] for x in imgs]))
 | 
			
		||||
    # 获取用户信息
 | 
			
		||||
    cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
 | 
			
		||||
    users = cursor.fetchall()
 | 
			
		||||
    # 获取文章信息
 | 
			
		||||
    cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
 | 
			
		||||
    articles = cursor.fetchall()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
    # 合并信息
 | 
			
		||||
    user, article = {}, {}
 | 
			
		||||
    for x in users: user[x['id']] = x
 | 
			
		||||
    for x in articles: article[x['id']] = x
 | 
			
		||||
    for x in imgs:
 | 
			
		||||
        x['article'] = article[x['article_id']]
 | 
			
		||||
        x['user'] = user[x['user_id']]
 | 
			
		||||
        x['distance'] = [i.distance for i in res if i.id == x['id']][0]
 | 
			
		||||
        if x['praise_count'] == None: x['praise_count'] = 0
 | 
			
		||||
        if x['collect_count'] == None: x['collect_count'] = 0
 | 
			
		||||
        # 将字段名转换为驼峰
 | 
			
		||||
        x['createTime'] = x.pop('create_time')
 | 
			
		||||
        x['updateTime'] = x.pop('update_time')
 | 
			
		||||
    # 对 imgs 重新排序(按照 distance 字段)
 | 
			
		||||
    imgs = sorted(imgs, key=lambda x: x['distance'])
 | 
			
		||||
    return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': next, 'list': imgs}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post(path='', summary='以图搜图', description='上传图片进行搜索')
 | 
			
		||||
async def search_imagex(image: UploadFile = File(...), page: int = 1, pageSize: int = 20):
 | 
			
		||||
    content = await image.read()
 | 
			
		||||
    img = Image.open(image.file)
 | 
			
		||||
    embeddig = RESNET50(img).get()[0]
 | 
			
		||||
    collection = get_collection('default')
 | 
			
		||||
    res = collection.search([embeddig],anns_field="embedding",param={"metric_type": 'L2', "params": {"nprobe": 16}}, output_fields=["id", "article_id"], limit=500)[0]
 | 
			
		||||
    ope, end = (page - 1) * pageSize, page * pageSize
 | 
			
		||||
    ids, nextx = [x.id for x in res][ope:end], len(res) > end
 | 
			
		||||
 | 
			
		||||
    if not ids:
 | 
			
		||||
        return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': False, 'list': []}
 | 
			
		||||
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    cursor.execute(f"SELECT id,user_id,article_id,width,height,content,larger_image,thumbnail_image,article_category_top_id,praise_count,collect_count,create_time,update_time FROM `web_images` WHERE id IN ({','.join(map(str, ids))})")
 | 
			
		||||
    imgs = cursor.fetchall()
 | 
			
		||||
 | 
			
		||||
    if not imgs:
 | 
			
		||||
        return imgs
 | 
			
		||||
 | 
			
		||||
    uids, tids = list(set(x['user_id'] for x in imgs)), list(set(x['article_id'] for x in imgs))
 | 
			
		||||
    cursor.execute(f"SELECT id,user_name,avatar FROM `web_member` WHERE id IN ({str(uids).replace('[', '').replace(']', '')})")
 | 
			
		||||
    users = cursor.fetchall()
 | 
			
		||||
    cursor.execute(f"SELECT id,title,tags FROM `web_article` WHERE id IN ({str(tids).replace('[', '').replace(']', '')})")
 | 
			
		||||
    articles = cursor.fetchall()
 | 
			
		||||
    cursor.close()
 | 
			
		||||
 | 
			
		||||
    user, article = {x['id']: x for x in users}, {x['id']: x for x in articles}
 | 
			
		||||
    for x in imgs:
 | 
			
		||||
        x.update({
 | 
			
		||||
            'article': article[x['article_id']],
 | 
			
		||||
            'user': user[x['user_id']],
 | 
			
		||||
            #'distance': next(i.distance for i in res if i.id == x['id'], 0),
 | 
			
		||||
            'distance': [i.distance for i in res if i.id == x['id']][0],
 | 
			
		||||
            'praise_count': x.get('praise_count', 0),
 | 
			
		||||
            'collect_count': x.get('collect_count', 0)
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
    imgs.sort(key=lambda x: x['distance'])
 | 
			
		||||
    return {'code': 0, 'pageSize': pageSize, 'page': page, 'next': nextx, 'list': imgs}
 | 
			
		||||
 | 
			
		||||
@router.delete('/{thread_id}', summary="删除主题", description="删除指定主题下的所有图像")
 | 
			
		||||
async def delete_images(thread_id: str):
 | 
			
		||||
    collection = get_collection(collection_name)
 | 
			
		||||
    collection.delete(expr="article_id in ["+thread_id+"]")
 | 
			
		||||
    collection.load()
 | 
			
		||||
    default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
 | 
			
		||||
    collection.create_index(field_name="embedding", index_params=default_index)
 | 
			
		||||
    return {"status": True, 'msg': '删除完毕'}
 | 
			
		||||
							
								
								
									
										94
									
								
								routers/task.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								routers/task.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,94 @@
 | 
			
		||||
from fastapi import APIRouter, HTTPException, Request, WebSocket
 | 
			
		||||
from fastapi.responses import HTMLResponse
 | 
			
		||||
from fastapi.templating import Jinja2Templates
 | 
			
		||||
from models.task import TaskForm, TaskManager
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
task_manager = TaskManager()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 创建新任务
 | 
			
		||||
@router.post("", summary="创建新任务")
 | 
			
		||||
def create_task(form: TaskForm):
 | 
			
		||||
    task = task_manager.add_task(name=form.name, user_id=123456, description=form.description)
 | 
			
		||||
    return task
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取任务列表
 | 
			
		||||
@router.get("", summary="获取任务列表", description="可使用user_id参数筛选指定用户的任务列表")
 | 
			
		||||
def get_task_list(user_id: int=None):
 | 
			
		||||
    return task_manager.get_tasks(user_id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# websocket demo
 | 
			
		||||
@router.get("/demo", response_class=HTMLResponse)
 | 
			
		||||
async def websocket_demo(request: Request):
 | 
			
		||||
    task_list = task_manager.get_tasks()
 | 
			
		||||
    templates = Jinja2Templates(directory="templates")
 | 
			
		||||
    return templates.TemplateResponse("websocket.html", {"request": request, "task_list": task_list})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 监听指定任务的变化事件, 通知前端(不使用pydantic模型, 正确的写法)
 | 
			
		||||
@router.websocket("/{task_id}", name="监听任务变化")
 | 
			
		||||
async def websocket_endpoint(task_id: str, websocket: WebSocket):
 | 
			
		||||
    await websocket.accept()
 | 
			
		||||
    await task_manager.add_websocket(task_id, websocket)
 | 
			
		||||
    async for data in websocket.iter_text():
 | 
			
		||||
        await websocket.send_text(f"Message text was: {data}")
 | 
			
		||||
    task_manager.remove_websocket(task_id, websocket)
 | 
			
		||||
    print("websocket 连接已自动关闭")
 | 
			
		||||
 | 
			
		||||
    #await websocket.send_json({"message": "Hello WebSocket!"})
 | 
			
		||||
    #task = task_manager.get_task(task_id)
 | 
			
		||||
    #print(task)
 | 
			
		||||
    #if not task:
 | 
			
		||||
    #    print("task 不存在, 结束连接")
 | 
			
		||||
    #    return await websocket.close()             # 任务不存在, 结束连接
 | 
			
		||||
    #await websocket.send_json(task)                # 将任务的状态发送给客户端
 | 
			
		||||
    #await task_manager.add_websocket(task_id, websocket)
 | 
			
		||||
    # 正确的写法, 使用 async for, 并且处理意外断开的情况
 | 
			
		||||
    #try:
 | 
			
		||||
    #    async for data in websocket.iter_text():
 | 
			
		||||
    #        if data == "close":
 | 
			
		||||
    #            print("客户端主动关闭连接")
 | 
			
		||||
    #            task_manager.remove_websocket(task_id, websocket)
 | 
			
		||||
    #            await websocket.close()
 | 
			
		||||
    #            break
 | 
			
		||||
    #        else:
 | 
			
		||||
    #            print(f"接收到客户端消息: {data}")
 | 
			
		||||
    #            await websocket.send_text(f"Message text was: {data}")
 | 
			
		||||
    #except Exception as e:
 | 
			
		||||
    #    print(f"客户端意外断开连接: {e}")
 | 
			
		||||
    #    task_manager.remove_websocket(task_id, websocket)
 | 
			
		||||
    #    #await websocket.close()
 | 
			
		||||
 | 
			
		||||
    # 监听客户端的状态, 如果客户端主动关闭连接或意外断开连接, 都从任务的websocket列表中移除
 | 
			
		||||
    #while True:
 | 
			
		||||
    #    try:
 | 
			
		||||
    #        data = await websocket.receive_text()
 | 
			
		||||
    #        if data == "close":
 | 
			
		||||
    #            print("客户端主动关闭连接")
 | 
			
		||||
    #            task_manager.remove_websocket(task_id, websocket)
 | 
			
		||||
    #            await websocket.close()
 | 
			
		||||
    #            break
 | 
			
		||||
    #        else:
 | 
			
		||||
    #            print(f"接收到客户端消息: {data}")
 | 
			
		||||
    #            await websocket.send_text(f"Message text was: {data}")
 | 
			
		||||
    #    except Exception as e:
 | 
			
		||||
    #        print(f"客户端意外断开连接: {e}")
 | 
			
		||||
    #        task_manager.remove_websocket(task_id, websocket)
 | 
			
		||||
    #        #await websocket.close()
 | 
			
		||||
    #        break
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取任务详情
 | 
			
		||||
@router.get("/{task_id}", summary="获取任务详情")
 | 
			
		||||
def get_task(task_id: int):
 | 
			
		||||
    return get_task(task_id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 删除任务
 | 
			
		||||
@router.delete("/{task_id}", summary="删除指定任务")
 | 
			
		||||
def delete_task(task_id: int):
 | 
			
		||||
    return delete_task(task_id)
 | 
			
		||||
							
								
								
									
										19
									
								
								routers/user.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								routers/user.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
from fastapi import APIRouter, HTTPException
 | 
			
		||||
from models.mysql import conn, get_cursor
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
@router.get('/{user_id}/collect', summary="用户收藏记录", description="获取指定用户收藏记录")
 | 
			
		||||
def get_user_collect(user_id:int):
 | 
			
		||||
    # TODO: 需要验证权限
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    cursor.execute(f"SELECT content FROM web_collect  WHERE user_id={user_id} AND type='1'")
 | 
			
		||||
    data = cursor.fetchall()
 | 
			
		||||
    if not data:
 | 
			
		||||
        return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []}
 | 
			
		||||
    data = [str(item['content']) for item in data]
 | 
			
		||||
    cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
 | 
			
		||||
    data = cursor.fetchall()
 | 
			
		||||
    data = [str(item['id']) for item in data]
 | 
			
		||||
    return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data }
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										45
									
								
								routers/user_collect.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								routers/user_collect.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,45 @@
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from fastapi import APIRouter, HTTPException, Header
 | 
			
		||||
from models.mysql import conn, get_cursor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取当前用户收藏记录
 | 
			
		||||
@router.get('', summary='自己的收藏记录', description='获取自己的收藏记录, 用于判断是否收藏(headers中必须附带token)')
 | 
			
		||||
def get_self_collect(token: Optional[str] = Header()):
 | 
			
		||||
    print('token: ', token)
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    # 查询用户ID
 | 
			
		||||
    cursor.execute(f"SELECT user_id FROM web_auth WHERE token={token} limit 1")
 | 
			
		||||
    data = cursor.fetchone()
 | 
			
		||||
    print('auth: ', data)
 | 
			
		||||
    if not data:
 | 
			
		||||
        raise HTTPException(status_code=401, detail="用户未登录")
 | 
			
		||||
    user_id = data['user_id']
 | 
			
		||||
    # 查询收藏记录
 | 
			
		||||
    cursor.execute(f"SELECT content FROM web_collect  WHERE user_id={user_id} AND type='1'")
 | 
			
		||||
    data = cursor.fetchall()                       # 获取所有记录列表
 | 
			
		||||
    data = [str(item['content']) for item in data] # 转换为数组
 | 
			
		||||
    # 查询图片ID(对特殊字符安全转义)
 | 
			
		||||
    cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
 | 
			
		||||
    data = cursor.fetchall()
 | 
			
		||||
    data = [str(item['id']) for item in data]
 | 
			
		||||
    return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取指定用户收藏记录
 | 
			
		||||
@router.get('/{user_id}', summary='指定用的户收藏记录', description='获取指定用户收藏记录(仅测试用)')
 | 
			
		||||
def get_user_collect(user_id: int):
 | 
			
		||||
    cursor = get_cursor()
 | 
			
		||||
    cursor.execute(f"SELECT content FROM web_collect  WHERE user_id={user_id} AND type=1")
 | 
			
		||||
    data = cursor.fetchall()                       # 获取所有记录列表
 | 
			
		||||
    data = [str(item['content']) for item in data] # 转换为数组
 | 
			
		||||
    if not data:
 | 
			
		||||
        return {'code': 0, 'user_id': user_id, 'total': 0, 'data': []}
 | 
			
		||||
    # 查询图片ID(对特殊字符安全转义)
 | 
			
		||||
    cursor.execute(f"SELECT id FROM web_images WHERE content IN ({','.join(['%s'] * len(data))})", data)
 | 
			
		||||
    data = cursor.fetchall()
 | 
			
		||||
    data = [str(item['id']) for item in data]
 | 
			
		||||
    return {'code': 0, 'user_id': user_id, 'total': len(data), 'data': data}
 | 
			
		||||
		Reference in New Issue
	
	Block a user