diff --git a/src-tauri/src/commands/chat.rs b/src-tauri/src/commands/chat.rs index d8668a3..b98d204 100644 --- a/src-tauri/src/commands/chat.rs +++ b/src-tauri/src/commands/chat.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::{Arc, Mutex}; use futures_util::StreamExt; use surrealdb::types::{RecordId, RecordIdKey}; @@ -13,6 +14,25 @@ use crate::AppState; const DEFAULT_PAGE_SIZE: i64 = 50; const MAX_PAGE_SIZE: i64 = 100; const MAX_MESSAGE_LEN: usize = 4000; +const MAX_CACHED_ROOMS: usize = 5; + +fn cache_put( + cache: &Arc>>>, + order: &Arc>>, + room_id: &str, + messages: Vec, +) { + let mut c = cache.lock().unwrap(); + let mut o = order.lock().unwrap(); + c.insert(room_id.to_string(), messages); + o.retain(|id| id != room_id); + o.insert(0, room_id.to_string()); + while o.len() > MAX_CACHED_ROOMS { + if let Some(evicted) = o.pop() { + c.remove(&evicted); + } + } +} const MAX_ROOM_NAME_LEN: usize = 80; /// Wrapper emitted to the frontend for each LIVE query notification. @@ -427,7 +447,24 @@ pub async fn send_message( .ok_or_else(|| into_err(AppError::NotFound("message after create".into()))) } +/// Return cached messages for a room without hitting the remote DB. +/// Returns an empty vec if the room has not been cached yet. +#[tauri::command] +pub async fn get_cached_messages( + state: State<'_, AppState>, + room_id: String, +) -> Result, String> { + Ok(state + .msg_cache + .lock() + .unwrap() + .get(&room_id) + .cloned() + .unwrap_or_default()) +} + /// Fetch a bounded page of messages in a room, oldest first. +/// Also updates the in-process message cache. #[tauri::command] pub async fn get_messages( state: State<'_, AppState>, @@ -451,11 +488,11 @@ pub async fn get_messages( let mut response = state .db .query(query) - .bind(("room_id", room_id)) + .bind(("room_id", room_id.clone())) .bind(("limit", limit)); - if let Some(before) = before { - response = response.bind(("before", before)); + if let Some(ref before) = before { + response = response.bind(("before", before.clone())); } let mut result: Vec = response @@ -467,6 +504,18 @@ pub async fn get_messages( result.reverse(); let user = current_user(&state).await?; hydrate_reactions(&state, &user, &mut result).await?; + + if before.is_none() { + cache_put(&state.msg_cache, &state.cache_order, &room_id, result.clone()); + } else { + let mut c = state.msg_cache.lock().unwrap(); + if let Some(existing) = c.get_mut(&room_id) { + let mut merged = result.clone(); + merged.extend_from_slice(existing); + *existing = merged; + } + } + Ok(result) } @@ -563,7 +612,8 @@ pub async fn mark_room_read(state: State<'_, AppState>, room_id: String) -> Resu } /// Start a LIVE query for new messages in a room. -/// Spawns a background tokio task that emits "chat:message" Tauri events. +/// Spawns a background tokio task that emits "chat:message" Tauri events +/// and keeps the in-process message cache in sync. /// /// Returns a local subscription UUID — pass it to `unsubscribe_room` on cleanup. /// Aborting the JoinHandle drops the stream, which closes the LIVE query automatically. @@ -574,6 +624,9 @@ pub async fn subscribe_room( room_id: String, ) -> Result { let db = state.db.clone(); + let msg_cache = Arc::clone(&state.msg_cache); + let cache_order = Arc::clone(&state.cache_order); + let room_id_cache = room_id.clone(); let mut stream = db .query("LIVE SELECT * FROM message WHERE room = type::record('room', $room_id)") @@ -587,10 +640,39 @@ pub async fn subscribe_room( let handle = tokio::spawn(async move { while let Some(Ok(notification)) = stream.next().await { + let action = format!("{:?}", notification.action); + let data = notification.data.clone(); + + { + let mut c = msg_cache.lock().unwrap(); + let mut o = cache_order.lock().unwrap(); + if let Some(msgs) = c.get_mut(&room_id_cache) { + match action.as_str() { + "Create" => msgs.push(data.clone()), + "Update" => { + if let Some(m) = msgs.iter_mut().find(|m| m.id == data.id) { + *m = data.clone(); + } + } + "Delete" => msgs.retain(|m| m.id != data.id), + _ => {} + } + } else if action == "Create" { + c.insert(room_id_cache.clone(), vec![data.clone()]); + o.retain(|id| id != &room_id_cache); + o.insert(0, room_id_cache.clone()); + while o.len() > MAX_CACHED_ROOMS { + if let Some(evicted) = o.pop() { + c.remove(&evicted); + } + } + } + } + let _ = app_handle.emit( "chat:message", &LiveMessageEvent { - action: format!("{:?}", notification.action), + action, data: ¬ification.data, }, ); diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 9ee136f..cd0b4a2 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -13,9 +13,15 @@ mod error; mod models; use db::{init_db, SURREAL_DB, SURREAL_NS, SURREAL_URL}; +use models::Message; pub struct AppState { pub db: Arc>, + /// In-process message cache keyed by room_id string. Arc so the live-event + /// task in subscribe_room can hold a reference without borrowing AppState. + pub msg_cache: Arc>>>, + /// LRU order of cached room IDs (front = most recent). Evicts beyond 5. + pub cache_order: Arc>>, /// std::sync::Mutex is intentional: guards are never held across .await points. pub subscriptions: Mutex>>, } @@ -38,6 +44,8 @@ pub fn run() { let state = AppState { db: Arc::new(surreal), + msg_cache: Arc::new(Mutex::new(HashMap::new())), + cache_order: Arc::new(Mutex::new(Vec::new())), subscriptions: Mutex::new(HashMap::new()), }; @@ -61,6 +69,7 @@ pub fn run() { commands::chat::get_or_create_direct_room, commands::chat::send_message, commands::chat::get_messages, + commands::chat::get_cached_messages, commands::chat::delete_message, commands::chat::edit_message, commands::chat::toggle_reaction, diff --git a/src/routes/+page.svelte b/src/routes/+page.svelte index cd6f887..01a522f 100644 --- a/src/routes/+page.svelte +++ b/src/routes/+page.svelte @@ -136,11 +136,19 @@ activeRoom = room; replyTo = null; - messages = await cmd("get_messages", { + + const cached = await cmd("get_cached_messages", { roomId: sid(room.id) }); + if (cached.length > 0) { + messages = cached; + hasOlderMessages = false; + } + + const fresh = await cmd("get_messages", { roomId: sid(room.id), limit: 50, }); - hasOlderMessages = messages.length === 50; + messages = fresh; + hasOlderMessages = fresh.length === 50; unreadCounts = { ...unreadCounts, [sid(room.id)]: 0 }; await cmd("mark_room_read", { roomId: sid(room.id) }).catch(() => {});