From ee88acf9a80dc0bc42a030a92c54cd5206b284aa Mon Sep 17 00:00:00 2001 From: satori Date: Mon, 4 Nov 2024 05:58:03 +0800 Subject: [PATCH] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/milvus.py | 4 +-- models/mysql.py | 27 +++++++++------- models/task.py | 79 ---------------------------------------------- routers/reverse.py | 36 +-------------------- 4 files changed, 18 insertions(+), 128 deletions(-) diff --git a/models/milvus.py b/models/milvus.py index 7234664..147014d 100644 --- a/models/milvus.py +++ b/models/milvus.py @@ -1,6 +1,6 @@ import pymilvus -from configs.config import MILVUS_HOST +from configs.config import MILVUS_HOST, MILVUS_PORT # 连接 Milvus (开启 Milvus 服务) collection_name = 'default' @@ -8,7 +8,7 @@ collection_name = 'default' # 获取 Milvus 连接 def get_collection(collection_name): - pymilvus.connections.connect(host=MILVUS_HOST, port='19530') + pymilvus.connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) if not pymilvus.utility.has_collection(collection_name): field1 = pymilvus.FieldSchema(name="id", dtype=pymilvus.DataType.INT64, is_primary=True) field2 = pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, dim=2048) diff --git a/models/mysql.py b/models/mysql.py index 607d194..4c81aa5 100644 --- a/models/mysql.py +++ b/models/mysql.py @@ -1,24 +1,27 @@ import pymysql from configs.config import MYSQL_HOST, MYSQL_PORT, MYSQL_NAME, MYSQL_USER, MYSQL_PASS +# 创建 MySQL 连接 +def create_connection(): + return pymysql.connect( + host=MYSQL_HOST, + user=MYSQL_USER, + port=MYSQL_PORT, # 应该使用 MYSQL_PORT 而不是 MYSQL_HOST + password=MYSQL_PASS, + database=MYSQL_NAME, + local_infile=True, + cursorclass=pymysql.cursors.DictCursor + ) # 连接 MySQL (开启 MySQL 服务) -conn = pymysql.connect( - host=MYSQL_HOST, - user=MYSQL_USER, - port=MYSQL_HOST, - password=MYSQL_PASS, - database=MYSQL_NAME, - local_infile=True, - cursorclass=pymysql.cursors.DictCursor -) - +conn = create_connection() # 获取 MySQL 连接 def get_cursor(): + global conn try: - dx = conn.ping() + conn.ping() return conn.cursor() except Exception: - conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor) + conn = create_connection() return conn.cursor() diff --git a/models/task.py b/models/task.py index 4baa885..2f379c3 100644 --- a/models/task.py +++ b/models/task.py @@ -5,8 +5,6 @@ from pydantic import BaseModel from configs.config import SQLITE3_PATH import uuid import time -import requests - __taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False) __taskdb.row_factory = sqlite3.Row @@ -104,81 +102,4 @@ class TaskManager: break print(f"get task status:{task.id}") break - #data = requests.get(f"http://localhost:5001/api/task/{task.id}").json() - #print(data) - ## 通过 http 请求获取任务的状态 - ## 如果任务状态发生变化, 则将任务状态发送给所有监听该任务的 websocket - #if task.status != data.status: - # task.status = data.status - # print("send task status to websocket") - # for websocket in task.websockets: - # await websocket.send_json(task) - - - -## 任务实体模型 -#class TaskModel(BaseModel): -# id: str -# name: str -# user_id: int -# description: str -# created_at: str -# updated_at: str -# websockets: list=[] -# -## 向数据库中添加任务 -#def add_task(name, user_id, description): -# print(name, user_id, description) -# date = datetime.datetime.now() -# cursor = __taskdb.cursor() -# cursor.execute("INSERT INTO tasks(name,user_id,description,created_at,updated_at) VALUES (?, ?, ?, ?, ?)", (name,user_id,description,date,date)) -# id = cursor.lastrowid -# cursor.close() -# __taskdb.commit() -# return {"id": id,"name": name,"user_id": user_id,"description": description,"created_at": date,"updated_at": date} -# -# -## 从数据库中获取任务列表 -#def get_tasks(user_id: int=None): -# cursor = __taskdb.cursor() -# if user_id: -# cursor.execute("SELECT * FROM tasks WHERE user_id = ?", (user_id,)) -# else: -# cursor.execute("SELECT * FROM tasks") -# tasks = cursor.fetchall() -# cursor.close() -# return [dict(task) for task in tasks] -# -# -## 从数据库中获取指定任务 -#def get_task(task_id: int): -# cursor = __taskdb.cursor() -# cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)) -# task = cursor.fetchone() -# cursor.close() -# return task -# -# -## 更新数据库中的任务 -#def update_task(task: TaskModel): -# cursor = __taskdb.cursor() -# cursor.execute("UPDATE tasks SET name = ?, description = ?, updated_at = ? WHERE id = ?", ( -# task.name, -# task.description, -# datetime.datetime.now(), -# task.id -# )) -# cursor.close() -# __taskdb.commit() -# return task -# -# -## 从数据库中删除任务 -#def delete_task(task_id: int): -# cursor = __taskdb.cursor() -# cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,)) -# cursor.close() -# __taskdb.commit() -# return task_id - diff --git a/routers/reverse.py b/routers/reverse.py index 51a43e6..d986c65 100644 --- a/routers/reverse.py +++ b/routers/reverse.py @@ -11,13 +11,11 @@ from towhee import pipe, ops from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header from models.milvus import get_collection, collection_name from models.mysql import get_cursor, conn -#from models.resnet import Resnet50 from configs.config import UPLOAD_PATH from utilities.download import download_image router = APIRouter() -#MODEL = Resnet50() RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec')) @@ -37,40 +35,8 @@ async def create_index(): collection.load() return {'status': True, 'count': collection.num_entities} -''' -# 批量生成向量 -@router.get('/create_vector', summary='生成向量', description='手动生成向量', include_in_schema=False) -async def create_vector(count: int = 10): - cursor = get_cursor() - cursor.execute(f"SELECT id,thumbnail_image,article_id,milvus_id FROM `web_images` WHERE thumbnail_image IS NOT NULL AND article_id IS NOT NULL AND milvus_id != 2048 AND width IS NOT NULL LIMIT 0,{count}") - images = cursor.fetchall() - cursor.close() - for item in images: - print(item) - # 先查询 milvus 中是否存在 - collection = get_collection(collection_name) - data = collection.query(expr=f'id in [{item["id"]}]', output_fields=None, partition_names=None, timeout=None) # offset, limit - if len(data) > 0: - cursor = get_cursor() - cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}") - conn.commit() - cursor.close() - continue - try: - img_path = os.path.join(UPLOAD_PATH, os.path.basename(item['thumbnail_image'])) - download_image(item['thumbnail_image']).save(img_path, 'png', save_all=True) - feat = MODEL.resnet50_extract_feat(img_path) - collection.insert([[item['id']], [feat], [item['article_id']]]) - cursor = get_cursor() - cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}") - conn.commit() - cursor.close() - except Exception as e: - print(e) - print('END') - return images -''' +# 重建指定图像的向量 @router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量') async def rewrite_image(image_id: int): print('START', image_id, '重建向量')