mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-24 18:07:17 +01:00
feat(flags): add support for matching static cohort membership (#25942)
This commit is contained in:
parent
4ce7e9c781
commit
f5a567f01a
@ -108,6 +108,8 @@ pub enum FlagError {
|
||||
CohortFiltersParsingError,
|
||||
#[error("Cohort dependency cycle")]
|
||||
CohortDependencyCycle(String),
|
||||
#[error("Person not found")]
|
||||
PersonNotFound,
|
||||
}
|
||||
|
||||
impl IntoResponse for FlagError {
|
||||
@ -212,6 +214,9 @@ impl IntoResponse for FlagError {
|
||||
tracing::error!("Cohort dependency cycle: {}", msg);
|
||||
(StatusCode::BAD_REQUEST, msg)
|
||||
}
|
||||
FlagError::PersonNotFound => {
|
||||
(StatusCode::BAD_REQUEST, "Person not found. Please check your distinct_id and try again.".to_string())
|
||||
}
|
||||
}
|
||||
.into_response()
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ use petgraph::algo::{is_cyclic_directed, toposort};
|
||||
use petgraph::graph::DiGraph;
|
||||
use serde_json::Value;
|
||||
use sha1::{Digest, Sha1};
|
||||
use sqlx::{postgres::PgQueryResult, Acquire, FromRow};
|
||||
use sqlx::{postgres::PgQueryResult, Acquire, FromRow, Row};
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
@ -26,6 +26,7 @@ use tokio::time::{sleep, timeout};
|
||||
use tracing::{error, info};
|
||||
|
||||
pub type TeamId = i32;
|
||||
pub type PersonId = i32;
|
||||
pub type GroupTypeIndex = i32;
|
||||
pub type PostgresReader = Arc<dyn DatabaseClient + Send + Sync>;
|
||||
pub type PostgresWriter = Arc<dyn DatabaseClient + Send + Sync>;
|
||||
@ -176,6 +177,7 @@ impl GroupTypeMappingCache {
|
||||
/// to fetch the properties from the DB each time.
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct PropertiesCache {
|
||||
person_id: Option<PersonId>,
|
||||
person_properties: Option<HashMap<String, Value>>,
|
||||
group_properties: HashMap<GroupTypeIndex, HashMap<String, Value>>,
|
||||
}
|
||||
@ -217,9 +219,18 @@ impl FeatureFlagMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate feature flags for a given distinct_id
|
||||
/// - Returns a map of feature flag keys to their values
|
||||
/// - If an error occurs while evaluating a flag, it will be logged and the flag will be omitted from the result
|
||||
/// Evaluates all feature flags for the current matcher context.
|
||||
///
|
||||
/// ## Arguments
|
||||
///
|
||||
/// * `feature_flags` - The list of feature flags to evaluate.
|
||||
/// * `person_property_overrides` - Any overrides for person properties.
|
||||
/// * `group_property_overrides` - Any overrides for group properties.
|
||||
/// * `hash_key_override` - Optional hash key overrides for experience continuity.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// * `FlagsResponse` - The result containing flag evaluations and any errors.
|
||||
pub async fn evaluate_all_feature_flags(
|
||||
&mut self,
|
||||
feature_flags: FeatureFlagList,
|
||||
@ -746,22 +757,29 @@ impl FeatureFlagMatcher {
|
||||
.partition(|prop| prop.is_cohort());
|
||||
|
||||
// Get the properties we need to check for in this condition match from the flag + any overrides
|
||||
let target_properties = self
|
||||
let person_or_group_properties = self
|
||||
.get_properties_to_check(feature_flag, property_overrides, &non_cohort_filters)
|
||||
.await?;
|
||||
|
||||
// Evaluate non-cohort filters first, since they're cheaper to evaluate and we can return early if they don't match
|
||||
if !all_properties_match(&non_cohort_filters, &target_properties) {
|
||||
if !all_properties_match(&non_cohort_filters, &person_or_group_properties) {
|
||||
return Ok((false, FeatureFlagMatchReason::NoConditionMatch));
|
||||
}
|
||||
|
||||
// Evaluate cohort filters, if any.
|
||||
if !cohort_filters.is_empty()
|
||||
&& !self
|
||||
.evaluate_cohort_filters(&cohort_filters, &target_properties)
|
||||
if !cohort_filters.is_empty() {
|
||||
// Get the person ID for the current distinct ID – this value should be cached at this point, but as a fallback we fetch from the database
|
||||
let person_id = self.get_person_id().await?;
|
||||
if !self
|
||||
.evaluate_cohort_filters(
|
||||
&cohort_filters,
|
||||
&person_or_group_properties,
|
||||
person_id,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
return Ok((false, FeatureFlagMatchReason::NoConditionMatch));
|
||||
{
|
||||
return Ok((false, FeatureFlagMatchReason::NoConditionMatch));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -809,6 +827,31 @@ impl FeatureFlagMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves the `PersonId` from the properties cache.
|
||||
/// If the cache does not contain a `PersonId`, it fetches it from the database
|
||||
/// and updates the cache accordingly.
|
||||
async fn get_person_id(&mut self) -> Result<PersonId, FlagError> {
|
||||
match self.properties_cache.person_id {
|
||||
Some(id) => Ok(id),
|
||||
None => {
|
||||
let id = self.get_person_id_from_db().await?;
|
||||
self.properties_cache.person_id = Some(id);
|
||||
Ok(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetches the `PersonId` from the database based on the current `distinct_id` and `team_id`.
|
||||
/// This method is called when the `PersonId` is not present in the properties cache.
|
||||
async fn get_person_id_from_db(&mut self) -> Result<PersonId, FlagError> {
|
||||
let postgres_reader = self.postgres_reader.clone();
|
||||
let distinct_id = self.distinct_id.clone();
|
||||
let team_id = self.team_id;
|
||||
fetch_person_properties_from_db(postgres_reader, distinct_id, team_id)
|
||||
.await
|
||||
.map(|(_, person_id)| person_id)
|
||||
}
|
||||
|
||||
/// Get person properties from cache or database.
|
||||
///
|
||||
/// This function attempts to retrieve person properties either from a cache or directly from the database.
|
||||
@ -836,26 +879,45 @@ impl FeatureFlagMatcher {
|
||||
&self,
|
||||
cohort_property_filters: &[PropertyFilter],
|
||||
target_properties: &HashMap<String, Value>,
|
||||
person_id: PersonId,
|
||||
) -> Result<bool, FlagError> {
|
||||
// At the start of the request, fetch all of the cohorts for the team from the cache
|
||||
// This method also caches the cohorts in memory for the duration of the application, so we don't need to fetch from
|
||||
// the database again until we restart the application.
|
||||
// This method also caches any cohorts for a given team in memory for the duration of the application, so we don't need to fetch from
|
||||
// the database again until we restart the application. See the CohortCacheManager for more details.
|
||||
let cohorts = self.cohort_cache.get_cohorts_for_team(self.team_id).await?;
|
||||
|
||||
// Store cohort match results in a HashMap to avoid re-evaluating the same cohort multiple times,
|
||||
// since the same cohort could appear in multiple property filters. This is especially important
|
||||
// because evaluating a cohort requires evaluating all of its dependencies, which can be expensive.
|
||||
// Split the cohorts into static and dynamic, since the dynamic ones have property filters
|
||||
// and we need to evaluate them based on the target properties, whereas the static ones are
|
||||
// purely based on person properties and are membership-based.
|
||||
let (static_cohorts, dynamic_cohorts): (Vec<_>, Vec<_>) =
|
||||
cohorts.iter().partition(|c| c.is_static);
|
||||
|
||||
// Store all cohort match results in a HashMap to avoid re-evaluating the same cohort multiple times,
|
||||
// since the same cohort could appear in multiple property filters.
|
||||
let mut cohort_matches = HashMap::new();
|
||||
for filter in cohort_property_filters {
|
||||
let cohort_id = filter
|
||||
.get_cohort_id()
|
||||
.ok_or(FlagError::CohortFiltersParsingError)?;
|
||||
let match_result =
|
||||
evaluate_cohort_dependencies(cohort_id, target_properties, cohorts.clone())?;
|
||||
cohort_matches.insert(cohort_id, match_result);
|
||||
|
||||
if !static_cohorts.is_empty() {
|
||||
let results = evaluate_static_cohorts(
|
||||
self.postgres_reader.clone(),
|
||||
person_id,
|
||||
static_cohorts.iter().map(|c| c.id).collect(),
|
||||
)
|
||||
.await?;
|
||||
cohort_matches.extend(results);
|
||||
}
|
||||
|
||||
// Apply cohort membership logic (IN|NOT_IN)
|
||||
if !dynamic_cohorts.is_empty() {
|
||||
for filter in cohort_property_filters {
|
||||
let cohort_id = filter
|
||||
.get_cohort_id()
|
||||
.ok_or(FlagError::CohortFiltersParsingError)?;
|
||||
let match_result =
|
||||
evaluate_dynamic_cohorts(cohort_id, target_properties, cohorts.clone())?;
|
||||
cohort_matches.insert(cohort_id, match_result);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply cohort membership logic (IN|NOT_IN) to the cohort match results
|
||||
apply_cohort_membership_logic(cohort_property_filters, &cohort_matches)
|
||||
}
|
||||
|
||||
@ -971,11 +1033,12 @@ impl FeatureFlagMatcher {
|
||||
let postgres_reader = self.postgres_reader.clone();
|
||||
let distinct_id = self.distinct_id.clone();
|
||||
let team_id = self.team_id;
|
||||
let db_properties =
|
||||
let (db_properties, person_id) =
|
||||
fetch_person_properties_from_db(postgres_reader, distinct_id, team_id).await?;
|
||||
|
||||
// once the properties are fetched, cache them so we don't need to fetch again in a given request
|
||||
// once the properties and person ID are fetched, cache them so we don't need to fetch again in a given request
|
||||
self.properties_cache.person_properties = Some(db_properties.clone());
|
||||
self.properties_cache.person_id = Some(person_id);
|
||||
|
||||
Ok(db_properties)
|
||||
}
|
||||
@ -1102,10 +1165,49 @@ impl FeatureFlagMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluates a single cohort and its dependencies.
|
||||
/// Evaluate static cohort filters by checking if the person is in each cohort.
|
||||
async fn evaluate_static_cohorts(
|
||||
postgres_reader: PostgresReader,
|
||||
person_id: i32, // Change this parameter from distinct_id to person_id
|
||||
cohort_ids: Vec<CohortId>,
|
||||
) -> Result<Vec<(CohortId, bool)>, FlagError> {
|
||||
let mut conn = postgres_reader.get_connection().await?;
|
||||
|
||||
let query = r#"
|
||||
WITH cohort_membership AS (
|
||||
SELECT c.cohort_id,
|
||||
CASE WHEN pc.cohort_id IS NOT NULL THEN true ELSE false END AS is_member
|
||||
FROM unnest($1::integer[]) AS c(cohort_id)
|
||||
LEFT JOIN posthog_cohortpeople AS pc
|
||||
ON pc.person_id = $2
|
||||
AND pc.cohort_id = c.cohort_id
|
||||
)
|
||||
SELECT cohort_id, is_member
|
||||
FROM cohort_membership
|
||||
"#;
|
||||
|
||||
let rows = sqlx::query(query)
|
||||
.bind(&cohort_ids)
|
||||
.bind(person_id) // Bind person_id directly
|
||||
.fetch_all(&mut *conn)
|
||||
.await?;
|
||||
|
||||
let result = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let cohort_id: CohortId = row.get("cohort_id");
|
||||
let is_member: bool = row.get("is_member");
|
||||
(cohort_id, is_member)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Evaluates a dynamic cohort and its dependencies.
|
||||
/// This uses a topological sort to evaluate dependencies first, which is necessary
|
||||
/// because a cohort can depend on another cohort, and we need to respect the dependency order.
|
||||
fn evaluate_cohort_dependencies(
|
||||
fn evaluate_dynamic_cohorts(
|
||||
initial_cohort_id: CohortId,
|
||||
target_properties: &HashMap<String, Value>,
|
||||
cohorts: Vec<Cohort>,
|
||||
@ -1221,6 +1323,16 @@ fn build_cohort_dependency_graph(
|
||||
let mut graph = DiGraph::new();
|
||||
let mut node_map = HashMap::new();
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
let initial_cohort = cohorts
|
||||
.iter()
|
||||
.find(|c| c.id == initial_cohort_id)
|
||||
.ok_or(FlagError::CohortNotFound(initial_cohort_id.to_string()))?;
|
||||
|
||||
if initial_cohort.is_static {
|
||||
return Ok(graph);
|
||||
}
|
||||
|
||||
// This implements a breadth-first search (BFS) traversal to build a directed graph of cohort dependencies.
|
||||
// Starting from the initial cohort, we:
|
||||
// 1. Add each cohort as a node in the graph
|
||||
@ -1283,32 +1395,52 @@ async fn fetch_and_locally_cache_all_properties(
|
||||
|
||||
let query = r#"
|
||||
SELECT
|
||||
(SELECT "posthog_person"."properties"
|
||||
FROM "posthog_person"
|
||||
INNER JOIN "posthog_persondistinctid"
|
||||
ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id")
|
||||
WHERE ("posthog_persondistinctid"."distinct_id" = $1
|
||||
AND "posthog_persondistinctid"."team_id" = $2
|
||||
AND "posthog_person"."team_id" = $2)
|
||||
LIMIT 1) as person_properties,
|
||||
|
||||
(SELECT json_object_agg("posthog_group"."group_type_index", "posthog_group"."group_properties")
|
||||
FROM "posthog_group"
|
||||
WHERE ("posthog_group"."team_id" = $2
|
||||
AND "posthog_group"."group_type_index" = ANY($3))) as group_properties
|
||||
person.person_id,
|
||||
person.person_properties,
|
||||
group_properties.group_properties
|
||||
FROM (
|
||||
SELECT
|
||||
"posthog_person"."id" AS person_id,
|
||||
"posthog_person"."properties" AS person_properties
|
||||
FROM "posthog_person"
|
||||
INNER JOIN "posthog_persondistinctid"
|
||||
ON "posthog_person"."id" = "posthog_persondistinctid"."person_id"
|
||||
WHERE
|
||||
"posthog_persondistinctid"."distinct_id" = $1
|
||||
AND "posthog_persondistinctid"."team_id" = $2
|
||||
AND "posthog_person"."team_id" = $2
|
||||
LIMIT 1
|
||||
) AS person,
|
||||
(
|
||||
SELECT
|
||||
json_object_agg(
|
||||
"posthog_group"."group_type_index",
|
||||
"posthog_group"."group_properties"
|
||||
) AS group_properties
|
||||
FROM "posthog_group"
|
||||
WHERE
|
||||
"posthog_group"."team_id" = $2
|
||||
AND "posthog_group"."group_type_index" = ANY($3)
|
||||
) AS group_properties
|
||||
"#;
|
||||
|
||||
let group_type_indexes_vec: Vec<GroupTypeIndex> = group_type_indexes.iter().cloned().collect();
|
||||
|
||||
let row: (Option<Value>, Option<Value>) = sqlx::query_as(query)
|
||||
let row: (Option<i32>, Option<Value>, Option<Value>) = sqlx::query_as(query)
|
||||
.bind(&distinct_id)
|
||||
.bind(team_id)
|
||||
.bind(&group_type_indexes_vec)
|
||||
.fetch_optional(&mut *conn)
|
||||
.await?
|
||||
.unwrap_or((None, None));
|
||||
.unwrap_or((None, None, None));
|
||||
|
||||
if let Some(person_props) = row.0 {
|
||||
let (person_id, person_props, group_props) = row;
|
||||
|
||||
if let Some(person_id) = person_id {
|
||||
properties_cache.person_id = Some(person_id);
|
||||
}
|
||||
|
||||
if let Some(person_props) = person_props {
|
||||
properties_cache.person_properties = Some(
|
||||
person_props
|
||||
.as_object()
|
||||
@ -1319,7 +1451,7 @@ async fn fetch_and_locally_cache_all_properties(
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(group_props) = row.1 {
|
||||
if let Some(group_props) = group_props {
|
||||
let group_props_map: HashMap<GroupTypeIndex, HashMap<String, Value>> = group_props
|
||||
.as_object()
|
||||
.unwrap_or(&serde_json::Map::new())
|
||||
@ -1342,7 +1474,7 @@ async fn fetch_and_locally_cache_all_properties(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fetch person properties from the database for a given distinct ID and team ID.
|
||||
/// Fetch person properties and person ID from the database for a given distinct ID and team ID.
|
||||
///
|
||||
/// This function constructs and executes a SQL query to fetch the person properties for a specified distinct ID and team ID.
|
||||
/// It returns the fetched properties as a HashMap.
|
||||
@ -1350,31 +1482,37 @@ async fn fetch_person_properties_from_db(
|
||||
postgres_reader: PostgresReader,
|
||||
distinct_id: String,
|
||||
team_id: TeamId,
|
||||
) -> Result<HashMap<String, Value>, FlagError> {
|
||||
) -> Result<(HashMap<String, Value>, i32), FlagError> {
|
||||
let mut conn = postgres_reader.as_ref().get_connection().await?;
|
||||
|
||||
let query = r#"
|
||||
SELECT "posthog_person"."properties" as person_properties
|
||||
FROM "posthog_person"
|
||||
INNER JOIN "posthog_persondistinctid" ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id")
|
||||
WHERE ("posthog_persondistinctid"."distinct_id" = $1
|
||||
AND "posthog_persondistinctid"."team_id" = $2
|
||||
AND "posthog_person"."team_id" = $2)
|
||||
LIMIT 1
|
||||
"#;
|
||||
SELECT "posthog_person"."id" as person_id, "posthog_person"."properties" as person_properties
|
||||
FROM "posthog_person"
|
||||
INNER JOIN "posthog_persondistinctid" ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id")
|
||||
WHERE ("posthog_persondistinctid"."distinct_id" = $1
|
||||
AND "posthog_persondistinctid"."team_id" = $2
|
||||
AND "posthog_person"."team_id" = $2)
|
||||
LIMIT 1
|
||||
"#;
|
||||
|
||||
let row: Option<Value> = sqlx::query_scalar(query)
|
||||
let row: Option<(i32, Value)> = sqlx::query_as(query)
|
||||
.bind(&distinct_id)
|
||||
.bind(team_id)
|
||||
.fetch_optional(&mut *conn)
|
||||
.await?;
|
||||
|
||||
Ok(row
|
||||
.and_then(|v| v.as_object().cloned())
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.clone()))
|
||||
.collect())
|
||||
match row {
|
||||
Some((person_id, person_props)) => {
|
||||
let properties_map = person_props
|
||||
.as_object()
|
||||
.unwrap_or(&serde_json::Map::new())
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect();
|
||||
Ok((properties_map, person_id))
|
||||
}
|
||||
None => Err(FlagError::PersonNotFound),
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch group properties from the database for a given team ID and group type index.
|
||||
@ -1436,11 +1574,11 @@ fn locally_computable_property_overrides(
|
||||
/// Check if all properties match the given filters
|
||||
fn all_properties_match(
|
||||
flag_condition_properties: &[PropertyFilter],
|
||||
target_properties: &HashMap<String, Value>,
|
||||
matching_property_values: &HashMap<String, Value>,
|
||||
) -> bool {
|
||||
flag_condition_properties
|
||||
.iter()
|
||||
.all(|property| match_property(property, target_properties, false).unwrap_or(false))
|
||||
.all(|property| match_property(property, matching_property_values, false).unwrap_or(false))
|
||||
}
|
||||
|
||||
async fn get_feature_flag_hash_key_overrides(
|
||||
@ -1663,8 +1801,9 @@ mod tests {
|
||||
OperatorType,
|
||||
},
|
||||
test_utils::{
|
||||
insert_cohort_for_team_in_pg, insert_flag_for_team_in_pg, insert_new_team_in_pg,
|
||||
insert_person_for_team_in_pg, setup_pg_reader_client, setup_pg_writer_client,
|
||||
add_person_to_cohort, get_person_id_by_distinct_id, insert_cohort_for_team_in_pg,
|
||||
insert_flag_for_team_in_pg, insert_new_team_in_pg, insert_person_for_team_in_pg,
|
||||
setup_pg_reader_client, setup_pg_writer_client,
|
||||
},
|
||||
};
|
||||
|
||||
@ -1750,6 +1889,7 @@ mod tests {
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// Matcher for a matching distinct_id
|
||||
let mut matcher = FeatureFlagMatcher::new(
|
||||
distinct_id.clone(),
|
||||
team.id,
|
||||
@ -1763,6 +1903,7 @@ mod tests {
|
||||
assert!(match_result.matches);
|
||||
assert_eq!(match_result.variant, None);
|
||||
|
||||
// Matcher for a non-matching distinct_id
|
||||
let mut matcher = FeatureFlagMatcher::new(
|
||||
not_matching_distinct_id.clone(),
|
||||
team.id,
|
||||
@ -1776,6 +1917,7 @@ mod tests {
|
||||
assert!(!match_result.matches);
|
||||
assert_eq!(match_result.variant, None);
|
||||
|
||||
// Matcher for a distinct_id that does not exist
|
||||
let mut matcher = FeatureFlagMatcher::new(
|
||||
"other_distinct_id".to_string(),
|
||||
team.id,
|
||||
@ -1785,9 +1927,10 @@ mod tests {
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let match_result = matcher.get_match(&flag, None, None).await.unwrap();
|
||||
assert!(!match_result.matches);
|
||||
assert_eq!(match_result.variant, None);
|
||||
let match_result = matcher.get_match(&flag, None, None).await;
|
||||
|
||||
// Expecting an error for non-existent distinct_id
|
||||
assert!(match_result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@ -3106,6 +3249,19 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
insert_person_for_team_in_pg(postgres_reader.clone(), team.id, "lil_id".to_string(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
insert_person_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
"another_id".to_string(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut matcher_test_id = FeatureFlagMatcher::new(
|
||||
"test_id".to_string(),
|
||||
team.id,
|
||||
@ -3265,6 +3421,19 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
insert_person_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
"another_id".to_string(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
insert_person_for_team_in_pg(postgres_reader.clone(), team.id, "lil_id".to_string(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let flag = create_test_flag(
|
||||
Some(1),
|
||||
Some(team.id),
|
||||
@ -3852,6 +4021,344 @@ mod tests {
|
||||
assert!(!result.matches);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_static_cohort_matching_user_in_cohort() {
|
||||
let postgres_reader = setup_pg_reader_client(None).await;
|
||||
let postgres_writer = setup_pg_writer_client(None).await;
|
||||
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
|
||||
let team = insert_new_team_in_pg(postgres_reader.clone(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert a static cohort
|
||||
let cohort = insert_cohort_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
Some("Static Cohort".to_string()),
|
||||
json!({}), // Static cohorts don't have property filters
|
||||
true, // is_static = true
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert a person
|
||||
let distinct_id = "static_user".to_string();
|
||||
insert_person_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
distinct_id.clone(),
|
||||
Some(json!({"email": "static@user.com"})),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Retrieve the person's ID
|
||||
let person_id =
|
||||
get_person_id_by_distinct_id(postgres_reader.clone(), team.id, &distinct_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Associate the person with the static cohort
|
||||
add_person_to_cohort(postgres_reader.clone(), person_id, cohort.id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Define a flag with an 'In' cohort filter
|
||||
let flag = create_test_flag(
|
||||
None,
|
||||
Some(team.id),
|
||||
None,
|
||||
None,
|
||||
Some(FlagFilters {
|
||||
groups: vec![FlagGroupType {
|
||||
properties: Some(vec![PropertyFilter {
|
||||
key: "id".to_string(),
|
||||
value: json!(cohort.id),
|
||||
operator: Some(OperatorType::In),
|
||||
prop_type: "cohort".to_string(),
|
||||
group_type_index: None,
|
||||
negation: Some(false),
|
||||
}]),
|
||||
rollout_percentage: Some(100.0),
|
||||
variant: None,
|
||||
}],
|
||||
multivariate: None,
|
||||
aggregation_group_type_index: None,
|
||||
payloads: None,
|
||||
super_groups: None,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let mut matcher = FeatureFlagMatcher::new(
|
||||
distinct_id.clone(),
|
||||
team.id,
|
||||
postgres_reader.clone(),
|
||||
postgres_writer.clone(),
|
||||
cohort_cache.clone(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let result = matcher.get_match(&flag, None, None).await.unwrap();
|
||||
|
||||
assert!(
|
||||
result.matches,
|
||||
"User should match the static cohort and flag"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_static_cohort_matching_user_not_in_cohort() {
|
||||
let postgres_reader = setup_pg_reader_client(None).await;
|
||||
let postgres_writer = setup_pg_writer_client(None).await;
|
||||
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
|
||||
let team = insert_new_team_in_pg(postgres_reader.clone(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert a static cohort
|
||||
let cohort = insert_cohort_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
Some("Another Static Cohort".to_string()),
|
||||
json!({}), // Static cohorts don't have property filters
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert a person
|
||||
let distinct_id = "non_static_user".to_string();
|
||||
insert_person_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
distinct_id.clone(),
|
||||
Some(json!({"email": "nonstatic@user.com"})),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Note: Do NOT associate the person with the static cohort
|
||||
|
||||
// Define a flag with an 'In' cohort filter
|
||||
let flag = create_test_flag(
|
||||
None,
|
||||
Some(team.id),
|
||||
None,
|
||||
None,
|
||||
Some(FlagFilters {
|
||||
groups: vec![FlagGroupType {
|
||||
properties: Some(vec![PropertyFilter {
|
||||
key: "id".to_string(),
|
||||
value: json!(cohort.id),
|
||||
operator: Some(OperatorType::In),
|
||||
prop_type: "cohort".to_string(),
|
||||
group_type_index: None,
|
||||
negation: Some(false),
|
||||
}]),
|
||||
rollout_percentage: Some(100.0),
|
||||
variant: None,
|
||||
}],
|
||||
multivariate: None,
|
||||
aggregation_group_type_index: None,
|
||||
payloads: None,
|
||||
super_groups: None,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let mut matcher = FeatureFlagMatcher::new(
|
||||
distinct_id.clone(),
|
||||
team.id,
|
||||
postgres_reader.clone(),
|
||||
postgres_writer.clone(),
|
||||
cohort_cache.clone(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let result = matcher.get_match(&flag, None, None).await.unwrap();
|
||||
|
||||
assert!(
|
||||
!result.matches,
|
||||
"User should not match the static cohort and flag"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_static_cohort_not_in_matching_user_not_in_cohort() {
|
||||
let postgres_reader = setup_pg_reader_client(None).await;
|
||||
let postgres_writer = setup_pg_writer_client(None).await;
|
||||
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
|
||||
let team = insert_new_team_in_pg(postgres_reader.clone(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert a static cohort
|
||||
let cohort = insert_cohort_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
Some("Static Cohort NotIn".to_string()),
|
||||
json!({}), // Static cohorts don't have property filters
|
||||
true, // is_static = true
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert a person
|
||||
let distinct_id = "not_in_static_user".to_string();
|
||||
insert_person_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
distinct_id.clone(),
|
||||
Some(json!({"email": "notinstatic@user.com"})),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// No association with the static cohort
|
||||
|
||||
// Define a flag with a 'NotIn' cohort filter
|
||||
let flag = create_test_flag(
|
||||
None,
|
||||
Some(team.id),
|
||||
None,
|
||||
None,
|
||||
Some(FlagFilters {
|
||||
groups: vec![FlagGroupType {
|
||||
properties: Some(vec![PropertyFilter {
|
||||
key: "id".to_string(),
|
||||
value: json!(cohort.id),
|
||||
operator: Some(OperatorType::NotIn),
|
||||
prop_type: "cohort".to_string(),
|
||||
group_type_index: None,
|
||||
negation: Some(false),
|
||||
}]),
|
||||
rollout_percentage: Some(100.0),
|
||||
variant: None,
|
||||
}],
|
||||
multivariate: None,
|
||||
aggregation_group_type_index: None,
|
||||
payloads: None,
|
||||
super_groups: None,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let mut matcher = FeatureFlagMatcher::new(
|
||||
distinct_id.clone(),
|
||||
team.id,
|
||||
postgres_reader.clone(),
|
||||
postgres_writer.clone(),
|
||||
cohort_cache.clone(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let result = matcher.get_match(&flag, None, None).await.unwrap();
|
||||
|
||||
assert!(
|
||||
result.matches,
|
||||
"User not in the static cohort should match the 'NotIn' flag"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_static_cohort_not_in_matching_user_in_cohort() {
|
||||
let postgres_reader = setup_pg_reader_client(None).await;
|
||||
let postgres_writer = setup_pg_writer_client(None).await;
|
||||
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
|
||||
let team = insert_new_team_in_pg(postgres_reader.clone(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert a static cohort
|
||||
let cohort = insert_cohort_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
Some("Static Cohort NotIn User In".to_string()),
|
||||
json!({}), // Static cohorts don't have property filters
|
||||
true, // is_static = true
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert a person
|
||||
let distinct_id = "in_not_in_static_user".to_string();
|
||||
insert_person_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
distinct_id.clone(),
|
||||
Some(json!({"email": "innotinstatic@user.com"})),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Retrieve the person's ID
|
||||
let person_id =
|
||||
get_person_id_by_distinct_id(postgres_reader.clone(), team.id, &distinct_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Associate the person with the static cohort
|
||||
add_person_to_cohort(postgres_reader.clone(), person_id, cohort.id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Define a flag with a 'NotIn' cohort filter
|
||||
let flag = create_test_flag(
|
||||
None,
|
||||
Some(team.id),
|
||||
None,
|
||||
None,
|
||||
Some(FlagFilters {
|
||||
groups: vec![FlagGroupType {
|
||||
properties: Some(vec![PropertyFilter {
|
||||
key: "id".to_string(),
|
||||
value: json!(cohort.id),
|
||||
operator: Some(OperatorType::NotIn),
|
||||
prop_type: "cohort".to_string(),
|
||||
group_type_index: None,
|
||||
negation: Some(false),
|
||||
}]),
|
||||
rollout_percentage: Some(100.0),
|
||||
variant: None,
|
||||
}],
|
||||
multivariate: None,
|
||||
aggregation_group_type_index: None,
|
||||
payloads: None,
|
||||
super_groups: None,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let mut matcher = FeatureFlagMatcher::new(
|
||||
distinct_id.clone(),
|
||||
team.id,
|
||||
postgres_reader.clone(),
|
||||
postgres_writer.clone(),
|
||||
cohort_cache.clone(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
let result = matcher.get_match(&flag, None, None).await.unwrap();
|
||||
|
||||
assert!(
|
||||
!result.matches,
|
||||
"User in the static cohort should not match the 'NotIn' flag"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_feature_flag_hash_key_overrides_success() {
|
||||
let postgres_reader = setup_pg_reader_client(None).await;
|
||||
@ -4095,7 +4602,6 @@ mod tests {
|
||||
.unwrap();
|
||||
let distinct_id = "user4".to_string();
|
||||
|
||||
// Insert person
|
||||
insert_person_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
@ -4168,7 +4674,6 @@ mod tests {
|
||||
.unwrap();
|
||||
let distinct_id = "user5".to_string();
|
||||
|
||||
// Insert person
|
||||
insert_person_for_team_in_pg(
|
||||
postgres_reader.clone(),
|
||||
team.id,
|
||||
|
@ -97,6 +97,7 @@ pub async fn process_request(context: RequestContext) -> Result<FlagsResponse, F
|
||||
let team = request
|
||||
.get_team_from_cache_or_pg(&token, state.redis.clone(), state.postgres_reader.clone())
|
||||
.await?;
|
||||
|
||||
let distinct_id = request.extract_distinct_id()?;
|
||||
let groups = request.groups.clone();
|
||||
let team_id = team.id;
|
||||
|
@ -1,7 +1,7 @@
|
||||
use anyhow::Error;
|
||||
use axum::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, Postgres};
|
||||
use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, Postgres, Row};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
@ -317,7 +317,8 @@ pub async fn insert_person_for_team_in_pg(
|
||||
team_id: i32,
|
||||
distinct_id: String,
|
||||
properties: Option<Value>,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<i32, Error> {
|
||||
// Changed return type to Result<i32, Error>
|
||||
let payload = match properties {
|
||||
Some(value) => value,
|
||||
None => json!({
|
||||
@ -329,7 +330,7 @@ pub async fn insert_person_for_team_in_pg(
|
||||
let uuid = Uuid::now_v7();
|
||||
|
||||
let mut conn = client.get_connection().await?;
|
||||
let res = sqlx::query(
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
WITH inserted_person AS (
|
||||
INSERT INTO posthog_person (
|
||||
@ -337,10 +338,11 @@ pub async fn insert_person_for_team_in_pg(
|
||||
properties_last_operation, team_id, is_user_id, is_identified, uuid, version
|
||||
)
|
||||
VALUES ('2023-04-05', $1, '{}', '{}', $2, NULL, true, $3, 0)
|
||||
RETURNING *
|
||||
RETURNING id
|
||||
)
|
||||
INSERT INTO posthog_persondistinctid (distinct_id, person_id, team_id, version)
|
||||
VALUES ($4, (SELECT id FROM inserted_person), $5, 0)
|
||||
RETURNING person_id
|
||||
"#,
|
||||
)
|
||||
.bind(&payload)
|
||||
@ -348,12 +350,11 @@ pub async fn insert_person_for_team_in_pg(
|
||||
.bind(uuid)
|
||||
.bind(&distinct_id)
|
||||
.bind(team_id)
|
||||
.execute(&mut *conn)
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
|
||||
assert_eq!(res.rows_affected(), 1);
|
||||
|
||||
Ok(())
|
||||
let person_id: i32 = row.get::<i32, _>("person_id");
|
||||
Ok(person_id)
|
||||
}
|
||||
|
||||
pub async fn insert_cohort_for_team_in_pg(
|
||||
@ -410,3 +411,48 @@ pub async fn insert_cohort_for_team_in_pg(
|
||||
|
||||
Ok(Cohort { id, ..cohort })
|
||||
}
|
||||
|
||||
pub async fn get_person_id_by_distinct_id(
|
||||
client: Arc<dyn Client + Send + Sync>,
|
||||
team_id: i32,
|
||||
distinct_id: &str,
|
||||
) -> Result<i32, Error> {
|
||||
let mut conn = client.get_connection().await?;
|
||||
let row: (i32,) = sqlx::query_as(
|
||||
r#"SELECT id FROM posthog_person
|
||||
WHERE team_id = $1 AND id = (
|
||||
SELECT person_id FROM posthog_persondistinctid
|
||||
WHERE team_id = $1 AND distinct_id = $2
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 1"#,
|
||||
)
|
||||
.bind(team_id)
|
||||
.bind(distinct_id)
|
||||
.fetch_one(&mut *conn)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("Person not found"))?;
|
||||
|
||||
Ok(row.0)
|
||||
}
|
||||
|
||||
pub async fn add_person_to_cohort(
|
||||
client: Arc<dyn Client + Send + Sync>,
|
||||
person_id: i32,
|
||||
cohort_id: i32,
|
||||
) -> Result<(), Error> {
|
||||
let mut conn = client.get_connection().await?;
|
||||
let res = sqlx::query(
|
||||
r#"INSERT INTO posthog_cohortpeople (cohort_id, person_id)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING"#,
|
||||
)
|
||||
.bind(cohort_id)
|
||||
.bind(person_id)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
assert!(res.rows_affected() > 0, "Failed to add person to cohort");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user