标准Task模型
This commit is contained in:
@@ -3,6 +3,7 @@ from fastapi import WebSocket
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
# 使用字典来存储websocket连接
|
# 使用字典来存储websocket连接
|
||||||
class ConnectionManager:
|
class ConnectionManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -28,22 +29,56 @@ class ConnectionManager:
|
|||||||
for client_id, ws in self.active_connections.items():
|
for client_id, ws in self.active_connections.items():
|
||||||
await ws.send_text(message)
|
await ws.send_text(message)
|
||||||
|
|
||||||
|
'''
|
||||||
|
任务管理器(观察者模式)
|
||||||
|
使用字典来存储Task任务
|
||||||
|
'''
|
||||||
|
|
||||||
# Task 基本模型
|
# Task 基本模型
|
||||||
class Task(BaseModel):
|
class Task(BaseModel):
|
||||||
id: int
|
id: str=''
|
||||||
name: str
|
name: str=''
|
||||||
status: str
|
status: str='pending'
|
||||||
created_at: datetime
|
created_at: datetime=datetime.now()
|
||||||
updated_at: datetime
|
updated_at: datetime=datetime.now()
|
||||||
|
|
||||||
|
__observers = [] # 观察者列表
|
||||||
|
|
||||||
|
# 属性发生变化时,更新updated_at并通知观察者
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
self.event_observer(f"Task {self.id} updated at {self.updated_at}")
|
||||||
|
|
||||||
|
# 添加观察者
|
||||||
|
def add_observer(self, websocket: WebSocket):
|
||||||
|
self.__observers.append(websocket)
|
||||||
|
|
||||||
|
# 移除观察者
|
||||||
|
def remove_observer(self, websocket: WebSocket):
|
||||||
|
self.__observers.remove(websocket)
|
||||||
|
|
||||||
|
# 通知观察者
|
||||||
|
def event_observer(self, message: str):
|
||||||
|
for observer in self.__observers:
|
||||||
|
observer.send_text(message)
|
||||||
|
|
||||||
|
|
||||||
# 使用字典来存储Task任务
|
# 使用字典来存储Task任务
|
||||||
class TaskManager(object):
|
class TaskManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tasks = {}
|
self.tasks = {} # 任务ID的映射
|
||||||
# TOOD: 保持一个后台线程,定时检查任务的状态,并将任务的状态更新到数据中
|
# TOOD: 保持一个后台线程,定时检查任务的状态,并将任务的状态更新到数据中
|
||||||
|
|
||||||
def create(self, task: Task):
|
def add_observer(self, task_id: str, websocket: WebSocket):
|
||||||
|
self.tasks[task_id].add_observer(websocket)
|
||||||
|
|
||||||
|
def remove_observer(self, task_id: str, websocket: WebSocket):
|
||||||
|
self.tasks[task_id].remove_observer(websocket)
|
||||||
|
|
||||||
|
def has_task(self, task_id: str):
|
||||||
|
return task_id in self.tasks
|
||||||
|
|
||||||
|
def add(self, task: Task):
|
||||||
task.id = str(uuid.uuid4())
|
task.id = str(uuid.uuid4())
|
||||||
self.tasks[task.id] = task
|
self.tasks[task.id] = task
|
||||||
return task
|
return task
|
||||||
@@ -61,7 +96,7 @@ class TaskManager(object):
|
|||||||
return self.tasks
|
return self.tasks
|
||||||
|
|
||||||
# 使用字典来存储服务器池(使用腾讯云的API来管理服务器)
|
# 使用字典来存储服务器池(使用腾讯云的API来管理服务器)
|
||||||
class ServerManager(object):
|
class ServerManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.servers = {}
|
self.servers = {}
|
||||||
# 维护一个服务器池,每个服务器都有一个状态,状态有三种:空闲,运行中,异常
|
# 维护一个服务器池,每个服务器都有一个状态,状态有三种:空闲,运行中,异常
|
||||||
|
82
main.py
82
main.py
@@ -1,4 +1,5 @@
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
|
import sys
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Depends, HTTPException
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Depends, HTTPException
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from ObjectManager import ConnectionManager, TaskManager, ServerManager, Task
|
from ObjectManager import ConnectionManager, TaskManager, ServerManager, Task
|
||||||
@@ -26,10 +27,23 @@ html = """
|
|||||||
</form>
|
</form>
|
||||||
<ul id='messages'>
|
<ul id='messages'>
|
||||||
</ul>
|
</ul>
|
||||||
<script>
|
<script type="module">
|
||||||
var client_id = Date.now()
|
var task = await fetch('/tasks', {
|
||||||
document.querySelector("#ws-id").textContent = client_id;
|
method: 'POST',
|
||||||
var ws = new WebSocket(`ws://localhost:8000/ws/${client_id}`);
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
"name": "test",
|
||||||
|
"description": "test",
|
||||||
|
"status": "running",
|
||||||
|
"created_at": "2020-01-01 00:00:00",
|
||||||
|
"updated_at": "2020-01-01 00:00:00"
|
||||||
|
})
|
||||||
|
}).then(response => response.json())
|
||||||
|
console.log(task)
|
||||||
|
document.querySelector("#ws-id").textContent = task.id;
|
||||||
|
var ws = new WebSocket(`ws://localhost:8000/tasks/${task.id}`);
|
||||||
ws.onmessage = function(event) {
|
ws.onmessage = function(event) {
|
||||||
var messages = document.getElementById('messages')
|
var messages = document.getElementById('messages')
|
||||||
var message = document.createElement('li')
|
var message = document.createElement('li')
|
||||||
@@ -37,6 +51,9 @@ html = """
|
|||||||
message.appendChild(content)
|
message.appendChild(content)
|
||||||
messages.appendChild(message)
|
messages.appendChild(message)
|
||||||
};
|
};
|
||||||
|
ws.onclose = function(event) {
|
||||||
|
console.log('Socket is closed. Reconnect will be attempted in 1 second.', event.reason);
|
||||||
|
};
|
||||||
function sendMessage(event) {
|
function sendMessage(event) {
|
||||||
var input = document.getElementById("messageText")
|
var input = document.getElementById("messageText")
|
||||||
ws.send(input.value)
|
ws.send(input.value)
|
||||||
@@ -59,26 +76,6 @@ async def get():
|
|||||||
connection_manager = ConnectionManager()
|
connection_manager = ConnectionManager()
|
||||||
task_manager = TaskManager()
|
task_manager = TaskManager()
|
||||||
|
|
||||||
# 接收客户端的websocket连接
|
|
||||||
@app.websocket("/ws/{client_id}")
|
|
||||||
async def websocket_endpoint(websocket: WebSocket, client_id: int):
|
|
||||||
# TODO: 验证客户端的身份(使用TOKEN)
|
|
||||||
# 获取TOKEN (从数据库中获取用户的信息)
|
|
||||||
# token = request.headers.get('Authorization')
|
|
||||||
# if token is None:
|
|
||||||
# raise HTTPException(status_code=401, detail="Unauthorized")
|
|
||||||
|
|
||||||
await connection_manager.connect(websocket=websocket, client_id=client_id)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
data = await websocket.receive_text()
|
|
||||||
await connection_manager.send_personal_message(f"You wrote: {data}", client_id=client_id)
|
|
||||||
await connection_manager.broadcast(f"Client #{client_id} says: {data}")
|
|
||||||
# TODO: 处理客户端的请求变化(理论上并没有)
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
connection_manager.disconnect(client_id=client_id)
|
|
||||||
await connection_manager.broadcast(f"Client #{client_id} left the chat")
|
|
||||||
|
|
||||||
|
|
||||||
# 通知所有的ws客户端
|
# 通知所有的ws客户端
|
||||||
@app.post("/notify")
|
@app.post("/notify")
|
||||||
@@ -95,5 +92,38 @@ async def get_tasks():
|
|||||||
async def create_task(task: Task):
|
async def create_task(task: Task):
|
||||||
return task_manager.add(task)
|
return task_manager.add(task)
|
||||||
|
|
||||||
# 维护一个任务队列, 任务队列中的任务会被分发给worker节点
|
|
||||||
# 任务状态变化时通知对应的客户端
|
'''
|
||||||
|
监听任务进度
|
||||||
|
可能有多个客户端监听同一个任务(向任务的观察者列表中添加websocket连接)
|
||||||
|
可能有多个任务被同一客户端监听(向客户端的观察目标列表中添加任务)
|
||||||
|
应检查目标任务是否存在
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/tasks/{task_id}")
|
||||||
|
async def task_endpoint(websocket: WebSocket, task_id: str):
|
||||||
|
await websocket.accept()
|
||||||
|
if not task_manager.has_task(task_id):
|
||||||
|
await websocket.close()
|
||||||
|
print(f"close websocket: {task_id}")
|
||||||
|
return
|
||||||
|
task_manager.add_observer(task_id, websocket)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_text()
|
||||||
|
print(f"Client #says: {data}")
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
task_manager.remove_observer(task_id, websocket)
|
||||||
|
print(f"close websocket: {task_id}")
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
维护一个任务队列, 任务队列中的任务会被分发给worker节点
|
||||||
|
任务状态变化时通知对应的客户端
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 启动服务
|
||||||
|
if __name__ == '__main__':
|
||||||
|
port = 8000 if len(sys.argv) < 2 else int(sys.argv[1])
|
||||||
|
uvicorn.run(app='main:app', host='0.0.0.0', port=port, reload=True, workers=1)
|
||||||
|
42
message.py
Normal file
42
message.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
'''
|
||||||
|
0. 消息模型
|
||||||
|
1. 消息盒子(储存消息)
|
||||||
|
2. 消息信道(收发消息)
|
||||||
|
'''
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
id: str # 消息ID(可基于此更新消息)
|
||||||
|
type: str # 消息类型(可基于此分发消息)
|
||||||
|
data: dict # 消息数据(用于展示的消息主体)
|
||||||
|
date: datetime=datetime.now() # 时间(消息被创建的时间)
|
||||||
|
|
||||||
|
class MessageBox:
|
||||||
|
def __init__(self, sender, receiver, content):
|
||||||
|
self.sender = sender
|
||||||
|
self.receiver = receiver
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
def send(self):
|
||||||
|
self.receiver.receive(self)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'From: %s' % self.sender
|
||||||
|
|
||||||
|
class MessageChannel:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
#def watch(self, websocket: WebSocket, task: Task):
|
||||||
|
# await websocket.accept()
|
||||||
|
# # 将websocket连接加入到任务的观察者列表中
|
||||||
|
# task.add_observer(websocket)
|
||||||
|
# await websocket.send_text(f"Task {task.id} is being watched")
|
||||||
|
# # 从websocket连接中读取消息
|
||||||
|
# await websocket.receive_text()
|
||||||
|
# # 将websocket连接从任务的观察者列表中移除
|
||||||
|
# task.remove_observer(websocket)
|
||||||
|
# await websocket.send_text(f"Task {task.id} is no longer being watched")
|
||||||
|
# await websocket.close()
|
Reference in New Issue
Block a user