import os import sqlite3 import datetime from pydantic import BaseModel from configs.config import SQLITE3_PATH import uuid import time import requests __taskdb = sqlite3.connect(os.path.join(SQLITE3_PATH, 'tasks.db'), check_same_thread=False) __taskdb.row_factory = sqlite3.Row # 创建任务表 __taskdb.execute(""" CREATE TABLE IF NOT EXISTS tasks ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, user_id INTEGER NOT NULL, description TEXT NOT NULL, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ) """) # 任务表单模型 class TaskForm(BaseModel): name: str description: str class Task(BaseModel): id: str name: str user_id: int description: str created_at: str updated_at: str status: str = "pending" progress: int = 0 websockets: list = [] class TaskManager: __tasks = {} # 使用pydantic模型的写法 def add_task(self, name: str, user_id: int, description: str): id = str(uuid.uuid4()) date = str(datetime.datetime.now()) task = Task(id=id, name=name, user_id=user_id, description=description, created_at=date, updated_at=date) self.__tasks[id] = task return task def get_task(self, task_id: str): task = self.__tasks.get(task_id) if task: return task.dict() else: return None # 直接返回对象的写法 def get_tasks(self, user_id: int=None): if user_id: return [task for task in self.__tasks.values() if task.user_id == user_id] else: return [task for task in self.__tasks.values()] def remove_task(self, task_id: str): if self.__tasks.get(task_id): self.__tasks.pop(task_id) return True else: return False # 正确的写法(不等待async函数直接返回结果) async def add_websocket(self, task_id: str, websocket): print(self.__tasks) # 每次调用连接数都是1, 这是错误的 self.__tasks[task_id].websockets.append(websocket) print(f"当前连接数: {len(self.__tasks[task_id].websockets)}") #if len(task.websockets) > 0 and task.progress == 0 and task.status == "pending": # print("任务未被监听, 开启监听") # await self.start_listen_task(task) # 移除 websocket def remove_websocket(self, task_id: str, websocket): task = self.__tasks.get(task_id) if task: print("连接已断开, 移除 websocket") task.websockets.remove(websocket) print(f"当前剩余连接数: {len(task.websockets)}") # 正确的写法 async def start_listen_task(self, task: Task): while True: print(f"start listen task: {len(task.websockets)}") time.sleep(2.5) if len(task.websockets) == 0: break print(f"get task status:{task.id}") break #data = requests.get(f"http://localhost:5001/api/task/{task.id}").json() #print(data) ## 通过 http 请求获取任务的状态 ## 如果任务状态发生变化, 则将任务状态发送给所有监听该任务的 websocket #if task.status != data.status: # task.status = data.status # print("send task status to websocket") # for websocket in task.websockets: # await websocket.send_json(task) ## 任务实体模型 #class TaskModel(BaseModel): # id: str # name: str # user_id: int # description: str # created_at: str # updated_at: str # websockets: list=[] # ## 向数据库中添加任务 #def add_task(name, user_id, description): # print(name, user_id, description) # date = datetime.datetime.now() # cursor = __taskdb.cursor() # cursor.execute("INSERT INTO tasks(name,user_id,description,created_at,updated_at) VALUES (?, ?, ?, ?, ?)", (name,user_id,description,date,date)) # id = cursor.lastrowid # cursor.close() # __taskdb.commit() # return {"id": id,"name": name,"user_id": user_id,"description": description,"created_at": date,"updated_at": date} # # ## 从数据库中获取任务列表 #def get_tasks(user_id: int=None): # cursor = __taskdb.cursor() # if user_id: # cursor.execute("SELECT * FROM tasks WHERE user_id = ?", (user_id,)) # else: # cursor.execute("SELECT * FROM tasks") # tasks = cursor.fetchall() # cursor.close() # return [dict(task) for task in tasks] # # ## 从数据库中获取指定任务 #def get_task(task_id: int): # cursor = __taskdb.cursor() # cursor.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)) # task = cursor.fetchone() # cursor.close() # return task # # ## 更新数据库中的任务 #def update_task(task: TaskModel): # cursor = __taskdb.cursor() # cursor.execute("UPDATE tasks SET name = ?, description = ?, updated_at = ? WHERE id = ?", ( # task.name, # task.description, # datetime.datetime.now(), # task.id # )) # cursor.close() # __taskdb.commit() # return task # # ## 从数据库中删除任务 #def delete_task(task_id: int): # cursor = __taskdb.cursor() # cursor.execute("DELETE FROM tasks WHERE id = ?", (task_id,)) # cursor.close() # __taskdb.commit() # return task_id