diff --git a/src/app/chat/chat.py b/src/app/chat/chat.py index 97d3287..a6d4b5c 100644 --- a/src/app/chat/chat.py +++ b/src/app/chat/chat.py @@ -1,32 +1,43 @@ import asyncio from fastapi import WebSocket +import asyncio class Chat: members: dict[str, WebSocket] """Represents the chatting room.""" + def __init__(self): self.members = {} - + self.illegal_ids = ["system"] + async def join(self, user: str, websocket: WebSocket): """Connect to the user and add user to the members on success.""" - ######### [ TODO ] ######### # TODO Establish connection - ############################ + await websocket.accept() + close_reason = "" ######### [ TODO ] ######### # TODO Close connection on duplicate and exit method. ############################ - + if user in self.members.keys(): + close_reason = "User {} already exists".format(user) + await websocket.close(reason=close_reason) + elif user in self.illegal_ids: + close_reason = "Cannot use Username {}".format(user) + await websocket.close(reason=close_reason) ######### [ TODO ] ######### # TODO Otherwise, add user to the members. ############################ - - ######### [ TODO ] ######### - # TODO Optionally, broadcast the system message that the user joined the chat. - ############################ + else: + self.members[user] = websocket + json = { + "from": "system", + "msg": "User {} joined the chat".format(user) + } + await self.broadcast(json) async def leave(self, user: str): """Remove user from the members.""" @@ -34,10 +45,16 @@ async def leave(self, user: str): ######### [ TODO ] ######### # TODO Remove user from the members. ############################ - + del self.members[user] + ######### [ TODO ] ######### # TODO Optionally, broadcast the system message that the user left the chat. ############################ + json = { + "from": "system", + "msg": "User {} left the chat".format(user) + } + await self.broadcast(json) async def handle_message(self, user: str, message: dict[str, str]): """Handler message from user.""" @@ -45,6 +62,34 @@ async def handle_message(self, user: str, message: dict[str, str]): ######### [ TODO ] ######### # TODO Check the message type and receiver. ############################ + msg_type = message["type"] + msg = message["msg"] + if msg_type == "direct": + dst_user = message["to"] + try: + dst_socket = self.members[dst_user] + await dst_socket.send_json( + { + "from": dst_user, + "msg": msg + } + ) + except KeyError: + src_socket = self.members[user] + await src_socket.send_json( + { + "from": "system", + "msg": "User {} does not exist.".format(dst_user) + } + ) + + else: + await self.broadcast( + { + "from":user, + "msg":msg + }, 1, user + ) ######### [ TODO ] ######### # TODO If the valid receiver does not exist, send error message back to the sender. @@ -54,3 +99,14 @@ async def handle_message(self, user: str, message: dict[str, str]): # TODO Otherwise, forward the message to the receiver. # If broadcasting the message, DO NOT send to the original sender. ############################ + + async def broadcast(self, msg: dict, type = 0, sender = None): + tasks = [] + for user, socket in self.members.items(): + if type == 1 and sender == user: + continue + task = asyncio.create_task(socket.send_json(msg)) + tasks.append(task) + + for task in tasks: + await task diff --git a/src/app/chat/router.py b/src/app/chat/router.py index 23ecb2b..28dbc68 100644 --- a/src/app/chat/router.py +++ b/src/app/chat/router.py @@ -19,19 +19,14 @@ async def get() -> HTMLResponse: @router.websocket("/ws") async def websocket(user: str, websocket: WebSocket): """Interact with user via connection.""" - - ######### [ TODO ] ######### # TODO Join the chat - ############################ + await chat.join(user, websocket) try: while True: - ######### [ TODO ] ######### # TODO Receive message from the user and handle - ############################ - pass + data = await websocket.receive_json() + await chat.handle_message(user, data) except WebSocketDisconnect: - ######### [ TODO ] ######### # TODO Handle disconnection: leave the chat - ############################ - pass + await chat.leave(user) diff --git a/src/app/main.py b/src/app/main.py index caf3226..eafc889 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -5,3 +5,8 @@ app = FastAPI() app.include_router(chat_router) + + +@app.get("/") +def root(): + return {"messgae": "Hello, world!"} \ No newline at end of file