From 397b9878e1c942f9ee9e9afa81c57563900adeba Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Sun, 6 Oct 2024 13:22:13 -0700 Subject: [PATCH] fix: temp fix for spam attack via notification bug abuse, but git actually commits the changes this time --- .../database/src/models/messages/model.rs | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/crates/core/database/src/models/messages/model.rs b/crates/core/database/src/models/messages/model.rs index 29710ed31..e813fa94f 100644 --- a/crates/core/database/src/models/messages/model.rs +++ b/crates/core/database/src/models/messages/model.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::{collections::HashSet, hash::RandomState}; use indexmap::{IndexMap, IndexSet}; use iso8601_timestamp::Timestamp; @@ -305,6 +305,18 @@ impl Message { ..Default::default() }; + // Parse mentions in message. + let mut mentions = HashSet::new(); + if allow_mentions { + if let Some(content) = &data.content { + for capture in RE_MENTION.captures_iter(content) { + if let Some(mention) = capture.get(1) { + mentions.insert(mention.as_str().to_string()); + } + } + } + } + // Verify replies are valid. let mut replies = HashSet::new(); if let Some(entries) = data.replies { @@ -325,29 +337,27 @@ impl Message { } } - // Parse mentions in message. - let mut mentions = HashSet::new(); - if allow_mentions { - if let Some(content) = &data.content { - for capture in RE_MENTION.captures_iter(content) { - if let Some(mention) = capture.get(1) { - mentions.insert(mention.as_str().to_string()); - } - } - } - } - if !mentions.is_empty() { // FIXME: temp fix to stop spam attacks match channel { - Channel::DirectMessage { recipients, .. } | Channel::Group { recipients, .. } => { - mentions = mentions.intersection(recipients); + Channel::DirectMessage { ref recipients, .. } + | Channel::Group { ref recipients, .. } => { + let recipients_hash: HashSet<&String, RandomState> = + HashSet::from_iter(recipients.iter()); + + mentions.retain(|m| recipients_hash.contains(m)); } - Channel::TextChannel { server, .. } | Channel::VoiceChannel { server, .. } => { - let valid_members = db.fetch_members(server.into(), mentions).await; + Channel::TextChannel { ref server, .. } + | Channel::VoiceChannel { ref server, .. } => { + let mentions_vec = Vec::from_iter(mentions.iter().cloned()); + let valid_members = db.fetch_members(server.as_str(), &mentions_vec[..]).await; if let Ok(valid_members) = valid_members { - let valid_ids = valid_members.iter().map(|member| member.id.user); - mentions = mentions.intersection(valid_ids); + let valid_ids: HashSet = HashSet::from_iter( + valid_members.iter().map(|member| member.id.user.clone()), + ); + mentions.retain(|m| valid_ids.contains(m)); + } else { + revolt_config::capture_error(&valid_members.unwrap_err()); } } Channel::SavedMessages { .. } => mentions.clear(),