diff --git a/Cargo.lock b/Cargo.lock index d92d34c..9fe4d46 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -613,7 +613,7 @@ dependencies = [ [[package]] name = "glados" -version = "0.1.0" +version = "0.1.1" dependencies = [ "dotenv", "influx_db_client", diff --git a/Cargo.toml b/Cargo.toml index 2ac2a2a..75a1a68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "glados" -version = "0.1.0" +version = "0.1.1" edition = "2021" license = "MIT" diff --git a/src/discord.rs b/src/discord.rs index 22c00dc..ff6db1f 100644 --- a/src/discord.rs +++ b/src/discord.rs @@ -6,6 +6,9 @@ //! - `~invite Random#1234` (not possible) //! - Invite someone into the guild //! +//! - `~filter Active - Linked` +//! - Filter members based on assigned roles +//! //! - Member_add //! - Notify guild about new member //! diff --git a/src/discord/framework.rs b/src/discord/framework.rs index 77d43ef..56a1c43 100644 --- a/src/discord/framework.rs +++ b/src/discord/framework.rs @@ -13,6 +13,10 @@ use serenity::{ model::{channel::Message, id::UserId}, }; +mod filter; + +use filter::FILTER_COMMAND; + use crate::conf::ARGS; pub fn init() -> StandardFramework { @@ -30,7 +34,7 @@ pub fn init() -> StandardFramework { #[group] #[owners_only] -#[commands(ping, invite)] +#[commands(ping, invite, filter)] struct Owner; #[group] diff --git a/src/discord/framework/filter.rs b/src/discord/framework/filter.rs new file mode 100644 index 0000000..4747c76 --- /dev/null +++ b/src/discord/framework/filter.rs @@ -0,0 +1,227 @@ +use std::collections::{HashMap, HashSet}; + +use nom::{ + branch::alt, + character::complete::{char, multispace0, none_of}, + combinator::{all_consuming, map, opt, recognize, value}, + multi::{many0, many1}, + sequence::tuple, + IResult, +}; +use serenity::{ + client::Context, + framework::standard::{macros::command, Args, CommandResult}, + futures::StreamExt, + model::{channel::Message, guild::Member, id::RoleId}, +}; + +use crate::{ + conf::ARGS, + error::{Error, Result, ResultExt}, +}; + +#[command] +#[description = "Filter members based on roles"] +pub async fn filter(ctx: &Context, msg: &Message, args: Args) -> CommandResult { + // Assemble filter + let filter = match MemberFilter::parse(ctx, args.rest()).await { + Ok(filter) => filter, + Err(why) => { + msg.reply(ctx, "I don't understand your input").await?; + return Err(Box::from(why)); + } + }; + // Fetch guild members + let mut members = ARGS.guild_id().members_iter(ctx).boxed(); + // Member that match our filter + let mut matched = vec![]; + // Filter members + while let Some(member) = members.next().await { + // Notify about member fetching errors + let member = match member.log_warn("fetching guild members") { + Some(member) => member, + None => continue, + }; + // Apply Filter + if filter.matches(&member).await { + matched.push(member.user.id); + }; + } + // Assemble reply + let reply = matched + .iter() + .fold(String::from("Filtered members:"), |s, id| { + s + &format!(" {}", id) + }); + msg.reply(ctx, reply).await?; + Ok(()) +} + +#[derive(Debug, Default, Clone)] +struct MemberFilter { + add: HashSet, + sub: HashSet, +} + +#[derive(Debug, Default, Clone, PartialEq)] +struct RawMemberFilter<'i> { + add: Vec<&'i str>, + sub: Vec<&'i str>, +} + +#[derive(Debug, Clone, PartialEq)] +enum Sign { + Plus, + Minus, +} + +impl MemberFilter { + pub async fn parse(ctx: &Context, raw: &str) -> Result { + let raw_filter = all_consuming(parse_member_filter)(raw) + .map(|(_, filter)| filter) + .map_err(|why| Error::ParseMemberFilter(why.to_owned()))?; + Self::from_raw(ctx, raw_filter).await + } + + /// Does this filter match the given member? + /// + /// # When is this true? + /// - If `add` is not empty: + /// - `member` has a role in `add` and none in `sub` + /// - If `add` is empty but `sub` is not: + /// - `member` has no role in `sub` + /// - If `add` and `sub` are empty: + /// - YES! + pub async fn matches(&self, member: &Member) -> bool { + let roles: HashSet<_> = member.roles.clone().into_iter().collect(); + if !self.add.is_empty() { + let add_intersection: HashSet<_> = self.add.intersection(&roles).collect(); + let sub_intersection: HashSet<_> = self.sub.intersection(&roles).collect(); + !add_intersection.is_empty() && sub_intersection.is_empty() + } else if self.add.is_empty() && !self.sub.is_empty() { + let sub_intersection: HashSet<_> = self.sub.intersection(&roles).collect(); + sub_intersection.is_empty() + } else { + true + } + } + + async fn from_raw(ctx: &Context, raw_filter: RawMemberFilter<'_>) -> Result { + let RawMemberFilter { add, sub } = raw_filter; + // Map role names to role ids + let roles: HashMap<_, _> = ARGS + .guild_id() + .roles(ctx) + .await? + .into_iter() + .map(|(id, role)| (role.name, id)) + .collect(); + // Helper to map a role name to a role id, if that role exists + let to_id = |name: &str| -> Result { + let name = name.trim(); + match roles.get(name) { + Some(id) => Ok(*id), + None => Err(Error::UnknownRoleInFilter(name.to_owned())), + } + }; + // Helper to collect a Result> from results + let set_from_res = |mut set: HashSet<_>, id: Result<_>| -> Result<_> { + set.insert(id?); + Ok(set) + }; + Ok(MemberFilter { + add: add + .into_iter() + .map(to_id) + .try_fold(HashSet::new(), set_from_res)?, + sub: sub + .into_iter() + .map(to_id) + .try_fold(HashSet::new(), set_from_res)?, + }) + } +} + +fn parse_member_filter(inp: &str) -> IResult<&str, RawMemberFilter<'_>> { + map(many0(parse_role), |roles| { + let mut filter = RawMemberFilter::default(); + for (sign, role) in roles { + match sign { + Sign::Plus => filter.add.push(role), + Sign::Minus => filter.sub.push(role), + } + } + filter + })(inp.trim()) +} + +fn parse_role(inp: &str) -> IResult<&str, (Sign, &str)> { + let parse_sign = alt(( + value(Sign::Minus, char('-')), + value(Sign::Plus, opt(char('+'))), + )); + // TODO: No whitespaces/plus/minus allowed but Discord roles allow almost any char + let parse_role_name = recognize(many1(none_of(" +-\t\n\r"))); + map( + tuple((parse_sign, multispace0, parse_role_name, multispace0)), + |(sign, _, role, _)| (sign, role), + )(inp) +} + +#[cfg(test)] +mod tests { + use nom::combinator::all_consuming; + + use crate::discord::framework::filter::{RawMemberFilter, Sign}; + + macro_rules! parse { + ($fn:path, $str:literal, $ok_val:expr) => { + let res = all_consuming($fn)($str); + assert_eq!(res.unwrap().1, $ok_val) + }; + } + + #[test] + fn parse_role() { + use super::parse_role; + parse!(parse_role, "test", (Sign::Plus, "test")); + parse!(parse_role, "-test", (Sign::Minus, "test")); + } + + #[test] + fn parse_member_filter() { + use super::parse_member_filter; + parse!( + parse_member_filter, + "test", + RawMemberFilter { + add: vec!["test"], + sub: vec![], + } + ); + parse!( + parse_member_filter, + "+other", + RawMemberFilter { + add: vec!["other"], + sub: vec![], + } + ); + parse!( + parse_member_filter, + "-test -9+some_other\t +dis", + RawMemberFilter { + add: vec!["some_other", "dis"], + sub: vec!["test", "9"], + } + ); + parse!( + parse_member_filter, + "\t -test \t\n+other", + RawMemberFilter { + add: vec!["other"], + sub: vec!["test"], + } + ); + } +} diff --git a/src/error.rs b/src/error.rs index 3938c2b..58ea4cd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -24,6 +24,10 @@ pub enum Error { RconListCmd(#[source] nom::Err>), #[error("Parsing results of `whois` command: {_0}")] RconWhoIsCmd(#[source] nom::Err>), + #[error("Parsing member filter of `~filter` command: {_0}")] + ParseMemberFilter(#[source] nom::Err>), + #[error("Unknown role {_0:?} in filter")] + UnknownRoleInFilter(String), } pub trait ResultExt { @@ -66,3 +70,23 @@ where } } } + +impl ResultExt for E +where + E: fmt::Display, +{ + fn log_info(self, when: &str) -> Option { + tracing::info!("Error occured while {}: {}", when, self); + None + } + + fn log_warn(self, when: &str) -> Option { + tracing::warn!("Error occured while {}: {}", when, self); + None + } + + fn log_error(self, when: &str) -> Option { + tracing::warn!("Error occured while {}: {}", when, self); + None + } +}