标准Task模型
This commit is contained in:
		@@ -3,6 +3,7 @@ from fastapi import WebSocket
 | 
				
			|||||||
from pydantic import BaseModel
 | 
					from pydantic import BaseModel
 | 
				
			||||||
from datetime import datetime
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 使用字典来存储websocket连接
 | 
					# 使用字典来存储websocket连接
 | 
				
			||||||
class ConnectionManager:
 | 
					class ConnectionManager:
 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
@@ -28,22 +29,56 @@ class ConnectionManager:
 | 
				
			|||||||
        for client_id, ws in self.active_connections.items():
 | 
					        for client_id, ws in self.active_connections.items():
 | 
				
			||||||
            await ws.send_text(message)
 | 
					            await ws.send_text(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					任务管理器(观察者模式)
 | 
				
			||||||
 | 
					使用字典来存储Task任务
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Task 基本模型
 | 
					# Task 基本模型
 | 
				
			||||||
class Task(BaseModel):
 | 
					class Task(BaseModel):
 | 
				
			||||||
    id: int
 | 
					    id: str=''
 | 
				
			||||||
    name: str
 | 
					    name: str=''
 | 
				
			||||||
    status: str
 | 
					    status: str='pending'
 | 
				
			||||||
    created_at: datetime
 | 
					    created_at: datetime=datetime.now()
 | 
				
			||||||
    updated_at: datetime
 | 
					    updated_at: datetime=datetime.now()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    __observers = [] # 观察者列表
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 属性发生变化时,更新updated_at并通知观察者
 | 
				
			||||||
 | 
					    def __setattr__(self, name, value):
 | 
				
			||||||
 | 
					        super().__setattr__(name, value)
 | 
				
			||||||
 | 
					        self.event_observer(f"Task {self.id} updated at {self.updated_at}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # 添加观察者
 | 
				
			||||||
 | 
					    def add_observer(self, websocket: WebSocket):
 | 
				
			||||||
 | 
					        self.__observers.append(websocket)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # 移除观察者
 | 
				
			||||||
 | 
					    def remove_observer(self, websocket: WebSocket):
 | 
				
			||||||
 | 
					        self.__observers.remove(websocket)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # 通知观察者
 | 
				
			||||||
 | 
					    def event_observer(self, message: str):
 | 
				
			||||||
 | 
					        for observer in self.__observers:
 | 
				
			||||||
 | 
					            observer.send_text(message)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 使用字典来存储Task任务
 | 
					# 使用字典来存储Task任务
 | 
				
			||||||
class TaskManager(object):
 | 
					class TaskManager:
 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        self.tasks = {}
 | 
					        self.tasks = {}       # 任务ID的映射
 | 
				
			||||||
        # TOOD: 保持一个后台线程,定时检查任务的状态,并将任务的状态更新到数据中
 | 
					        # TOOD: 保持一个后台线程,定时检查任务的状态,并将任务的状态更新到数据中
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def create(self, task: Task):
 | 
					    def add_observer(self, task_id: str, websocket: WebSocket):
 | 
				
			||||||
 | 
					        self.tasks[task_id].add_observer(websocket)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def remove_observer(self, task_id: str, websocket: WebSocket):
 | 
				
			||||||
 | 
					        self.tasks[task_id].remove_observer(websocket)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def has_task(self, task_id: str):
 | 
				
			||||||
 | 
					        return task_id in self.tasks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add(self, task: Task):
 | 
				
			||||||
        task.id = str(uuid.uuid4())
 | 
					        task.id = str(uuid.uuid4())
 | 
				
			||||||
        self.tasks[task.id] = task
 | 
					        self.tasks[task.id] = task
 | 
				
			||||||
        return task
 | 
					        return task
 | 
				
			||||||
@@ -61,7 +96,7 @@ class TaskManager(object):
 | 
				
			|||||||
        return self.tasks
 | 
					        return self.tasks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 使用字典来存储服务器池(使用腾讯云的API来管理服务器)
 | 
					# 使用字典来存储服务器池(使用腾讯云的API来管理服务器)
 | 
				
			||||||
class ServerManager(object):
 | 
					class ServerManager:
 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        self.servers = {}
 | 
					        self.servers = {}
 | 
				
			||||||
        # 维护一个服务器池,每个服务器都有一个状态,状态有三种:空闲,运行中,异常
 | 
					        # 维护一个服务器池,每个服务器都有一个状态,状态有三种:空闲,运行中,异常
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										82
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										82
									
								
								main.py
									
									
									
									
									
								
							@@ -1,4 +1,5 @@
 | 
				
			|||||||
import uvicorn
 | 
					import uvicorn
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Depends, HTTPException
 | 
					from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Depends, HTTPException
 | 
				
			||||||
from fastapi.responses import HTMLResponse
 | 
					from fastapi.responses import HTMLResponse
 | 
				
			||||||
from ObjectManager import ConnectionManager, TaskManager, ServerManager, Task
 | 
					from ObjectManager import ConnectionManager, TaskManager, ServerManager, Task
 | 
				
			||||||
@@ -26,10 +27,23 @@ html = """
 | 
				
			|||||||
        </form>
 | 
					        </form>
 | 
				
			||||||
        <ul id='messages'>
 | 
					        <ul id='messages'>
 | 
				
			||||||
        </ul>
 | 
					        </ul>
 | 
				
			||||||
        <script>
 | 
					        <script type="module">
 | 
				
			||||||
            var client_id = Date.now()
 | 
					            var task = await fetch('/tasks', {
 | 
				
			||||||
            document.querySelector("#ws-id").textContent = client_id;
 | 
					                method: 'POST',
 | 
				
			||||||
            var ws = new WebSocket(`ws://localhost:8000/ws/${client_id}`);
 | 
					                headers: {
 | 
				
			||||||
 | 
					                    'Content-Type': 'application/json'
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                body: JSON.stringify({
 | 
				
			||||||
 | 
					                    "name": "test",
 | 
				
			||||||
 | 
					                    "description": "test",
 | 
				
			||||||
 | 
					                    "status": "running",
 | 
				
			||||||
 | 
					                    "created_at": "2020-01-01 00:00:00",
 | 
				
			||||||
 | 
					                    "updated_at": "2020-01-01 00:00:00"
 | 
				
			||||||
 | 
					                })
 | 
				
			||||||
 | 
					            }).then(response => response.json())
 | 
				
			||||||
 | 
					            console.log(task)
 | 
				
			||||||
 | 
					            document.querySelector("#ws-id").textContent = task.id;
 | 
				
			||||||
 | 
					            var ws = new WebSocket(`ws://localhost:8000/tasks/${task.id}`);
 | 
				
			||||||
            ws.onmessage = function(event) {
 | 
					            ws.onmessage = function(event) {
 | 
				
			||||||
                var messages = document.getElementById('messages')
 | 
					                var messages = document.getElementById('messages')
 | 
				
			||||||
                var message = document.createElement('li')
 | 
					                var message = document.createElement('li')
 | 
				
			||||||
@@ -37,6 +51,9 @@ html = """
 | 
				
			|||||||
                message.appendChild(content)
 | 
					                message.appendChild(content)
 | 
				
			||||||
                messages.appendChild(message)
 | 
					                messages.appendChild(message)
 | 
				
			||||||
            };
 | 
					            };
 | 
				
			||||||
 | 
					            ws.onclose = function(event) {
 | 
				
			||||||
 | 
					                console.log('Socket is closed. Reconnect will be attempted in 1 second.', event.reason);
 | 
				
			||||||
 | 
					            };
 | 
				
			||||||
            function sendMessage(event) {
 | 
					            function sendMessage(event) {
 | 
				
			||||||
                var input = document.getElementById("messageText")
 | 
					                var input = document.getElementById("messageText")
 | 
				
			||||||
                ws.send(input.value)
 | 
					                ws.send(input.value)
 | 
				
			||||||
@@ -59,26 +76,6 @@ async def get():
 | 
				
			|||||||
connection_manager = ConnectionManager()
 | 
					connection_manager = ConnectionManager()
 | 
				
			||||||
task_manager = TaskManager()
 | 
					task_manager = TaskManager()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 接收客户端的websocket连接
 | 
					 | 
				
			||||||
@app.websocket("/ws/{client_id}")
 | 
					 | 
				
			||||||
async def websocket_endpoint(websocket: WebSocket, client_id: int):
 | 
					 | 
				
			||||||
    # TODO: 验证客户端的身份(使用TOKEN)
 | 
					 | 
				
			||||||
    # 获取TOKEN (从数据库中获取用户的信息)
 | 
					 | 
				
			||||||
    # token = request.headers.get('Authorization')
 | 
					 | 
				
			||||||
    # if token is None:
 | 
					 | 
				
			||||||
    #     raise HTTPException(status_code=401, detail="Unauthorized")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    await connection_manager.connect(websocket=websocket, client_id=client_id)
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            data = await websocket.receive_text()
 | 
					 | 
				
			||||||
            await connection_manager.send_personal_message(f"You wrote: {data}", client_id=client_id)
 | 
					 | 
				
			||||||
            await connection_manager.broadcast(f"Client #{client_id} says: {data}")
 | 
					 | 
				
			||||||
            # TODO: 处理客户端的请求变化(理论上并没有)
 | 
					 | 
				
			||||||
    except WebSocketDisconnect:
 | 
					 | 
				
			||||||
        connection_manager.disconnect(client_id=client_id)
 | 
					 | 
				
			||||||
        await connection_manager.broadcast(f"Client #{client_id} left the chat")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 通知所有的ws客户端
 | 
					# 通知所有的ws客户端
 | 
				
			||||||
@app.post("/notify")
 | 
					@app.post("/notify")
 | 
				
			||||||
@@ -95,5 +92,38 @@ async def get_tasks():
 | 
				
			|||||||
async def create_task(task: Task):
 | 
					async def create_task(task: Task):
 | 
				
			||||||
    return task_manager.add(task)
 | 
					    return task_manager.add(task)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 维护一个任务队列, 任务队列中的任务会被分发给worker节点
 | 
					
 | 
				
			||||||
# 任务状态变化时通知对应的客户端
 | 
					'''
 | 
				
			||||||
 | 
					监听任务进度
 | 
				
			||||||
 | 
					可能有多个客户端监听同一个任务(向任务的观察者列表中添加websocket连接)
 | 
				
			||||||
 | 
					可能有多个任务被同一客户端监听(向客户端的观察目标列表中添加任务)
 | 
				
			||||||
 | 
					应检查目标任务是否存在
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.websocket("/tasks/{task_id}")
 | 
				
			||||||
 | 
					async def task_endpoint(websocket: WebSocket, task_id: str):
 | 
				
			||||||
 | 
					    await websocket.accept()
 | 
				
			||||||
 | 
					    if not task_manager.has_task(task_id):
 | 
				
			||||||
 | 
					        await websocket.close()
 | 
				
			||||||
 | 
					        print(f"close websocket: {task_id}")
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    task_manager.add_observer(task_id, websocket)
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            data = await websocket.receive_text()
 | 
				
			||||||
 | 
					            print(f"Client #says: {data}")
 | 
				
			||||||
 | 
					    except WebSocketDisconnect:
 | 
				
			||||||
 | 
					        task_manager.remove_observer(task_id, websocket)
 | 
				
			||||||
 | 
					        print(f"close websocket: {task_id}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					维护一个任务队列, 任务队列中的任务会被分发给worker节点
 | 
				
			||||||
 | 
					任务状态变化时通知对应的客户端
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 启动服务
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    port = 8000 if len(sys.argv) < 2 else int(sys.argv[1])
 | 
				
			||||||
 | 
					    uvicorn.run(app='main:app', host='0.0.0.0', port=port, reload=True, workers=1)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										42
									
								
								message.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								message.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,42 @@
 | 
				
			|||||||
 | 
					'''
 | 
				
			||||||
 | 
					0. 消息模型
 | 
				
			||||||
 | 
					1. 消息盒子(储存消息)
 | 
				
			||||||
 | 
					2. 消息信道(收发消息)
 | 
				
			||||||
 | 
					'''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pydantic import BaseModel
 | 
				
			||||||
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Message(BaseModel):
 | 
				
			||||||
 | 
					    id:   str                     # 消息ID(可基于此更新消息)
 | 
				
			||||||
 | 
					    type: str                     # 消息类型(可基于此分发消息)
 | 
				
			||||||
 | 
					    data: dict                    # 消息数据(用于展示的消息主体)
 | 
				
			||||||
 | 
					    date: datetime=datetime.now() # 时间(消息被创建的时间)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MessageBox:
 | 
				
			||||||
 | 
					    def __init__(self, sender, receiver, content):
 | 
				
			||||||
 | 
					        self.sender = sender
 | 
				
			||||||
 | 
					        self.receiver = receiver
 | 
				
			||||||
 | 
					        self.content = content
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def send(self):
 | 
				
			||||||
 | 
					        self.receiver.receive(self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __str__(self):
 | 
				
			||||||
 | 
					        return 'From: %s' % self.sender
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MessageChannel:
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#def watch(self, websocket: WebSocket, task: Task):
 | 
				
			||||||
 | 
					#    await websocket.accept()
 | 
				
			||||||
 | 
					#    # 将websocket连接加入到任务的观察者列表中
 | 
				
			||||||
 | 
					#    task.add_observer(websocket)
 | 
				
			||||||
 | 
					#    await websocket.send_text(f"Task {task.id} is being watched")
 | 
				
			||||||
 | 
					#    # 从websocket连接中读取消息
 | 
				
			||||||
 | 
					#    await websocket.receive_text()
 | 
				
			||||||
 | 
					#    # 将websocket连接从任务的观察者列表中移除
 | 
				
			||||||
 | 
					#    task.remove_observer(websocket)
 | 
				
			||||||
 | 
					#    await websocket.send_text(f"Task {task.id} is no longer being watched")
 | 
				
			||||||
 | 
					#    await websocket.close()
 | 
				
			||||||
		Reference in New Issue
	
	Block a user