移除配置
This commit is contained in:
		@@ -1,6 +1,6 @@
 | 
				
			|||||||
import pymilvus
 | 
					import pymilvus
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from configs.config import MILVUS_HOST
 | 
					from configs.config import MILVUS_HOST, MILVUS_PORT
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 连接 Milvus (开启 Milvus 服务)
 | 
					# 连接 Milvus (开启 Milvus 服务)
 | 
				
			||||||
collection_name = 'default'
 | 
					collection_name = 'default'
 | 
				
			||||||
@@ -8,7 +8,7 @@ collection_name = 'default'
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# 获取 Milvus 连接
 | 
					# 获取 Milvus 连接
 | 
				
			||||||
def get_collection(collection_name):
 | 
					def get_collection(collection_name):
 | 
				
			||||||
    pymilvus.connections.connect(host=MILVUS_HOST, port='19530')
 | 
					    pymilvus.connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
 | 
				
			||||||
    if not pymilvus.utility.has_collection(collection_name):
 | 
					    if not pymilvus.utility.has_collection(collection_name):
 | 
				
			||||||
        field1 = pymilvus.FieldSchema(name="id", dtype=pymilvus.DataType.INT64, is_primary=True)
 | 
					        field1 = pymilvus.FieldSchema(name="id", dtype=pymilvus.DataType.INT64, is_primary=True)
 | 
				
			||||||
        field2 = pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, dim=2048)
 | 
					        field2 = pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, dim=2048)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,24 +1,27 @@
 | 
				
			|||||||
import pymysql
 | 
					import pymysql
 | 
				
			||||||
from configs.config import MYSQL_HOST, MYSQL_PORT, MYSQL_NAME, MYSQL_USER, MYSQL_PASS
 | 
					from configs.config import MYSQL_HOST, MYSQL_PORT, MYSQL_NAME, MYSQL_USER, MYSQL_PASS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 创建 MySQL 连接
 | 
				
			||||||
 | 
					def create_connection():
 | 
				
			||||||
 | 
					    return pymysql.connect(
 | 
				
			||||||
 | 
					        host=MYSQL_HOST,
 | 
				
			||||||
 | 
					        user=MYSQL_USER,
 | 
				
			||||||
 | 
					        port=MYSQL_PORT,  # 应该使用 MYSQL_PORT 而不是 MYSQL_HOST
 | 
				
			||||||
 | 
					        password=MYSQL_PASS,
 | 
				
			||||||
 | 
					        database=MYSQL_NAME,
 | 
				
			||||||
 | 
					        local_infile=True,
 | 
				
			||||||
 | 
					        cursorclass=pymysql.cursors.DictCursor
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 连接 MySQL (开启 MySQL 服务)
 | 
					# 连接 MySQL (开启 MySQL 服务)
 | 
				
			||||||
conn = pymysql.connect(
 | 
					conn = create_connection()
 | 
				
			||||||
    host=MYSQL_HOST,
 | 
					 | 
				
			||||||
    user=MYSQL_USER,
 | 
					 | 
				
			||||||
    port=MYSQL_HOST,
 | 
					 | 
				
			||||||
    password=MYSQL_PASS,
 | 
					 | 
				
			||||||
    database=MYSQL_NAME,
 | 
					 | 
				
			||||||
    local_infile=True,
 | 
					 | 
				
			||||||
    cursorclass=pymysql.cursors.DictCursor
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 获取 MySQL 连接
 | 
					# 获取 MySQL 连接
 | 
				
			||||||
def get_cursor():
 | 
					def get_cursor():
 | 
				
			||||||
 | 
					    global conn
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        dx = conn.ping()
 | 
					        conn.ping()
 | 
				
			||||||
        return conn.cursor()
 | 
					        return conn.cursor()
 | 
				
			||||||
    except Exception:
 | 
					    except Exception:
 | 
				
			||||||
        conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor)
 | 
					        conn = create_connection()
 | 
				
			||||||
        return conn.cursor()
 | 
					        return conn.cursor()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,8 +5,6 @@ from pydantic import BaseModel
 | 
				
			|||||||
from configs.config import SQLITE3_PATH
 | 
					from configs.config import SQLITE3_PATH
 | 
				
			||||||
import uuid
 | 
					import uuid
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import requests
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False)
 | 
					__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False)
 | 
				
			||||||
__taskdb.row_factory = sqlite3.Row
 | 
					__taskdb.row_factory = sqlite3.Row
 | 
				
			||||||
@@ -104,81 +102,4 @@ class TaskManager:
 | 
				
			|||||||
                break
 | 
					                break
 | 
				
			||||||
            print(f"get task status:{task.id}")
 | 
					            print(f"get task status:{task.id}")
 | 
				
			||||||
            break
 | 
					            break
 | 
				
			||||||
            #data = requests.get(f"http://localhost:5001/api/task/{task.id}").json()
 | 
					 | 
				
			||||||
            #print(data)
 | 
					 | 
				
			||||||
            ## 通过 http 请求获取任务的状态
 | 
					 | 
				
			||||||
            ## 如果任务状态发生变化, 则将任务状态发送给所有监听该任务的 websocket
 | 
					 | 
				
			||||||
            #if task.status != data.status:
 | 
					 | 
				
			||||||
            #    task.status = data.status
 | 
					 | 
				
			||||||
            #    print("send task status to websocket")
 | 
					 | 
				
			||||||
            #    for websocket in task.websockets:
 | 
					 | 
				
			||||||
            #        await websocket.send_json(task)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## 任务实体模型
 | 
					 | 
				
			||||||
#class TaskModel(BaseModel):
 | 
					 | 
				
			||||||
#    id: str
 | 
					 | 
				
			||||||
#    name: str
 | 
					 | 
				
			||||||
#    user_id: int
 | 
					 | 
				
			||||||
#    description: str
 | 
					 | 
				
			||||||
#    created_at: str
 | 
					 | 
				
			||||||
#    updated_at: str
 | 
					 | 
				
			||||||
#    websockets: list=[]
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
## 向数据库中添加任务
 | 
					 | 
				
			||||||
#def add_task(name, user_id, description):
 | 
					 | 
				
			||||||
#    print(name, user_id, description)
 | 
					 | 
				
			||||||
#    date = datetime.datetime.now()
 | 
					 | 
				
			||||||
#    cursor = __taskdb.cursor()
 | 
					 | 
				
			||||||
#    cursor.execute("INSERT INTO tasks(name,user_id,description,created_at,updated_at) VALUES (?, ?, ?, ?, ?)", (name,user_id,description,date,date))
 | 
					 | 
				
			||||||
#    id = cursor.lastrowid
 | 
					 | 
				
			||||||
#    cursor.close()
 | 
					 | 
				
			||||||
#    __taskdb.commit()
 | 
					 | 
				
			||||||
#    return {"id": id,"name": name,"user_id": user_id,"description": description,"created_at": date,"updated_at": date}
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
## 从数据库中获取任务列表
 | 
					 | 
				
			||||||
#def get_tasks(user_id: int=None):
 | 
					 | 
				
			||||||
#    cursor = __taskdb.cursor()
 | 
					 | 
				
			||||||
#    if user_id:
 | 
					 | 
				
			||||||
#        cursor.execute("SELECT * FROM tasks WHERE user_id = ?", (user_id,))
 | 
					 | 
				
			||||||
#    else:
 | 
					 | 
				
			||||||
#        cursor.execute("SELECT * FROM tasks")
 | 
					 | 
				
			||||||
#    tasks = cursor.fetchall()
 | 
					 | 
				
			||||||
#    cursor.close()
 | 
					 | 
				
			||||||
#    return [dict(task) for task in tasks]
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
## 从数据库中获取指定任务
 | 
					 | 
				
			||||||
#def get_task(task_id: int):
 | 
					 | 
				
			||||||
#    cursor = __taskdb.cursor()
 | 
					 | 
				
			||||||
#    cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
 | 
					 | 
				
			||||||
#    task = cursor.fetchone()
 | 
					 | 
				
			||||||
#    cursor.close()
 | 
					 | 
				
			||||||
#    return task
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
## 更新数据库中的任务
 | 
					 | 
				
			||||||
#def update_task(task: TaskModel):
 | 
					 | 
				
			||||||
#    cursor = __taskdb.cursor()
 | 
					 | 
				
			||||||
#    cursor.execute("UPDATE tasks SET name = ?, description = ?, updated_at = ? WHERE id = ?", (
 | 
					 | 
				
			||||||
#        task.name,
 | 
					 | 
				
			||||||
#        task.description,
 | 
					 | 
				
			||||||
#        datetime.datetime.now(),
 | 
					 | 
				
			||||||
#        task.id
 | 
					 | 
				
			||||||
#    ))
 | 
					 | 
				
			||||||
#    cursor.close()
 | 
					 | 
				
			||||||
#    __taskdb.commit()
 | 
					 | 
				
			||||||
#    return task
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
## 从数据库中删除任务
 | 
					 | 
				
			||||||
#def delete_task(task_id: int):
 | 
					 | 
				
			||||||
#    cursor = __taskdb.cursor()
 | 
					 | 
				
			||||||
#    cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
 | 
					 | 
				
			||||||
#    cursor.close()
 | 
					 | 
				
			||||||
#    __taskdb.commit()
 | 
					 | 
				
			||||||
#    return task_id
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,13 +11,11 @@ from towhee import pipe, ops
 | 
				
			|||||||
from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
 | 
					from fastapi import APIRouter, HTTPException, File, UploadFile, Response, status, Header
 | 
				
			||||||
from models.milvus import get_collection, collection_name
 | 
					from models.milvus import get_collection, collection_name
 | 
				
			||||||
from models.mysql import get_cursor, conn
 | 
					from models.mysql import get_cursor, conn
 | 
				
			||||||
#from models.resnet import Resnet50
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from configs.config import UPLOAD_PATH
 | 
					from configs.config import UPLOAD_PATH
 | 
				
			||||||
from utilities.download import download_image
 | 
					from utilities.download import download_image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
router = APIRouter()
 | 
					router = APIRouter()
 | 
				
			||||||
#MODEL = Resnet50()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
 | 
					RESNET50 = (pipe.input('img').map('img', 'vec', ops.image_embedding.timm(model_name='resnet50')).output('vec'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -37,40 +35,8 @@ async def create_index():
 | 
				
			|||||||
    collection.load()
 | 
					    collection.load()
 | 
				
			||||||
    return {'status': True, 'count': collection.num_entities}
 | 
					    return {'status': True, 'count': collection.num_entities}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
'''
 | 
					 | 
				
			||||||
# 批量生成向量
 | 
					 | 
				
			||||||
@router.get('/create_vector', summary='生成向量', description='手动生成向量', include_in_schema=False)
 | 
					 | 
				
			||||||
async def create_vector(count: int = 10):
 | 
					 | 
				
			||||||
    cursor = get_cursor()
 | 
					 | 
				
			||||||
    cursor.execute(f"SELECT id,thumbnail_image,article_id,milvus_id FROM `web_images` WHERE thumbnail_image IS NOT NULL AND article_id IS NOT NULL AND milvus_id != 2048 AND width IS NOT NULL LIMIT 0,{count}")
 | 
					 | 
				
			||||||
    images = cursor.fetchall()
 | 
					 | 
				
			||||||
    cursor.close()
 | 
					 | 
				
			||||||
    for item in images:
 | 
					 | 
				
			||||||
        print(item)
 | 
					 | 
				
			||||||
        # 先查询 milvus 中是否存在
 | 
					 | 
				
			||||||
        collection = get_collection(collection_name)
 | 
					 | 
				
			||||||
        data = collection.query(expr=f'id in [{item["id"]}]', output_fields=None, partition_names=None, timeout=None) # offset, limit
 | 
					 | 
				
			||||||
        if len(data) > 0:
 | 
					 | 
				
			||||||
            cursor = get_cursor()
 | 
					 | 
				
			||||||
            cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
 | 
					 | 
				
			||||||
            conn.commit()
 | 
					 | 
				
			||||||
            cursor.close()
 | 
					 | 
				
			||||||
            continue
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            img_path = os.path.join(UPLOAD_PATH, os.path.basename(item['thumbnail_image']))
 | 
					 | 
				
			||||||
            download_image(item['thumbnail_image']).save(img_path, 'png', save_all=True)
 | 
					 | 
				
			||||||
            feat = MODEL.resnet50_extract_feat(img_path)
 | 
					 | 
				
			||||||
            collection.insert([[item['id']], [feat], [item['article_id']]])
 | 
					 | 
				
			||||||
            cursor = get_cursor()
 | 
					 | 
				
			||||||
            cursor.execute(f"UPDATE web_images SET milvus_id=2048 WHERE id={item['id']}")
 | 
					 | 
				
			||||||
            conn.commit()
 | 
					 | 
				
			||||||
            cursor.close()
 | 
					 | 
				
			||||||
        except Exception as e:
 | 
					 | 
				
			||||||
            print(e)
 | 
					 | 
				
			||||||
    print('END')
 | 
					 | 
				
			||||||
    return images
 | 
					 | 
				
			||||||
'''
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 重建指定图像的向量
 | 
				
			||||||
@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
 | 
					@router.put('/{image_id}', summary='覆写向量', description='重建指定图像的向量')
 | 
				
			||||||
async def rewrite_image(image_id: int):
 | 
					async def rewrite_image(image_id: int):
 | 
				
			||||||
    print('START', image_id, '重建向量')
 | 
					    print('START', image_id, '重建向量')
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user