转移
This commit is contained in:
		
							
								
								
									
										26
									
								
								models/milvus.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								models/milvus.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,26 @@
 | 
			
		||||
import pymilvus
 | 
			
		||||
 | 
			
		||||
from configs.config import MILVUS_HOST
 | 
			
		||||
 | 
			
		||||
# 连接 Milvus (开启 Milvus 服务)
 | 
			
		||||
collection_name = 'default'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取 Milvus 连接
 | 
			
		||||
def get_collection(collection_name):
 | 
			
		||||
    pymilvus.connections.connect(host=MILVUS_HOST, port='19530')
 | 
			
		||||
    if not pymilvus.utility.has_collection(collection_name):
 | 
			
		||||
        field1 = pymilvus.FieldSchema(name="id", dtype=pymilvus.DataType.INT64, is_primary=True)
 | 
			
		||||
        field2 = pymilvus.FieldSchema(name="embedding", dtype=pymilvus.DataType.FLOAT_VECTOR, dim=2048)
 | 
			
		||||
        field3 = pymilvus.FieldSchema(name="article_id", dtype=pymilvus.DataType.INT64)
 | 
			
		||||
        schema = pymilvus.CollectionSchema(fields=[field1, field2, field3])
 | 
			
		||||
        return pymilvus.Collection(name=collection_name, schema=schema)
 | 
			
		||||
    return pymilvus.Collection(name=collection_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 检查索引是否存在, 不存在则创建, 并加载
 | 
			
		||||
#collection = get_collection(collection_name)
 | 
			
		||||
#if not collection.has_index():
 | 
			
		||||
#    default_index = {"index_type": "IVF_SQ8", "metric_type": 'L2', "params": {"nlist": 16384}}
 | 
			
		||||
#    collection.create_index(field_name="embedding", index_params=default_index)
 | 
			
		||||
#collection.load()
 | 
			
		||||
							
								
								
									
										16
									
								
								models/mysql.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								models/mysql.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
			
		||||
import pymysql
 | 
			
		||||
from configs.config import MYSQL_HOST
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 连接 MySQL (开启 MySQL 服务)
 | 
			
		||||
conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 获取 MySQL 连接
 | 
			
		||||
def get_cursor():
 | 
			
		||||
    try:
 | 
			
		||||
        dx = conn.ping()
 | 
			
		||||
        return conn.cursor()
 | 
			
		||||
    except Exception:
 | 
			
		||||
        conn = pymysql.connect(host=MYSQL_HOST, user="gameui", port=3306, password="gameui@2022", database='gameui', local_infile=True, cursorclass=pymysql.cursors.DictCursor)
 | 
			
		||||
        return conn.cursor()
 | 
			
		||||
							
								
								
									
										9
									
								
								models/oss.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								models/oss.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
			
		||||
import oss2
 | 
			
		||||
from configs.config import OSS_HOST
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 连接 OSS
 | 
			
		||||
oss2.defaults.connection_pool_size = 100
 | 
			
		||||
auth = oss2.Auth('LTAI4GH3qP6VA3QpmTYCgXEW', 'r2wz4bJty8iYfGIcFmEqlY1yon2Ruy')
 | 
			
		||||
bucket_image2 = oss2.Bucket(auth, f'http://{OSS_HOST}', 'gameui-image2')
 | 
			
		||||
bucket_webp = oss2.Bucket(auth, f'http://{OSS_HOST}', 'gameui-webp')
 | 
			
		||||
							
								
								
									
										11
									
								
								models/resnet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								models/resnet.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
			
		||||
import towhee
 | 
			
		||||
 | 
			
		||||
class Resnet50:
 | 
			
		||||
    def resnet50_extract_feat(self, img_path):
 | 
			
		||||
        feat = towhee.glob(img_path).image_decode().image_embedding.timm(model_name='resnet50').tensor_normalize().to_list()
 | 
			
		||||
        print(feat[0])
 | 
			
		||||
        return feat[0]
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    print('This script is running as the main program.')
 | 
			
		||||
    #resnet = 
 | 
			
		||||
							
								
								
									
										184
									
								
								models/task.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										184
									
								
								models/task.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,184 @@
 | 
			
		||||
import os
 | 
			
		||||
import sqlite3
 | 
			
		||||
import datetime
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from configs.config import SQLITE3_PATH
 | 
			
		||||
import uuid
 | 
			
		||||
import time
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False)
 | 
			
		||||
__taskdb.row_factory = sqlite3.Row
 | 
			
		||||
 | 
			
		||||
# 创建任务表
 | 
			
		||||
__taskdb.execute("""
 | 
			
		||||
CREATE TABLE IF NOT EXISTS tasks (
 | 
			
		||||
    id INTEGER PRIMARY KEY AUTOINCREMENT,
 | 
			
		||||
    name TEXT NOT NULL,
 | 
			
		||||
    user_id INTEGER NOT NULL,
 | 
			
		||||
    description TEXT NOT NULL,
 | 
			
		||||
    created_at TEXT NOT NULL,
 | 
			
		||||
    updated_at TEXT NOT NULL
 | 
			
		||||
)
 | 
			
		||||
""")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 任务表单模型
 | 
			
		||||
class TaskForm(BaseModel):
 | 
			
		||||
    name: str
 | 
			
		||||
    description: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Task(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    name: str
 | 
			
		||||
    user_id: int
 | 
			
		||||
    description: str
 | 
			
		||||
    created_at: str
 | 
			
		||||
    updated_at: str
 | 
			
		||||
    status: str = "pending"
 | 
			
		||||
    progress: int = 0
 | 
			
		||||
    websockets: list = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TaskManager:
 | 
			
		||||
    __tasks = {}
 | 
			
		||||
    
 | 
			
		||||
    # 使用pydantic模型的写法
 | 
			
		||||
    def add_task(self, name: str, user_id: int, description: str):
 | 
			
		||||
        id = str(uuid.uuid4())
 | 
			
		||||
        date = str(datetime.datetime.now())
 | 
			
		||||
        task = Task(id=id, name=name, user_id=user_id, description=description, created_at=date, updated_at=date)
 | 
			
		||||
        self.__tasks[id] = task
 | 
			
		||||
        return task
 | 
			
		||||
 | 
			
		||||
    def get_task(self, task_id: str):
 | 
			
		||||
        task = self.__tasks.get(task_id)
 | 
			
		||||
        if task:
 | 
			
		||||
            return task.dict()
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
    
 | 
			
		||||
    # 直接返回对象的写法
 | 
			
		||||
    def get_tasks(self, user_id: int=None):
 | 
			
		||||
        if user_id:
 | 
			
		||||
            return [task for task in self.__tasks.values() if task.user_id == user_id]
 | 
			
		||||
        else:
 | 
			
		||||
            return [task for task in self.__tasks.values()]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def remove_task(self, task_id: str):
 | 
			
		||||
        if self.__tasks.get(task_id):
 | 
			
		||||
            self.__tasks.pop(task_id)
 | 
			
		||||
            return True
 | 
			
		||||
        else:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    # 正确的写法(不等待async函数直接返回结果)
 | 
			
		||||
    async def add_websocket(self, task_id: str, websocket):
 | 
			
		||||
        print(self.__tasks)
 | 
			
		||||
        # 每次调用连接数都是1, 这是错误的
 | 
			
		||||
        self.__tasks[task_id].websockets.append(websocket)
 | 
			
		||||
        print(f"当前连接数: {len(self.__tasks[task_id].websockets)}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        #if len(task.websockets) > 0 and task.progress == 0 and task.status == "pending":
 | 
			
		||||
        #    print("任务未被监听, 开启监听")
 | 
			
		||||
        #    await self.start_listen_task(task)
 | 
			
		||||
    
 | 
			
		||||
    # 移除 websocket
 | 
			
		||||
    def remove_websocket(self, task_id: str, websocket):
 | 
			
		||||
        task = self.__tasks.get(task_id)
 | 
			
		||||
        if task:
 | 
			
		||||
            print("连接已断开, 移除 websocket")
 | 
			
		||||
            task.websockets.remove(websocket)
 | 
			
		||||
            print(f"当前剩余连接数: {len(task.websockets)}")
 | 
			
		||||
    
 | 
			
		||||
    # 正确的写法
 | 
			
		||||
    async def start_listen_task(self, task: Task):
 | 
			
		||||
        while True:
 | 
			
		||||
            print(f"start listen task: {len(task.websockets)}")
 | 
			
		||||
            time.sleep(2.5)
 | 
			
		||||
            if len(task.websockets) == 0:
 | 
			
		||||
                break
 | 
			
		||||
            print(f"get task status:{task.id}")
 | 
			
		||||
            break
 | 
			
		||||
            #data = requests.get(f"http://localhost:5001/api/task/{task.id}").json()
 | 
			
		||||
            #print(data)
 | 
			
		||||
            ## 通过 http 请求获取任务的状态
 | 
			
		||||
            ## 如果任务状态发生变化, 则将任务状态发送给所有监听该任务的 websocket
 | 
			
		||||
            #if task.status != data.status:
 | 
			
		||||
            #    task.status = data.status
 | 
			
		||||
            #    print("send task status to websocket")
 | 
			
		||||
            #    for websocket in task.websockets:
 | 
			
		||||
            #        await websocket.send_json(task)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## 任务实体模型
 | 
			
		||||
#class TaskModel(BaseModel):
 | 
			
		||||
#    id: str
 | 
			
		||||
#    name: str
 | 
			
		||||
#    user_id: int
 | 
			
		||||
#    description: str
 | 
			
		||||
#    created_at: str
 | 
			
		||||
#    updated_at: str
 | 
			
		||||
#    websockets: list=[]
 | 
			
		||||
#
 | 
			
		||||
## 向数据库中添加任务
 | 
			
		||||
#def add_task(name, user_id, description):
 | 
			
		||||
#    print(name, user_id, description)
 | 
			
		||||
#    date = datetime.datetime.now()
 | 
			
		||||
#    cursor = __taskdb.cursor()
 | 
			
		||||
#    cursor.execute("INSERT INTO tasks(name,user_id,description,created_at,updated_at) VALUES (?, ?, ?, ?, ?)", (name,user_id,description,date,date))
 | 
			
		||||
#    id = cursor.lastrowid
 | 
			
		||||
#    cursor.close()
 | 
			
		||||
#    __taskdb.commit()
 | 
			
		||||
#    return {"id": id,"name": name,"user_id": user_id,"description": description,"created_at": date,"updated_at": date}
 | 
			
		||||
#
 | 
			
		||||
#
 | 
			
		||||
## 从数据库中获取任务列表
 | 
			
		||||
#def get_tasks(user_id: int=None):
 | 
			
		||||
#    cursor = __taskdb.cursor()
 | 
			
		||||
#    if user_id:
 | 
			
		||||
#        cursor.execute("SELECT * FROM tasks WHERE user_id = ?", (user_id,))
 | 
			
		||||
#    else:
 | 
			
		||||
#        cursor.execute("SELECT * FROM tasks")
 | 
			
		||||
#    tasks = cursor.fetchall()
 | 
			
		||||
#    cursor.close()
 | 
			
		||||
#    return [dict(task) for task in tasks]
 | 
			
		||||
#
 | 
			
		||||
#
 | 
			
		||||
## 从数据库中获取指定任务
 | 
			
		||||
#def get_task(task_id: int):
 | 
			
		||||
#    cursor = __taskdb.cursor()
 | 
			
		||||
#    cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,))
 | 
			
		||||
#    task = cursor.fetchone()
 | 
			
		||||
#    cursor.close()
 | 
			
		||||
#    return task
 | 
			
		||||
#
 | 
			
		||||
#
 | 
			
		||||
## 更新数据库中的任务
 | 
			
		||||
#def update_task(task: TaskModel):
 | 
			
		||||
#    cursor = __taskdb.cursor()
 | 
			
		||||
#    cursor.execute("UPDATE tasks SET name = ?, description = ?, updated_at = ? WHERE id = ?", (
 | 
			
		||||
#        task.name,
 | 
			
		||||
#        task.description,
 | 
			
		||||
#        datetime.datetime.now(),
 | 
			
		||||
#        task.id
 | 
			
		||||
#    ))
 | 
			
		||||
#    cursor.close()
 | 
			
		||||
#    __taskdb.commit()
 | 
			
		||||
#    return task
 | 
			
		||||
#
 | 
			
		||||
#
 | 
			
		||||
## 从数据库中删除任务
 | 
			
		||||
#def delete_task(task_id: int):
 | 
			
		||||
#    cursor = __taskdb.cursor()
 | 
			
		||||
#    cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,))
 | 
			
		||||
#    cursor.close()
 | 
			
		||||
#    __taskdb.commit()
 | 
			
		||||
#    return task_id
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user