Skip to content
Snippets Groups Projects
Verified Commit 7155cbb8 authored by Nicolas Werner's avatar Nicolas Werner
Browse files

Implement MSC3664, pushrules for related events

parent c7a13e79
No related branches found
No related tags found
No related merge requests found
Pipeline #3955 passed
......@@ -14,6 +14,7 @@
#include <variant>
#include <vector>
#include "mtx/events/common.hpp"
#include "mtx/events/power_levels.hpp"
namespace mtx {
......@@ -47,6 +48,12 @@ struct PushCondition
//! defaults to ==.
std::string is;
//! The relation type to match on. Only valid for `im.nheko.msc3664.related_event_match`
//! conditions.
mtx::common::RelationType rel_type = mtx::common::RelationType::Unsupported;
//! Wether to match fallback relations or not.
bool include_fallback = false;
friend void to_json(nlohmann::json &obj, const PushCondition &condition);
friend void from_json(const nlohmann::json &obj, PushCondition &condition);
};
......@@ -200,10 +207,14 @@ public:
//! Evaluate the pushrules for @event .
///
/// You need to have the room_id set for the event.
/// `relatedEvents` is a mapping of rel_type to event. Pass all the events that are related to
/// by this event here.
/// \returns the actions to apply.
[[nodiscard]] std::vector<actions::Action> evaluate(
const mtx::events::collections::TimelineEvent &event,
const RoomContext &ctx) const;
const RoomContext &ctx,
const std::vector<std::pair<mtx::common::Relation, mtx::events::collections::TimelineEvent>>
&relatedEvents) const;
private:
struct OptimizedRules;
......
......@@ -204,7 +204,8 @@ from_json(const json &obj, RelationType &type)
type = RelationType::Reference;
else if (obj.get<std::string>() == "m.replace")
type = RelationType::Replace;
else if (obj.get<std::string>() == "im.nheko.relations.v1.in_reply_to")
else if (obj.get<std::string>() == "im.nheko.relations.v1.in_reply_to" ||
obj.get<std::string>() == "m.in_reply_to")
type = RelationType::InReplyTo;
else if (obj.get<std::string>() == "m.thread")
type = RelationType::Thread;
......
......@@ -8,6 +8,15 @@
#include "mtx/events/collections.hpp"
#include "mtx/log.hpp"
namespace {
struct RelatedEvents
{
std::vector<std::unordered_map<std::string, std::string>>
fallbacks; //!< fallback related events
std::vector<std::unordered_map<std::string, std::string>> events; //!< related events
};
}
namespace mtx {
namespace pushrules {
......@@ -21,15 +30,19 @@ to_json(nlohmann::json &obj, const PushCondition &condition)
obj["pattern"] = condition.pattern;
if (!condition.is.empty())
obj["is"] = condition.is;
if (condition.rel_type != mtx::common::RelationType::Unsupported)
obj["rel_type"] = condition.rel_type;
}
void
from_json(const nlohmann::json &obj, PushCondition &condition)
{
condition.kind = obj["kind"].get<std::string>();
condition.key = obj.value("key", "");
condition.pattern = obj.value("pattern", "");
condition.is = obj.value("is", "");
condition.kind = obj["kind"].get<std::string>();
condition.key = obj.value("key", "");
condition.pattern = obj.value("pattern", "");
condition.is = obj.value("is", "");
condition.rel_type = obj.value("rel_type", mtx::common::RelationType::Unsupported);
condition.include_fallback = obj.value("include_fallback", false);
}
namespace actions {
......@@ -180,11 +193,40 @@ struct PushRuleEvaluator::OptimizedRules
//! a pattern condition to match
struct PatternCondition
{
std::unique_ptr<re2::RE2> pattern; //< the pattern
std::string field; //< the field to match with pattern
std::unique_ptr<re2::RE2> pattern; //!< the pattern
std::string field; //!< the field to match with pattern
bool matches(const std::unordered_map<std::string, std::string> &ev) const
{
if (auto it = ev.find(field); it != ev.end()) {
if (pattern) {
if (field == "content.body") {
if (!re2::RE2::PartialMatch(it->second, *pattern))
return false;
} else {
if (!re2::RE2::FullMatch(it->second, *pattern))
return false;
}
}
} else {
return false;
}
return true;
}
};
// TODO(Nico): Sort by field for faster matching?
std::vector<PatternCondition> patterns; //< conditions that match on a field
std::vector<PatternCondition> patterns; //!< conditions that match on a field
//! a pattern condition to match on a related event
struct RelatedEventCondition
{
PatternCondition ev_match;
mtx::common::RelationType rel_type = mtx::common::RelationType::Unsupported;
bool include_fallbacks = false;
};
std::vector<RelatedEventCondition>
related_event_patterns; //!< conditions that match on fields of the related event.
//! a member count condition
struct MemberCountCondition
......@@ -212,8 +254,10 @@ struct PushRuleEvaluator::OptimizedRules
std::vector<actions::Action> actions; //< the actions to apply on match
[[nodiscard]] bool matches(const std::unordered_map<std::string, std::string> &ev,
const PushRuleEvaluator::RoomContext &ctx) const
[[nodiscard]] bool matches(
const std::unordered_map<std::string, std::string> &ev,
const PushRuleEvaluator::RoomContext &ctx,
const std::map<mtx::common::RelationType, RelatedEvents> &relatedEventsFlat) const
{
for (const auto &cond : membercounts) {
if (![&cond, &ctx] {
......@@ -249,19 +293,34 @@ struct PushRuleEvaluator::OptimizedRules
}
for (const auto &cond : patterns) {
if (auto it = ev.find(cond.field); it != ev.end()) {
if (cond.pattern) {
if (cond.field == "content.body") {
if (!re2::RE2::PartialMatch(it->second, *cond.pattern))
return false;
} else {
if (!re2::RE2::FullMatch(it->second, *cond.pattern))
return false;
if (!cond.matches(ev))
return false;
}
for (const auto &cond : related_event_patterns) {
bool matched = false;
for (const auto &[rel_type, rel_ev] : relatedEventsFlat) {
if (cond.rel_type == rel_type) {
for (const auto &e : rel_ev.events) {
if (cond.ev_match.field.empty() || !cond.ev_match.pattern ||
cond.ev_match.matches(e)) {
matched = true;
break;
}
}
if (cond.include_fallbacks) {
for (const auto &e : rel_ev.fallbacks) {
if (cond.ev_match.field.empty() || !cond.ev_match.pattern ||
cond.ev_match.matches(e)) {
matched = true;
break;
}
}
}
}
} else {
return false;
}
if (!matched)
return false;
}
if (check_displayname) {
......@@ -325,6 +384,23 @@ PushRuleEvaluator::PushRuleEvaluator(const Ruleset &rules_)
c.pattern = construct_re_from_pattern(cond.pattern, cond.key);
if (c.pattern)
rule.patterns.push_back(std::move(c));
} else if (cond.kind == "im.nheko.msc3664.related_event_match") {
OptimizedRules::OptimizedRule::RelatedEventCondition c;
if (cond.rel_type != mtx::common::RelationType::Unsupported) {
c.rel_type = cond.rel_type;
c.include_fallbacks = cond.include_fallback;
if (!cond.key.empty() && !cond.pattern.empty()) {
c.ev_match.field = cond.key;
c.ev_match.pattern = construct_re_from_pattern(cond.pattern, cond.key);
}
rule.related_event_patterns.push_back(std::move(c));
} else {
mtx::utils::log::log()->info(
"Skipping rel_event_match rule with unknown rel_type.");
return false;
}
} else if (cond.kind == "contains_display_name") {
rule.check_displayname = true;
} else if (cond.kind == "room_member_count") {
......@@ -479,19 +555,33 @@ flatten_event(const nlohmann::json &j)
}
std::vector<actions::Action>
PushRuleEvaluator::evaluate(const mtx::events::collections::TimelineEvent &event,
const RoomContext &ctx) const
PushRuleEvaluator::evaluate(
const mtx::events::collections::TimelineEvent &event,
const RoomContext &ctx,
const std::vector<std::pair<mtx::common::Relation, mtx::events::collections::TimelineEvent>>
&relatedEvents) const
{
auto event_json = nlohmann::json(event);
auto flat_event = flatten_event(event_json);
std::map<mtx::common::RelationType, RelatedEvents> relatedEventsFlat;
for (const auto &[rel, ev] : relatedEvents) {
if (rel.rel_type != mtx::common::RelationType::Unsupported) {
if (rel.is_fallback)
relatedEventsFlat[rel.rel_type].fallbacks.push_back(
flatten_event(nlohmann::json(ev)));
else
relatedEventsFlat[rel.rel_type].events.push_back(flatten_event(nlohmann::json(ev)));
}
}
for (const auto &rule : rules->override_) {
if (rule.matches(flat_event, ctx))
if (rule.matches(flat_event, ctx, relatedEventsFlat))
return rule.actions;
}
for (const auto &rule : rules->content) {
if (rule.matches(flat_event, ctx))
if (rule.matches(flat_event, ctx, relatedEventsFlat))
return rule.actions;
}
......@@ -508,7 +598,7 @@ PushRuleEvaluator::evaluate(const mtx::events::collections::TimelineEvent &event
}
for (const auto &rule : rules->underride) {
if (rule.matches(flat_event, ctx))
if (rule.matches(flat_event, ctx, relatedEventsFlat))
return rule.actions;
}
return {};
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment