From ab25fa026708084e07007de160e27931f9d7c38e Mon Sep 17 00:00:00 2001 From: Dane Johnson Date: Wed, 12 Oct 2022 10:08:23 -0500 Subject: [PATCH] Run server in tokio thread, use tokio sync objects --- src/main.rs | 104 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 90 insertions(+), 14 deletions(-) diff --git a/src/main.rs b/src/main.rs index f892a57..c235a3e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,18 @@ use std::{ - sync::{Arc, Mutex}, + sync::Arc, io::Error as IoError, collections::HashMap, }; -use tokio::net::{TcpListener}; +use tokio::{ + net::{TcpListener}, + sync::{mpsc, Mutex}, +}; + +use futures::{ + select, + FutureExt, +}; mod message; use message::{Message, MessageWebSocket}; @@ -14,11 +22,13 @@ use code_generator::CodeGenerator; #[tokio::main] async fn main() -> Result<(), IoError> { let global_state = Arc::new(GlobalState::default()); - + + // Bind socket let socket = TcpListener::bind("127.0.0.1:8080").await; let listener = socket.expect("Could not bind to localhost:8080"); println!("Server running"); + // Accept all incoming connections while let Ok((stream, addr)) = listener.accept().await { let global_state = Arc::clone(&global_state); let mut local_state = LocalState::default(); @@ -27,18 +37,30 @@ async fn main() -> Result<(), IoError> { let mut ws = MessageWebSocket(tokio_tungstenite::accept_async(stream) .await .expect("Could not establish connection")); - println!("Connected to {}", addr); - while let Ok(msg) = ws.next().await { - dispatch(msg, &global_state, &mut local_state, &mut ws).await; + loop { + match &mut local_state.channel { + Some(channel) => { + select! { + msg = channel.rx.recv().fuse() => + handle_server_msg(msg.unwrap(), &mut local_state, &mut ws).await, + msg = ws.next().fuse() => + handle_client_msg(msg.unwrap(), &global_state, &mut local_state, &mut ws).await, + }; + } + None => { + let msg = ws.next().await; + handle_client_msg(msg.unwrap(), &global_state, &mut local_state, &mut ws).await; + } + } } }); } Ok(()) } -async fn dispatch( +async fn handle_client_msg( msg: Message, global_state: &GlobalState, local_state: &mut LocalState, @@ -46,27 +68,81 @@ async fn dispatch( ) { match msg.command.as_str() { "HOST" => { - let room_code = global_state.code_generator.lock().unwrap().generate(); - let game_controller = GameController::default(); - global_state.rooms.lock().unwrap().insert(room_code, game_controller); + let room_code = global_state.code_generator.lock().await.generate(); + let mut game_controller = GameController::default(); + + let [client_channel, server_channel] = channel_pair(); + local_state.channel = Some(client_channel); + game_controller.channels.push(server_channel); + + let game_controller = Arc::new(Mutex::new(game_controller)); + global_state.rooms.lock().await.insert(room_code.clone(), Arc::clone(&game_controller)); + tokio::spawn(async move {game_loop(game_controller)}); + ws.send(msg!("ROOM_CODE", room_code)).await.unwrap(); } "JOIN" => { let room_code = &msg.args[0]; - let room = global_state.rooms.lock().unwrap().get(room_code); + let rooms = global_state.rooms.lock().await; + let room = rooms.get(room_code); + + match room { + Some(room) => { + let mut room = room.lock().await; + let [client_channel, server_channel] = channel_pair(); + local_state.channel = Some(client_channel); + room.channels.push(server_channel); + ws.send(msg!("JOIN_OK")).await.unwrap(); + } + None => { + ws.send(msg!("JOIN_INVALID")).await.unwrap(); + } + } + } + _ => if let Some(channel) = &local_state.channel { + // Forward message to the server + channel.tx.send(msg).await; } - _ => unimplemented!(), } } +async fn handle_server_msg( + msg: Message, + local_state: &mut LocalState, + ws: &mut MessageWebSocket +) { + todo!(); +} + +fn game_loop(game_controller: Arc>) { + todo!(); +} + #[derive(Default)] struct GlobalState { code_generator: Arc>, - rooms: Arc>>, + rooms: Arc>>>>, } #[derive(Default)] -struct GameController {} // TODO +struct GameController { + channels: Vec, +} #[derive(Default)] struct LocalState { + channel: Option, +} + +struct Channel { + tx: mpsc::Sender, + rx: mpsc::Receiver, +} + +fn channel_pair() -> [Channel; 2] { + let (atx, brx) = mpsc::channel(32); + let (btx, arx) = mpsc::channel(32); + [ + Channel { tx: atx, rx: arx }, + Channel { tx: btx, rx: brx }, + ] }