diff --git a/cypress/e2e/signup.cy.ts b/cypress/e2e/signup.cy.ts index bb1a4cf0468..9774236ef81 100644 --- a/cypress/e2e/signup.cy.ts +++ b/cypress/e2e/signup.cy.ts @@ -77,6 +77,9 @@ describe('Signup', () => { cy.get('[data-attr=password]').type(VALID_PASSWORD).should('have.value', VALID_PASSWORD) cy.get('[data-attr=signup-start]').click() cy.get('[data-attr=signup-name]').type('Alice Bob').should('have.value', 'Alice Bob') + cy.get('[data-attr=signup-role-at-organization]').click() + cy.get('.Popover li:first-child').click() + cy.get('[data-attr=signup-role-at-organization]').contains('Engineering') cy.get('[data-attr=signup-submit]').click() cy.wait('@signupRequest').then((interception) => { @@ -93,6 +96,9 @@ describe('Signup', () => { cy.get('[data-attr=password]').type(VALID_PASSWORD).should('have.value', VALID_PASSWORD) cy.get('[data-attr=signup-start]').click() cy.get('[data-attr=signup-name]').type('Alice Bob').should('have.value', 'Alice Bob') + cy.get('[data-attr=signup-role-at-organization]').click() + cy.get('.Popover li:first-child').click() + cy.get('[data-attr=signup-role-at-organization]').contains('Engineering') cy.get('[data-attr=signup-submit]').click() cy.wait('@signupRequest').then(() => { @@ -105,6 +111,9 @@ describe('Signup', () => { const newEmail = `new_user+${Math.floor(Math.random() * 10000)}@posthog.com` cy.get('[data-attr=signup-email]').clear().type(newEmail).should('have.value', newEmail) cy.get('[data-attr=signup-start]').click() + cy.get('[data-attr=signup-role-at-organization]').click() + cy.get('.Popover li:first-child').click() + cy.get('[data-attr=signup-role-at-organization]').contains('Engineering') cy.get('[data-attr=signup-submit]').click() cy.wait('@signupRequest').then((interception) => { diff --git a/ee/clickhouse/views/experiments.py b/ee/clickhouse/views/experiments.py index dc4a3170b93..644445067c4 100644 --- a/ee/clickhouse/views/experiments.py +++ b/ee/clickhouse/views/experiments.py @@ -328,6 +328,7 @@ class ExperimentSerializer(serializers.ModelSerializer): "name": f'Feature Flag for Experiment {validated_data["name"]}', "filters": filters, "active": not is_draft, + "creation_context": "experiments", }, context=self.context, ) diff --git a/frontend/__snapshots__/components-playerinspector--default--dark.png b/frontend/__snapshots__/components-playerinspector--default--dark.png new file mode 100644 index 00000000000..e1d088184e1 Binary files /dev/null and b/frontend/__snapshots__/components-playerinspector--default--dark.png differ diff --git a/frontend/__snapshots__/components-playerinspector--default--light.png b/frontend/__snapshots__/components-playerinspector--default--light.png new file mode 100644 index 00000000000..e9aabac97d7 Binary files /dev/null and b/frontend/__snapshots__/components-playerinspector--default--light.png differ diff --git a/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-breakdown-edit--light.png b/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-breakdown-edit--light.png index 78ea79ea3f7..25ab8cd06fd 100644 Binary files a/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-breakdown-edit--light.png and b/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-breakdown-edit--light.png differ diff --git a/frontend/src/lib/components/SignupRoleSelect.tsx b/frontend/src/lib/components/SignupRoleSelect.tsx index f99d00f1f47..0272145cbcc 100644 --- a/frontend/src/lib/components/SignupRoleSelect.tsx +++ b/frontend/src/lib/components/SignupRoleSelect.tsx @@ -3,7 +3,7 @@ import { LemonSelect } from 'lib/lemon-ui/LemonSelect' export default function SignupRoleSelect({ className }: { className?: string }): JSX.Element { return ( - + ([ } if (props.monaco) { - const defaultQuery = values.featureFlags[FEATURE_FLAGS.SQL_EDITOR] - ? '' - : 'SELECT event FROM events LIMIT 100' + const defaultQuery = 'SELECT event FROM events LIMIT 100' const uri = props.monaco.Uri.parse(currentModelCount.toString()) const model = props.monaco.editor.createModel(defaultQuery, props.language, uri) props.editor?.setModel(model) diff --git a/frontend/src/scenes/authentication/signup/signupForm/signupLogic.ts b/frontend/src/scenes/authentication/signup/signupForm/signupLogic.ts index 1683fdeae13..d1257595c38 100644 --- a/frontend/src/scenes/authentication/signup/signupForm/signupLogic.ts +++ b/frontend/src/scenes/authentication/signup/signupForm/signupLogic.ts @@ -67,7 +67,7 @@ export const signupLogic = kea([ password: !values.preflight?.demo ? !password ? 'Please enter your password to continue' - : values.validatedPassword.feedback + : values.validatedPassword.feedback || undefined : undefined, }), submit: async () => { @@ -83,8 +83,9 @@ export const signupLogic = kea([ role_at_organization: '', referral_source: '', } as SignupForm, - errors: ({ name }) => ({ + errors: ({ name, role_at_organization }) => ({ name: !name ? 'Please enter your name' : undefined, + role_at_organization: !role_at_organization ? 'Please select your role in the organization' : undefined, }), submit: async (payload, breakpoint) => { breakpoint() diff --git a/frontend/src/scenes/billing/Billing.tsx b/frontend/src/scenes/billing/Billing.tsx index fd06e8a6785..5e950359b5f 100644 --- a/frontend/src/scenes/billing/Billing.tsx +++ b/frontend/src/scenes/billing/Billing.tsx @@ -21,6 +21,7 @@ import { BillingCTAHero } from './BillingCTAHero' import { billingLogic } from './billingLogic' import { BillingProduct } from './BillingProduct' import { CreditCTAHero } from './CreditCTAHero' +import { PaymentEntryModal } from './PaymentEntryModal' import { UnsubscribeCard } from './UnsubscribeCard' export const scene: SceneExport = { @@ -82,6 +83,8 @@ export function Billing(): JSX.Element { const platformAndSupportProduct = products?.find((product) => product.type === 'platform_and_support') return (
+ + {showLicenseDirectInput && ( <>
diff --git a/frontend/src/scenes/billing/PaymentEntryModal.tsx b/frontend/src/scenes/billing/PaymentEntryModal.tsx index c4580450167..52092c47d0e 100644 --- a/frontend/src/scenes/billing/PaymentEntryModal.tsx +++ b/frontend/src/scenes/billing/PaymentEntryModal.tsx @@ -1,12 +1,13 @@ import { LemonButton, LemonModal, Spinner } from '@posthog/lemon-ui' import { Elements, PaymentElement, useElements, useStripe } from '@stripe/react-stripe-js' -import { loadStripe } from '@stripe/stripe-js' import { useActions, useValues } from 'kea' -import { useEffect } from 'react' +import { WavingHog } from 'lib/components/hedgehogs' +import { useEffect, useState } from 'react' +import { urls } from 'scenes/urls' import { paymentEntryLogic } from './paymentEntryLogic' -const stripePromise = loadStripe(window.STRIPE_PUBLIC_KEY!) +const stripeJs = async (): Promise => await import('@stripe/stripe-js') export const PaymentForm = (): JSX.Element => { const { error, isLoading } = useValues(paymentEntryLogic) @@ -34,13 +35,17 @@ export const PaymentForm = (): JSX.Element => { setLoading(false) setError(result.error.message) } else { - pollAuthorizationStatus() + pollAuthorizationStatus(result.paymentIntent.id) } } return (
+

+ Your card will not be charged but we place a $0.50 hold on it to verify your card that will be released + in 7 days. +

{error &&
{error}
}
@@ -58,21 +63,34 @@ interface PaymentEntryModalProps { redirectPath?: string | null } -export const PaymentEntryModal = ({ redirectPath = null }: PaymentEntryModalProps): JSX.Element | null => { +export const PaymentEntryModal = ({ + redirectPath = urls.organizationBilling(), +}: PaymentEntryModalProps): JSX.Element => { const { clientSecret, paymentEntryModalOpen } = useValues(paymentEntryLogic) const { hidePaymentEntryModal, initiateAuthorization } = useActions(paymentEntryLogic) + const [stripePromise, setStripePromise] = useState(null) + + useEffect(() => { + // Load Stripe.js asynchronously + const loadStripeJs = async (): Promise => { + const { loadStripe } = await stripeJs() + const publicKey = window.STRIPE_PUBLIC_KEY! + setStripePromise(await loadStripe(publicKey)) + } + void loadStripeJs() + }, []) useEffect(() => { initiateAuthorization(redirectPath) - }, [redirectPath]) + }, [initiateAuthorization, redirectPath]) return (
{clientSecret ? ( @@ -80,9 +98,13 @@ export const PaymentEntryModal = ({ redirectPath = null }: PaymentEntryModalProp ) : ( -
-
- +
+

We're contacting the Hedgehogs for approval.

+
+
+ +
+
)} diff --git a/frontend/src/scenes/billing/paymentEntryLogic.ts b/frontend/src/scenes/billing/paymentEntryLogic.ts index ebedbfe8b8a..ad2b84d0f80 100644 --- a/frontend/src/scenes/billing/paymentEntryLogic.ts +++ b/frontend/src/scenes/billing/paymentEntryLogic.ts @@ -12,7 +12,7 @@ export const paymentEntryLogic = kea({ setLoading: (loading) => ({ loading }), setError: (error) => ({ error }), initiateAuthorization: (redirectPath: string | null) => ({ redirectPath }), - pollAuthorizationStatus: true, + pollAuthorizationStatus: (paymentIntentId?: string) => ({ paymentIntentId }), setAuthorizationStatus: (status: string | null) => ({ status }), showPaymentEntryModal: true, hidePaymentEntryModal: true, @@ -73,7 +73,7 @@ export const paymentEntryLogic = kea({ } }, - pollAuthorizationStatus: async () => { + pollAuthorizationStatus: async ({ paymentIntentId }) => { const pollInterval = 2000 // Poll every 2 seconds const maxAttempts = 30 // Max 1 minute of polling (30 * 2 seconds) let attempts = 0 @@ -81,9 +81,9 @@ export const paymentEntryLogic = kea({ const poll = async (): Promise => { try { const urlParams = new URLSearchParams(window.location.search) - const paymentIntentId = urlParams.get('payment_intent') + const searchPaymentIntentId = urlParams.get('payment_intent') const response = await api.create('api/billing/activate/authorize/status', { - payment_intent_id: paymentIntentId, + payment_intent_id: paymentIntentId || searchPaymentIntentId, }) const status = response.status diff --git a/frontend/src/scenes/data-warehouse/editor/QueryWindow.tsx b/frontend/src/scenes/data-warehouse/editor/QueryWindow.tsx index 1fd177989be..c0dac586cce 100644 --- a/frontend/src/scenes/data-warehouse/editor/QueryWindow.tsx +++ b/frontend/src/scenes/data-warehouse/editor/QueryWindow.tsx @@ -1,19 +1,10 @@ import { Monaco } from '@monaco-editor/react' -import { BindLogic, useActions, useValues } from 'kea' +import { useActions, useValues } from 'kea' import { router } from 'kea-router' -import { - activemodelStateKey, - codeEditorLogic, - CodeEditorLogicProps, - editorModelsStateKey, -} from 'lib/monaco/codeEditorLogic' -import type { editor as importedEditor, Uri } from 'monaco-editor' -import { useCallback, useEffect, useState } from 'react' - -import { dataNodeLogic } from '~/queries/nodes/DataNode/dataNodeLogic' -import { hogQLQueryEditorLogic } from '~/queries/nodes/HogQLQuery/hogQLQueryEditorLogic' -import { HogQLQuery, NodeKind } from '~/queries/schema' +import type { editor as importedEditor } from 'monaco-editor' +import { useState } from 'react' +import { multitabEditorLogic } from './multitabEditorLogic' import { QueryPane } from './QueryPane' import { QueryTabs } from './QueryTabs' import { ResultPane } from './ResultPane' @@ -24,152 +15,54 @@ export function QueryWindow(): JSX.Element { ) const [monaco, editor] = monacoAndEditor ?? [] - const key = router.values.location.pathname - - const [query, setActiveQueryInput] = useState({ - kind: NodeKind.HogQLQuery, - query: '', - }) - - const hogQLQueryEditorLogicProps = { - query, - setQuery: (query: HogQLQuery) => { - setActiveQueryInput(query) - }, - onChange: () => {}, - key, - } - const logic = hogQLQueryEditorLogic(hogQLQueryEditorLogicProps) - const { queryInput, promptError } = useValues(logic) - const { setQueryInput, saveQuery, saveAsView } = useActions(logic) - const codeEditorKey = `hogQLQueryEditor/${router.values.location.pathname}` - const codeEditorLogicProps: CodeEditorLogicProps = { + const logic = multitabEditorLogic({ key: codeEditorKey, - sourceQuery: query, - query: queryInput, - language: 'hogQL', - metadataFilters: query.filters, monaco, editor, - multitab: true, - } - const { activeModelUri, allModels, hasErrors, error, isValidView } = useValues( - codeEditorLogic(codeEditorLogicProps) - ) - - const { createModel, setModel, deleteModel, setModels, addModel, updateState } = useActions( - codeEditorLogic(codeEditorLogicProps) - ) - - const modelKey = `hogQLQueryEditor/${activeModelUri?.path}` - - useEffect(() => { - if (monaco && activeModelUri) { - const _model = monaco.editor.getModel(activeModelUri) - const val = _model?.getValue() - setQueryInput(val ?? '') - saveQuery() - } - }, [activeModelUri]) - - const onAdd = useCallback(() => { - createModel() - }, [createModel]) + }) + const { allTabs, activeModelUri, queryInput, activeQuery, activeTabKey, hasErrors, error, isValidView } = + useValues(logic) + const { selectTab, deleteTab, createTab, setQueryInput, runQuery, saveAsView } = useActions(logic) return (
{ setQueryInput(v ?? '') - updateState() }, onMount: (editor, monaco) => { setMonacoAndEditor([monaco, editor]) - - const allModelQueries = localStorage.getItem(editorModelsStateKey(codeEditorKey)) - const activeModelUri = localStorage.getItem(activemodelStateKey(codeEditorKey)) - - if (allModelQueries) { - // clear existing models - monaco.editor.getModels().forEach((model) => { - model.dispose() - }) - - const models = JSON.parse(allModelQueries || '[]') - const newModels: Uri[] = [] - - models.forEach((model: Record) => { - if (monaco) { - const uri = monaco.Uri.parse(model.path) - const newModel = monaco.editor.createModel(model.query, 'hogQL', uri) - editor?.setModel(newModel) - newModels.push(uri) - } - }) - - setModels(newModels) - - if (activeModelUri) { - const uri = monaco.Uri.parse(activeModelUri) - const activeModel = monaco.editor - .getModels() - .find((model) => model.uri.path === uri.path) - activeModel && editor?.setModel(activeModel) - const val = activeModel?.getValue() - if (val) { - setQueryInput(val) - saveQuery() - } - setModel(uri) - } else if (newModels.length) { - setModel(newModels[0]) - } - } else { - const model = editor.getModel() - - if (model) { - addModel(model.uri) - setModel(model.uri) - } - } }, onPressCmdEnter: (value, selectionType) => { if (value && selectionType === 'selection') { - saveQuery(value) + runQuery(value) } else { - saveQuery() + runQuery() } }, }} /> - - - +
) } diff --git a/frontend/src/scenes/data-warehouse/editor/ResultPane.tsx b/frontend/src/scenes/data-warehouse/editor/ResultPane.tsx index 215a116d07a..40dfee342d2 100644 --- a/frontend/src/scenes/data-warehouse/editor/ResultPane.tsx +++ b/frontend/src/scenes/data-warehouse/editor/ResultPane.tsx @@ -7,6 +7,7 @@ import DataGrid from 'react-data-grid' import { themeLogic } from '~/layout/navigation-3000/themeLogic' import { dataNodeLogic } from '~/queries/nodes/DataNode/dataNodeLogic' +import { NodeKind } from '~/queries/schema' enum ResultsTab { Results = 'results', @@ -17,11 +18,28 @@ interface ResultPaneProps { onSave: () => void saveDisabledReason?: string onQueryInputChange: () => void + logicKey: string + query: string } -export function ResultPane({ onQueryInputChange, onSave, saveDisabledReason }: ResultPaneProps): JSX.Element { +export function ResultPane({ + onQueryInputChange, + onSave, + saveDisabledReason, + logicKey, + query, +}: ResultPaneProps): JSX.Element { const { isDarkModeOn } = useValues(themeLogic) - const { response, responseLoading } = useValues(dataNodeLogic) + const { response, responseLoading } = useValues( + dataNodeLogic({ + key: logicKey, + query: { + kind: NodeKind.HogQLQuery, + query, + }, + doNotLoad: !query, + }) + ) const columns = useMemo(() => { return ( diff --git a/frontend/src/scenes/data-warehouse/editor/multitabEditorLogic.tsx b/frontend/src/scenes/data-warehouse/editor/multitabEditorLogic.tsx new file mode 100644 index 00000000000..7a4a3d4e84e --- /dev/null +++ b/frontend/src/scenes/data-warehouse/editor/multitabEditorLogic.tsx @@ -0,0 +1,332 @@ +import { Monaco } from '@monaco-editor/react' +import { LemonDialog, LemonInput } from '@posthog/lemon-ui' +import { actions, kea, listeners, path, props, propsChanged, reducers, selectors } from 'kea' +import { subscriptions } from 'kea-subscriptions' +import { LemonField } from 'lib/lemon-ui/LemonField' +import { ModelMarker } from 'lib/monaco/codeEditorLogic' +import { editor, MarkerSeverity, Uri } from 'monaco-editor' + +import { dataNodeLogic } from '~/queries/nodes/DataNode/dataNodeLogic' +import { performQuery } from '~/queries/query' +import { HogLanguage, HogQLMetadata, HogQLMetadataResponse, HogQLNotice, HogQLQuery, NodeKind } from '~/queries/schema' + +import { dataWarehouseViewsLogic } from '../saved_queries/dataWarehouseViewsLogic' +import type { multitabEditorLogicType } from './multitabEditorLogicType' + +export interface MultitabEditorLogicProps { + key: string + monaco?: Monaco | null + editor?: editor.IStandaloneCodeEditor | null +} + +export const editorModelsStateKey = (key: string | number): string => `${key}/editorModelQueries` +export const activemodelStateKey = (key: string | number): string => `${key}/activeModelUri` + +export const multitabEditorLogic = kea([ + path(['data-warehouse', 'editor', 'multitabEditorLogic']), + props({} as MultitabEditorLogicProps), + actions({ + setQueryInput: (queryInput: string) => ({ queryInput }), + updateState: true, + runQuery: (queryOverride?: string) => ({ queryOverride }), + setActiveQuery: (query: string) => ({ query }), + setTabs: (tabs: Uri[]) => ({ tabs }), + addTab: (tab: Uri) => ({ tab }), + createTab: () => null, + deleteTab: (tab: Uri) => ({ tab }), + removeTab: (tab: Uri) => ({ tab }), + selectTab: (tab: Uri) => ({ tab }), + setLocalState: (key: string, value: any) => ({ key, value }), + initialize: true, + saveAsView: true, + saveAsViewSuccess: (name: string) => ({ name }), + reloadMetadata: true, + setMetadata: (query: string, metadata: HogQLMetadataResponse) => ({ query, metadata }), + }), + propsChanged(({ actions }, oldProps) => { + if (!oldProps.monaco && !oldProps.editor) { + actions.initialize() + } + }), + reducers(({ props }) => ({ + queryInput: [ + '', + { + setQueryInput: (_, { queryInput }) => queryInput, + }, + ], + activeQuery: [ + null as string | null, + { + setActiveQuery: (_, { query }) => query, + }, + ], + activeModelUri: [ + null as Uri | null, + { + selectTab: (_, { tab }) => tab, + }, + ], + allTabs: [ + [] as Uri[], + { + addTab: (state, { tab }) => { + const newTabs = [...state, tab] + return newTabs + }, + removeTab: (state, { tab: tabToRemove }) => { + const newModels = state.filter((tab) => tab.toString() !== tabToRemove.toString()) + return newModels + }, + setTabs: (_, { tabs }) => tabs, + }, + ], + metadata: [ + null as null | [string, HogQLMetadataResponse], + { + setMetadata: (_, { query, metadata }) => [query, metadata], + }, + ], + modelMarkers: [ + [] as ModelMarker[], + { + setMetadata: (_, { query, metadata }) => { + const model = props.editor?.getModel() + if (!model || !metadata) { + return [] + } + const markers: ModelMarker[] = [] + const metadataResponse = metadata + + function noticeToMarker(error: HogQLNotice, severity: MarkerSeverity): ModelMarker { + const start = model!.getPositionAt(error.start ?? 0) + const end = model!.getPositionAt(error.end ?? query.length) + return { + start: error.start ?? 0, + startLineNumber: start.lineNumber, + startColumn: start.column, + end: error.end ?? query.length, + endLineNumber: end.lineNumber, + endColumn: end.column, + message: error.message ?? 'Unknown error', + severity: severity, + hogQLFix: error.fix, + } + } + + for (const notice of metadataResponse?.errors ?? []) { + markers.push(noticeToMarker(notice, 8 /* MarkerSeverity.Error */)) + } + for (const notice of metadataResponse?.warnings ?? []) { + markers.push(noticeToMarker(notice, 4 /* MarkerSeverity.Warning */)) + } + for (const notice of metadataResponse?.notices ?? []) { + markers.push(noticeToMarker(notice, 1 /* MarkerSeverity.Hint */)) + } + + props.monaco?.editor.setModelMarkers(model, 'hogql', markers) + return markers + }, + }, + ], + })), + listeners(({ values, props, actions }) => ({ + createTab: () => { + let currentModelCount = 1 + const allNumbers = values.allTabs.map((tab) => parseInt(tab.path.split('/').pop() || '0')) + while (allNumbers.includes(currentModelCount)) { + currentModelCount++ + } + + if (props.monaco) { + const uri = props.monaco.Uri.parse(currentModelCount.toString()) + const model = props.monaco.editor.createModel('', 'hogQL', uri) + props.editor?.setModel(model) + actions.addTab(uri) + actions.selectTab(uri) + + const queries = values.allTabs.map((tab) => { + return { + query: props.monaco?.editor.getModel(tab)?.getValue() || '', + path: tab.path.split('/').pop(), + } + }) + actions.setLocalState(editorModelsStateKey(props.key), JSON.stringify(queries)) + } + }, + selectTab: ({ tab }) => { + if (props.monaco) { + const model = props.monaco.editor.getModel(tab) + props.editor?.setModel(model) + } + + const path = tab.path.split('/').pop() + path && actions.setLocalState(activemodelStateKey(props.key), path) + }, + deleteTab: ({ tab: tabToRemove }) => { + if (props.monaco) { + const model = props.monaco.editor.getModel(tabToRemove) + if (tabToRemove == values.activeModelUri) { + const indexOfModel = values.allTabs.findIndex((tab) => tab.toString() === tabToRemove.toString()) + const nextModel = + values.allTabs[indexOfModel + 1] || values.allTabs[indexOfModel - 1] || values.allTabs[0] // there will always be one + actions.selectTab(nextModel) + } + model?.dispose() + actions.removeTab(tabToRemove) + const queries = values.allTabs.map((tab) => { + return { + query: props.monaco?.editor.getModel(tab)?.getValue() || '', + path: tab.path.split('/').pop(), + } + }) + actions.setLocalState(editorModelsStateKey(props.key), JSON.stringify(queries)) + } + }, + setLocalState: ({ key, value }) => { + localStorage.setItem(key, value) + }, + initialize: () => { + const allModelQueries = localStorage.getItem(editorModelsStateKey(props.key)) + const activeModelUri = localStorage.getItem(activemodelStateKey(props.key)) + + if (allModelQueries) { + // clear existing models + props.monaco?.editor.getModels().forEach((model: editor.ITextModel) => { + model.dispose() + }) + + const models = JSON.parse(allModelQueries || '[]') + const newModels: Uri[] = [] + + models.forEach((model: Record) => { + if (props.monaco) { + const uri = props.monaco.Uri.parse(model.path) + const newModel = props.monaco.editor.createModel(model.query, 'hogQL', uri) + props.editor?.setModel(newModel) + newModels.push(uri) + } + }) + + actions.setTabs(newModels) + + if (activeModelUri) { + const uri = props.monaco?.Uri.parse(activeModelUri) + const activeModel = props.monaco?.editor + .getModels() + .find((model: editor.ITextModel) => model.uri.path === uri?.path) + activeModel && props.editor?.setModel(activeModel) + const val = activeModel?.getValue() + if (val) { + actions.setQueryInput(val) + actions.runQuery() + } + uri && actions.selectTab(uri) + } else if (newModels.length) { + actions.selectTab(newModels[0]) + } + } else { + const model = props.editor?.getModel() + + if (model) { + actions.createTab() + } + } + }, + setQueryInput: () => { + actions.updateState() + }, + updateState: async (_, breakpoint) => { + await breakpoint(100) + const queries = values.allTabs.map((model) => { + return { + query: props.monaco?.editor.getModel(model)?.getValue() || '', + path: model.path.split('/').pop(), + } + }) + localStorage.setItem(editorModelsStateKey(props.key), JSON.stringify(queries)) + }, + runQuery: ({ queryOverride }) => { + actions.setActiveQuery(queryOverride || values.queryInput) + }, + saveAsView: async () => { + LemonDialog.openForm({ + title: 'Save as view', + initialValues: { viewName: '' }, + content: ( + + + + ), + errors: { + viewName: (name) => (!name ? 'You must enter a name' : undefined), + }, + onSubmit: ({ viewName }) => actions.saveAsViewSuccess(viewName), + }) + }, + saveAsViewSuccess: async ({ name }) => { + const query: HogQLQuery = { + kind: NodeKind.HogQLQuery, + query: values.queryInput, + } + await dataWarehouseViewsLogic.asyncActions.createDataWarehouseSavedQuery({ name, query }) + }, + reloadMetadata: async (_, breakpoint) => { + const model = props.editor?.getModel() + if (!model || !props.monaco) { + return + } + await breakpoint(300) + const query = values.queryInput + if (query === '') { + return + } + + const response = await performQuery({ + kind: NodeKind.HogQLMetadata, + language: HogLanguage.hogQL, + query: query, + }) + breakpoint() + actions.setMetadata(query, response) + }, + })), + subscriptions(({ props, actions, values }) => ({ + activeModelUri: (activeModelUri) => { + if (props.monaco) { + const _model = props.monaco.editor.getModel(activeModelUri) + const val = _model?.getValue() + actions.setQueryInput(val ?? '') + actions.runQuery() + dataNodeLogic({ + key: values.activeTabKey, + query: { + kind: NodeKind.HogQLQuery, + query: val ?? '', + }, + doNotLoad: !val, + }).mount() + } + }, + queryInput: () => { + actions.reloadMetadata() + }, + })), + selectors({ + activeTabKey: [(s) => [s.activeModelUri], (activeModelUri) => `hogQLQueryEditor/${activeModelUri?.path}`], + isValidView: [(s) => [s.metadata], (metadata) => !!(metadata && metadata[1]?.isValidView)], + hasErrors: [ + (s) => [s.modelMarkers], + (modelMarkers) => !!(modelMarkers ?? []).filter((e) => e.severity === 8 /* MarkerSeverity.Error */).length, + ], + error: [ + (s) => [s.hasErrors, s.modelMarkers], + (hasErrors, modelMarkers) => { + const firstError = modelMarkers.find((e) => e.severity === 8 /* MarkerSeverity.Error */) + return hasErrors && firstError + ? `Error on line ${firstError.startLineNumber}, column ${firstError.startColumn}` + : null + }, + ], + }), +]) diff --git a/frontend/src/scenes/feature-flags/featureFlagLogic.ts b/frontend/src/scenes/feature-flags/featureFlagLogic.ts index 8485d628ed0..875b6f56cf8 100644 --- a/frontend/src/scenes/feature-flags/featureFlagLogic.ts +++ b/frontend/src/scenes/feature-flags/featureFlagLogic.ts @@ -302,7 +302,7 @@ export const featureFlagLogic = kea([ }), forms(({ actions, values }) => ({ featureFlag: { - defaults: { ...NEW_FLAG } as FeatureFlagType, + defaults: { ...NEW_FLAG }, errors: ({ key, filters }) => { return { key: validateFeatureFlagKey(key), diff --git a/frontend/src/scenes/onboarding/Onboarding.tsx b/frontend/src/scenes/onboarding/Onboarding.tsx index 522e28c569a..a4106a69173 100644 --- a/frontend/src/scenes/onboarding/Onboarding.tsx +++ b/frontend/src/scenes/onboarding/Onboarding.tsx @@ -26,6 +26,7 @@ import { ExperimentsSDKInstructions } from './sdks/experiments/ExperimentsSDKIns import { FeatureFlagsSDKInstructions } from './sdks/feature-flags/FeatureFlagsSDKInstructions' import { ProductAnalyticsSDKInstructions } from './sdks/product-analytics/ProductAnalyticsSDKInstructions' import { SDKs } from './sdks/SDKs' +import { sdksLogic } from './sdks/sdksLogic' import { SessionReplaySDKInstructions } from './sdks/session-replay/SessionReplaySDKInstructions' import { SurveysSDKInstructions } from './sdks/surveys/SurveysSDKInstructions' @@ -105,12 +106,16 @@ const OnboardingWrapper = ({ children }: { children: React.ReactNode }): JSX.Ele const ProductAnalyticsOnboarding = (): JSX.Element => { const { currentTeam } = useValues(teamLogic) const { featureFlags } = useValues(featureFlagLogic) + const { combinedSnippetAndLiveEventsHosts } = useValues(sdksLogic) + // mount the logic here so that it stays mounted for the entire onboarding flow // not sure if there is a better way to do this useValues(newDashboardLogic) const showTemplateSteps = - featureFlags[FEATURE_FLAGS.ONBOARDING_DASHBOARD_TEMPLATES] == 'test' && window.innerWidth > 1000 + featureFlags[FEATURE_FLAGS.ONBOARDING_DASHBOARD_TEMPLATES] == 'test' && + window.innerWidth > 1000 && + combinedSnippetAndLiveEventsHosts.length > 0 const options: ProductConfigOption[] = [ { diff --git a/frontend/src/scenes/onboarding/productAnalyticsSteps/DashboardTemplateConfigureStep.tsx b/frontend/src/scenes/onboarding/productAnalyticsSteps/DashboardTemplateConfigureStep.tsx index c2740d598ca..9d30f984a4c 100644 --- a/frontend/src/scenes/onboarding/productAnalyticsSteps/DashboardTemplateConfigureStep.tsx +++ b/frontend/src/scenes/onboarding/productAnalyticsSteps/DashboardTemplateConfigureStep.tsx @@ -52,8 +52,8 @@ const UrlInput = ({ iframeRef }: { iframeRef: React.RefObject return (
setInputValue(v)} diff --git a/frontend/src/scenes/session-recordings/player/inspector/PlayerInspector.stories.tsx b/frontend/src/scenes/session-recordings/player/inspector/PlayerInspector.stories.tsx new file mode 100644 index 00000000000..1c3376a829f --- /dev/null +++ b/frontend/src/scenes/session-recordings/player/inspector/PlayerInspector.stories.tsx @@ -0,0 +1,88 @@ +import { Meta, StoryFn, StoryObj } from '@storybook/react' +import { BindLogic, useActions, useValues } from 'kea' +import { useEffect } from 'react' +import recordingEventsJson from 'scenes/session-recordings/__mocks__/recording_events_query' +import recordingMetaJson from 'scenes/session-recordings/__mocks__/recording_meta.json' +import { snapshotsAsJSONLines } from 'scenes/session-recordings/__mocks__/recording_snapshots' +import { PlayerInspector } from 'scenes/session-recordings/player/inspector/PlayerInspector' +import { sessionRecordingDataLogic } from 'scenes/session-recordings/player/sessionRecordingDataLogic' +import { sessionRecordingPlayerLogic } from 'scenes/session-recordings/player/sessionRecordingPlayerLogic' + +import { mswDecorator } from '~/mocks/browser' + +type Story = StoryObj +const meta: Meta = { + title: 'Components/PlayerInspector', + component: PlayerInspector, + decorators: [ + mswDecorator({ + get: { + '/api/environments/:team_id/session_recordings/:id': recordingMetaJson, + '/api/environments/:team_id/session_recordings/:id/snapshots': (req, res, ctx) => { + // with no sources, returns sources... + if (req.url.searchParams.get('source') === 'blob') { + return res(ctx.text(snapshotsAsJSONLines())) + } + // with no source requested should return sources + return [ + 200, + { + sources: [ + { + source: 'blob', + start_timestamp: '2023-08-11T12:03:36.097000Z', + end_timestamp: '2023-08-11T12:04:52.268000Z', + blob_key: '1691755416097-1691755492268', + }, + ], + }, + ] + }, + }, + post: { + '/api/environments/:team_id/query': (req, res, ctx) => { + const body = req.body as Record + if (body.query.kind === 'EventsQuery' && body.query.properties.length === 1) { + return res(ctx.json(recordingEventsJson)) + } + + // default to an empty response or we duplicate information + return res(ctx.json({ results: [] })) + }, + }, + }), + ], +} +export default meta + +const BasicTemplate: StoryFn = () => { + const dataLogic = sessionRecordingDataLogic({ sessionRecordingId: '12345', playerKey: 'story-template' }) + const { sessionPlayerMetaData } = useValues(dataLogic) + + const { loadSnapshots, loadEvents } = useActions(dataLogic) + loadSnapshots() + + // TODO you have to call actions in a particular order + // and only when some other data has already been loaded + // 🫠 + useEffect(() => { + loadEvents() + }, [sessionPlayerMetaData]) + + return ( +
+ + + +
+ ) +} + +export const Default: Story = BasicTemplate.bind({}) +Default.args = {} diff --git a/frontend/src/scenes/session-recordings/player/inspector/PlayerInspector.tsx b/frontend/src/scenes/session-recordings/player/inspector/PlayerInspector.tsx new file mode 100644 index 00000000000..dc8c712cef4 --- /dev/null +++ b/frontend/src/scenes/session-recordings/player/inspector/PlayerInspector.tsx @@ -0,0 +1,11 @@ +import { PlayerInspectorControls } from 'scenes/session-recordings/player/inspector/PlayerInspectorControls' +import { PlayerInspectorList } from 'scenes/session-recordings/player/inspector/PlayerInspectorList' + +export function PlayerInspector(): JSX.Element { + return ( + <> + + + + ) +} diff --git a/frontend/src/scenes/session-recordings/player/sidebar/PlayerSidebarTab.tsx b/frontend/src/scenes/session-recordings/player/sidebar/PlayerSidebarTab.tsx index 8f2c12055f2..9c69a46274d 100644 --- a/frontend/src/scenes/session-recordings/player/sidebar/PlayerSidebarTab.tsx +++ b/frontend/src/scenes/session-recordings/player/sidebar/PlayerSidebarTab.tsx @@ -1,9 +1,8 @@ import { useValues } from 'kea' +import { PlayerInspector } from 'scenes/session-recordings/player/inspector/PlayerInspector' import { SessionRecordingSidebarTab } from '~/types' -import { PlayerInspectorControls } from '../inspector/PlayerInspectorControls' -import { PlayerInspectorList } from '../inspector/PlayerInspectorList' import { PlayerSidebarDebuggerTab } from './PlayerSidebarDebuggerTab' import { playerSidebarLogic } from './playerSidebarLogic' import { PlayerSidebarOverviewTab } from './PlayerSidebarOverviewTab' @@ -15,12 +14,7 @@ export function PlayerSidebarTab(): JSX.Element | null { case SessionRecordingSidebarTab.OVERVIEW: return case SessionRecordingSidebarTab.INSPECTOR: - return ( - <> - - - - ) + return case SessionRecordingSidebarTab.DEBUGGER: return default: diff --git a/frontend/src/scenes/surveys/SurveyEditQuestionRow.tsx b/frontend/src/scenes/surveys/SurveyEditQuestionRow.tsx index 41e5805fc05..237aaa78597 100644 --- a/frontend/src/scenes/surveys/SurveyEditQuestionRow.tsx +++ b/frontend/src/scenes/surveys/SurveyEditQuestionRow.tsx @@ -10,7 +10,7 @@ import { Group } from 'kea-forms' import { SortableDragIcon } from 'lib/lemon-ui/icons' import { LemonField } from 'lib/lemon-ui/LemonField' -import { Survey, SurveyQuestionType } from '~/types' +import { Survey, SurveyQuestionType, SurveyType } from '~/types' import { defaultSurveyFieldValues, NewSurvey, SurveyQuestionLabel } from './constants' import { QuestionBranchingInput } from './QuestionBranchingInput' @@ -315,7 +315,7 @@ export function SurveyEditQuestionGroup({ index, question }: { index: number; qu ) })}
- {(value || []).length < 6 && ( + {((value || []).length < 6 || survey.type != SurveyType.Popover) && ( <> } diff --git a/funnel-udf/src/steps.rs b/funnel-udf/src/steps.rs index 9330caf4e1c..21d6fb84e62 100644 --- a/funnel-udf/src/steps.rs +++ b/funnel-udf/src/steps.rs @@ -55,7 +55,7 @@ const DEFAULT_ENTERED_TIMESTAMP: EnteredTimestamp = EnteredTimestamp { }; pub fn process_line(line: &str) -> Value { - let args = parse_args(&line); + let args = parse_args(line); let mut aggregate_funnel_row = AggregateFunnelRow { results: Vec::with_capacity(args.prop_vals.len()), breakdown_step: Option::None, @@ -112,7 +112,7 @@ impl AggregateFunnelRow { self.process_event( args, &mut vars, - &events_with_same_timestamp[0], + events_with_same_timestamp[0], prop_val, false ); @@ -147,7 +147,7 @@ impl AggregateFunnelRow { args, &mut vars, &event, - &prop_val, + prop_val, true ); } @@ -261,4 +261,4 @@ impl AggregateFunnelRow { } } } -} \ No newline at end of file +} diff --git a/funnel-udf/src/trends.rs b/funnel-udf/src/trends.rs index fa7dc162c12..42356dc06d1 100644 --- a/funnel-udf/src/trends.rs +++ b/funnel-udf/src/trends.rs @@ -81,7 +81,7 @@ const DEFAULT_ENTERED_TIMESTAMP: EnteredTimestamp = EnteredTimestamp { }; pub fn process_line(line: &str) -> Value { - let args = parse_args(&line); + let args = parse_args(line); let mut aggregate_funnel_row = AggregateFunnelRow { results: HashMap::new(), breakdown_step: Option::None, @@ -128,7 +128,7 @@ impl AggregateFunnelRow { self.process_event( args, &mut vars, - &event, + event, prop_val, ); } @@ -242,4 +242,4 @@ impl AggregateFunnelRow { } } } -} \ No newline at end of file +} diff --git a/posthog/api/early_access_feature.py b/posthog/api/early_access_feature.py index 57885666fde..004725393b4 100644 --- a/posthog/api/early_access_feature.py +++ b/posthog/api/early_access_feature.py @@ -203,6 +203,7 @@ class EarlyAccessFeatureSerializerCreateOnly(EarlyAccessFeatureSerializer): "key": feature_flag_key, "name": f"Feature Flag for Feature {validated_data['name']}", "filters": filters, + "creation_context": "early_access_features", }, context=self.context, ) diff --git a/posthog/api/feature_flag.py b/posthog/api/feature_flag.py index d24ee4499a4..5c1485a6d86 100644 --- a/posthog/api/feature_flag.py +++ b/posthog/api/feature_flag.py @@ -115,6 +115,14 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo ) can_edit = serializers.SerializerMethodField() + CREATION_CONTEXT_CHOICES = ("feature_flags", "experiments", "surveys", "early_access_features", "web_experiments") + creation_context = serializers.ChoiceField( + choices=CREATION_CONTEXT_CHOICES, + write_only=True, + required=False, + help_text="Indicates the origin product of the feature flag. Choices: 'feature_flags', 'experiments', 'surveys', 'early_access_features', 'web_experiments'.", + ) + class Meta: model = FeatureFlag fields = [ @@ -139,6 +147,7 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo "usage_dashboard", "analytics_dashboards", "has_enriched_analytics", + "creation_context", ] def get_can_edit(self, feature_flag: FeatureFlag) -> bool: @@ -317,6 +326,9 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo validated_data["created_by"] = request.user validated_data["team_id"] = self.context["team_id"] tags = validated_data.pop("tags", None) # tags are created separately below as global tag relationships + creation_context = validated_data.pop( + "creation_context", "feature_flags" + ) # default to "feature_flags" if an alternative value is not provided self._update_filters(validated_data) @@ -347,7 +359,9 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo _create_usage_dashboard(instance, request.user) - report_user_action(request.user, "feature flag created", instance.get_analytics_metadata()) + analytics_metadata = instance.get_analytics_metadata() + analytics_metadata["creation_context"] = creation_context + report_user_action(request.user, "feature flag created", analytics_metadata) return instance diff --git a/posthog/api/survey.py b/posthog/api/survey.py index 1cd91b881c0..3d13981a867 100644 --- a/posthog/api/survey.py +++ b/posthog/api/survey.py @@ -640,6 +640,7 @@ class SurveySerializerCreateUpdateOnly(serializers.ModelSerializer): "name": f"Targeting flag for survey {name}", "filters": filters, "active": active, + "creation_context": "surveys", }, context=self.context, ) diff --git a/posthog/api/test/test_early_access_feature.py b/posthog/api/test/test_early_access_feature.py index 89d1d7369d0..311fbae3cb1 100644 --- a/posthog/api/test/test_early_access_feature.py +++ b/posthog/api/test/test_early_access_feature.py @@ -3,6 +3,7 @@ from unittest.mock import ANY from rest_framework import status from django.core.cache import cache from django.test.client import Client +from unittest.mock import patch from posthog.models.early_access_feature import EarlyAccessFeature from posthog.models import FeatureFlag, Person @@ -520,6 +521,36 @@ class TestEarlyAccessFeature(APIBaseTest): ], } + @patch("posthog.api.feature_flag.report_user_action") + def test_creation_context_is_set_to_early_access_features(self, mock_capture): + response = self.client.post( + f"/api/projects/{self.team.id}/early_access_feature/", + data={ + "name": "Hick bondoogling", + "description": 'Boondoogle your hicks with one click. Just click "bazinga"!', + "stage": "concept", + }, + format="json", + ) + response_data = response.json() + ff_instance = FeatureFlag.objects.get(id=response_data["feature_flag"]["id"]) + mock_capture.assert_called_once_with( + ANY, + "feature flag created", + { + "groups_count": 1, + "has_variants": False, + "variants_count": 0, + "has_rollout_percentage": False, + "has_filters": False, + "filter_count": 0, + "created_at": ff_instance.created_at, + "aggregating_by_groups": False, + "payload_count": 0, + "creation_context": "early_access_features", + }, + ) + class TestPreviewList(BaseTest, QueryMatchingTest): def setUp(self): diff --git a/posthog/api/test/test_feature_flag.py b/posthog/api/test/test_feature_flag.py index b5a2cfd6d18..2d4745313b9 100644 --- a/posthog/api/test/test_feature_flag.py +++ b/posthog/api/test/test_feature_flag.py @@ -300,6 +300,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin): "created_at": instance.created_at, "aggregating_by_groups": True, "payload_count": 0, + "creation_context": "feature_flags", }, ) @@ -334,6 +335,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin): "created_at": instance.created_at, "aggregating_by_groups": False, "payload_count": 0, + "creation_context": "feature_flags", }, ) @@ -385,6 +387,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin): "created_at": instance.created_at, "aggregating_by_groups": False, "payload_count": 0, + "creation_context": "feature_flags", }, ) @@ -438,6 +441,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin): "created_at": instance.created_at, "aggregating_by_groups": False, "payload_count": 0, + "creation_context": "feature_flags", }, ) diff --git a/posthog/api/test/test_survey.py b/posthog/api/test/test_survey.py index cb124c9b970..ee1cc97a696 100644 --- a/posthog/api/test/test_survey.py +++ b/posthog/api/test/test_survey.py @@ -60,6 +60,59 @@ class TestSurvey(APIBaseTest): ] assert response_data["created_by"]["id"] == self.user.id + @patch("posthog.api.feature_flag.report_user_action") + def test_creation_context_is_set_to_surveys(self, mock_capture): + response = self.client.post( + f"/api/projects/{self.team.id}/surveys/", + data={ + "name": "survey with targeting", + "type": "popover", + "targeting_flag_filters": { + "groups": [ + { + "variant": None, + "rollout_percentage": None, + "properties": [ + { + "key": "billing_plan", + "value": ["cloud"], + "operator": "exact", + "type": "person", + } + ], + } + ] + }, + "conditions": {"url": "https://app.posthog.com/notebooks"}, + }, + format="json", + ) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + response_data = response.json() + + # Ensure that a FeatureFlag has been created + ff_instance = FeatureFlag.objects.get(id=response_data["internal_targeting_flag"]["id"]) + self.assertIsNotNone(ff_instance) + + # Verify that report_user_action was called for the feature flag creation + mock_capture.assert_any_call( + ANY, + "feature flag created", + { + "groups_count": 1, + "has_variants": False, + "variants_count": 0, + "has_rollout_percentage": True, + "has_filters": True, + "filter_count": 2, + "created_at": ff_instance.created_at, + "aggregating_by_groups": False, + "payload_count": 0, + "creation_context": "surveys", + }, + ) + def test_create_adds_user_interactivity_filters(self): response = self.client.post( f"/api/projects/{self.team.id}/surveys/", diff --git a/posthog/api/test/test_web_experiment.py b/posthog/api/test/test_web_experiment.py index 679df4411c7..7b53e2ce0fa 100644 --- a/posthog/api/test/test_web_experiment.py +++ b/posthog/api/test/test_web_experiment.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta from rest_framework import status +from unittest.mock import ANY, patch from posthog.models import WebExperiment from posthog.test.base import APIBaseTest @@ -30,7 +31,8 @@ class TestWebExperiment(APIBaseTest): format="json", ) - def test_can_create_basic_web_experiment(self): + @patch("posthog.api.feature_flag.report_user_action") + def test_can_create_basic_web_experiment(self, mock_capture): response = self._create_web_experiment() response_data = response.json() assert response.status_code == status.HTTP_201_CREATED, response_data @@ -53,6 +55,22 @@ class TestWebExperiment(APIBaseTest): assert web_experiment.type == "web" assert web_experiment.variants.get("control") is not None assert web_experiment.variants.get("test") is not None + mock_capture.assert_called_once_with( + ANY, + "feature flag created", + { + "groups_count": 1, + "has_variants": True, + "variants_count": 2, + "has_rollout_percentage": True, + "has_filters": False, + "filter_count": 0, + "created_at": linked_flag.created_at, + "aggregating_by_groups": False, + "payload_count": 0, + "creation_context": "web_experiments", + }, + ) def test_can_list_active_web_experiments(self): response = self._create_web_experiment("active_web_experiment") diff --git a/posthog/api/web_experiment.py b/posthog/api/web_experiment.py index 81aae23f2da..d90d400404d 100644 --- a/posthog/api/web_experiment.py +++ b/posthog/api/web_experiment.py @@ -98,6 +98,7 @@ class WebExperimentsAPISerializer(serializers.ModelSerializer): "name": f'Feature Flag for Experiment {validated_data["name"]}', "filters": filters, "active": False, + "creation_context": "web_experiments", }, context=self.context, ) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index b99943cc4e5..9a263a87878 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1808,7 +1808,9 @@ dependencies = [ "futures", "health", "maxminddb", + "moka", "once_cell", + "petgraph", "rand", "redis", "regex", @@ -3046,9 +3048,13 @@ version = "0.12.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32cf62eb4dd975d2dde76432fb1075c49e3ee2331cf36f1f8fd4b66550d32b6f" dependencies = [ + "async-lock 3.4.0", + "async-trait", "crossbeam-channel", "crossbeam-epoch", "crossbeam-utils", + "event-listener 5.3.1", + "futures-util", "once_cell", "parking_lot", "quanta 0.12.2", diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index 4cf4016767b..4099fd8ab06 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -39,6 +39,8 @@ health = { path = "../common/health" } common-metrics = { path = "../common/metrics" } tower = { workspace = true } derive_builder = "0.20.1" +petgraph = "0.6.5" +moka = { version = "0.12.8", features = ["future"] } [lints] workspace = true diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index 4430476d28a..be21c1c37f5 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -89,7 +89,7 @@ pub enum FlagError { #[error("Row not found in postgres")] RowNotFound, #[error("failed to parse redis cache data")] - DataParsingError, + RedisDataParsingError, #[error("failed to update redis cache")] CacheUpdateError, #[error("redis unavailable")] @@ -102,6 +102,14 @@ pub enum FlagError { TimeoutError, #[error("No group type mappings")] NoGroupTypeMappings, + #[error("Cohort not found")] + CohortNotFound(String), + #[error("Failed to parse cohort filters")] + CohortFiltersParsingError, + #[error("Cohort dependency cycle")] + CohortDependencyCycle(String), + #[error("Person not found")] + PersonNotFound, } impl IntoResponse for FlagError { @@ -138,7 +146,7 @@ impl IntoResponse for FlagError { FlagError::TokenValidationError => { (StatusCode::UNAUTHORIZED, "The provided API key is invalid or has expired. Please check your API key and try again.".to_string()) } - FlagError::DataParsingError => { + FlagError::RedisDataParsingError => { tracing::error!("Data parsing error: {:?}", self); ( StatusCode::SERVICE_UNAVAILABLE, @@ -194,6 +202,21 @@ impl IntoResponse for FlagError { "The requested row was not found in the database. Please try again later or contact support if the problem persists.".to_string(), ) } + FlagError::CohortNotFound(msg) => { + tracing::error!("Cohort not found: {}", msg); + (StatusCode::NOT_FOUND, msg) + } + FlagError::CohortFiltersParsingError => { + tracing::error!("Failed to parse cohort filters: {:?}", self); + (StatusCode::BAD_REQUEST, "Failed to parse cohort filters. Please try again later or contact support if the problem persists.".to_string()) + } + FlagError::CohortDependencyCycle(msg) => { + tracing::error!("Cohort dependency cycle: {}", msg); + (StatusCode::BAD_REQUEST, msg) + } + FlagError::PersonNotFound => { + (StatusCode::BAD_REQUEST, "Person not found. Please check your distinct_id and try again.".to_string()) + } } .into_response() } @@ -205,7 +228,7 @@ impl From for FlagError { CustomRedisError::NotFound => FlagError::TokenValidationError, CustomRedisError::PickleError(e) => { tracing::error!("failed to fetch data: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError } CustomRedisError::Timeout(_) => FlagError::TimeoutError, CustomRedisError::Other(e) => { diff --git a/rust/feature-flags/src/cohort_cache.rs b/rust/feature-flags/src/cohort_cache.rs new file mode 100644 index 00000000000..68894c19f88 --- /dev/null +++ b/rust/feature-flags/src/cohort_cache.rs @@ -0,0 +1,221 @@ +use crate::api::FlagError; +use crate::cohort_models::Cohort; +use crate::flag_matching::{PostgresReader, TeamId}; +use moka::future::Cache; +use std::time::Duration; + +/// CohortCacheManager manages the in-memory cache of cohorts using `moka` for caching. +/// +/// Features: +/// - **TTL**: Each cache entry expires after 5 minutes. +/// - **Size-based eviction**: The cache evicts least recently used entries when the maximum capacity is reached. +/// +/// ```text +/// CohortCacheManager { +/// postgres_reader: PostgresReader, +/// per_team_cohorts: Cache> { +/// // Example: +/// 2: [ +/// Cohort { id: 1, name: "Power Users", filters: {...} }, +/// Cohort { id: 2, name: "Churned", filters: {...} } +/// ], +/// 5: [ +/// Cohort { id: 3, name: "Beta Users", filters: {...} } +/// ] +/// } +/// } +/// ``` +/// +#[derive(Clone)] +pub struct CohortCacheManager { + postgres_reader: PostgresReader, + per_team_cohort_cache: Cache>, +} + +impl CohortCacheManager { + pub fn new( + postgres_reader: PostgresReader, + max_capacity: Option, + ttl_seconds: Option, + ) -> Self { + // We use the size of the cohort list (i.e., the number of cohorts for a given team)as the weight of the entry + let weigher = + |_: &TeamId, value: &Vec| -> u32 { value.len().try_into().unwrap_or(u32::MAX) }; + + let cache = Cache::builder() + .time_to_live(Duration::from_secs(ttl_seconds.unwrap_or(300))) // Default to 5 minutes + .weigher(weigher) + .max_capacity(max_capacity.unwrap_or(10_000)) // Default to 10,000 cohorts + .build(); + + Self { + postgres_reader, + per_team_cohort_cache: cache, + } + } + + /// Retrieves cohorts for a given team. + /// + /// If the cohorts are not present in the cache or have expired, it fetches them from the database, + /// caches the result upon successful retrieval, and then returns it. + pub async fn get_cohorts_for_team(&self, team_id: TeamId) -> Result, FlagError> { + if let Some(cached_cohorts) = self.per_team_cohort_cache.get(&team_id).await { + return Ok(cached_cohorts.clone()); + } + let fetched_cohorts = Cohort::list_from_pg(self.postgres_reader.clone(), team_id).await?; + self.per_team_cohort_cache + .insert(team_id, fetched_cohorts.clone()) + .await; + + Ok(fetched_cohorts) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cohort_models::Cohort; + use crate::test_utils::{ + insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client, + setup_pg_writer_client, + }; + use std::sync::Arc; + use tokio::time::{sleep, Duration}; + + /// Helper function to setup a new team for testing. + async fn setup_test_team( + writer_client: Arc, + ) -> Result { + let team = crate::test_utils::insert_new_team_in_pg(writer_client, None).await?; + Ok(team.id) + } + + /// Helper function to insert a cohort for a team. + async fn setup_test_cohort( + writer_client: Arc, + team_id: TeamId, + name: Option, + ) -> Result { + let filters = serde_json::json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}}); + insert_cohort_for_team_in_pg(writer_client, team_id, name, filters, false).await + } + + /// Tests that cache entries expire after the specified TTL. + #[tokio::test] + async fn test_cache_expiry() -> Result<(), anyhow::Error> { + let writer_client = setup_pg_writer_client(None).await; + let reader_client = setup_pg_reader_client(None).await; + + let team_id = setup_test_team(writer_client.clone()).await?; + let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?; + + // Initialize CohortCacheManager with a short TTL for testing + let cohort_cache = CohortCacheManager::new( + reader_client.clone(), + Some(100), + Some(1), // 1-second TTL + ); + + let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?; + assert_eq!(cohorts.len(), 1); + assert_eq!(cohorts[0].team_id, team_id); + + let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await; + assert!(cached_cohorts.is_some()); + + // Wait for TTL to expire + sleep(Duration::from_secs(2)).await; + + // Attempt to retrieve from cache again + let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await; + assert!(cached_cohorts.is_none(), "Cache entry should have expired"); + + Ok(()) + } + + /// Tests that the cache correctly evicts least recently used entries based on the weigher. + #[tokio::test] + async fn test_cache_weigher() -> Result<(), anyhow::Error> { + let writer_client = setup_pg_writer_client(None).await; + let reader_client = setup_pg_reader_client(None).await; + + // Define a smaller max_capacity for testing + let max_capacity: u64 = 3; + + let cohort_cache = CohortCacheManager::new(reader_client.clone(), Some(max_capacity), None); + + let mut inserted_team_ids = Vec::new(); + + // Insert multiple teams and their cohorts + for _ in 0..max_capacity { + let team = insert_new_team_in_pg(writer_client.clone(), None).await?; + let team_id = team.id; + inserted_team_ids.push(team_id); + setup_test_cohort(writer_client.clone(), team_id, None).await?; + cohort_cache.get_cohorts_for_team(team_id).await?; + } + + cohort_cache.per_team_cohort_cache.run_pending_tasks().await; + let cache_size = cohort_cache.per_team_cohort_cache.entry_count(); + assert_eq!( + cache_size, max_capacity, + "Cache size should be equal to max_capacity" + ); + + let new_team = insert_new_team_in_pg(writer_client.clone(), None).await?; + let new_team_id = new_team.id; + setup_test_cohort(writer_client.clone(), new_team_id, None).await?; + cohort_cache.get_cohorts_for_team(new_team_id).await?; + + cohort_cache.per_team_cohort_cache.run_pending_tasks().await; + let cache_size_after = cohort_cache.per_team_cohort_cache.entry_count(); + assert_eq!( + cache_size_after, max_capacity, + "Cache size should remain equal to max_capacity after eviction" + ); + + let evicted_team_id = &inserted_team_ids[0]; + let cached_cohorts = cohort_cache + .per_team_cohort_cache + .get(evicted_team_id) + .await; + assert!( + cached_cohorts.is_none(), + "Least recently used cache entry should have been evicted" + ); + + let cached_new_team = cohort_cache.per_team_cohort_cache.get(&new_team_id).await; + assert!( + cached_new_team.is_some(), + "Newly added cache entry should be present" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_get_cohorts_for_team() -> Result<(), anyhow::Error> { + let writer_client = setup_pg_writer_client(None).await; + let reader_client = setup_pg_reader_client(None).await; + let team_id = setup_test_team(writer_client.clone()).await?; + let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?; + let cohort_cache = CohortCacheManager::new(reader_client.clone(), None, None); + + let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await; + assert!(cached_cohorts.is_none(), "Cache should initially be empty"); + + let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?; + assert_eq!(cohorts.len(), 1); + assert_eq!(cohorts[0].team_id, team_id); + + let cached_cohorts = cohort_cache + .per_team_cohort_cache + .get(&team_id) + .await + .unwrap(); + assert_eq!(cached_cohorts.len(), 1); + assert_eq!(cached_cohorts[0].team_id, team_id); + + Ok(()) + } +} diff --git a/rust/feature-flags/src/cohort_models.rs b/rust/feature-flags/src/cohort_models.rs new file mode 100644 index 00000000000..d1099839017 --- /dev/null +++ b/rust/feature-flags/src/cohort_models.rs @@ -0,0 +1,50 @@ +use crate::flag_definitions::PropertyFilter; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct Cohort { + pub id: i32, + pub name: String, + pub description: Option, + pub team_id: i32, + pub deleted: bool, + pub filters: serde_json::Value, + pub query: Option, + pub version: Option, + pub pending_version: Option, + pub count: Option, + pub is_calculating: bool, + pub is_static: bool, + pub errors_calculating: i32, + pub groups: serde_json::Value, + pub created_by_id: Option, +} + +pub type CohortId = i32; + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "UPPERCASE")] +pub enum CohortPropertyType { + AND, + OR, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CohortProperty { + pub properties: InnerCohortProperty, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct InnerCohortProperty { + #[serde(rename = "type")] + pub prop_type: CohortPropertyType, + pub values: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CohortValues { + #[serde(rename = "type")] + pub prop_type: String, + pub values: Vec, +} diff --git a/rust/feature-flags/src/cohort_operations.rs b/rust/feature-flags/src/cohort_operations.rs new file mode 100644 index 00000000000..ea4214ccdc0 --- /dev/null +++ b/rust/feature-flags/src/cohort_operations.rs @@ -0,0 +1,369 @@ +use std::collections::HashSet; +use std::sync::Arc; +use tracing::instrument; + +use crate::cohort_models::{Cohort, CohortId, CohortProperty, InnerCohortProperty}; +use crate::{api::FlagError, database::Client as DatabaseClient, flag_definitions::PropertyFilter}; + +impl Cohort { + /// Returns a cohort from postgres given a cohort_id and team_id + #[instrument(skip_all)] + pub async fn from_pg( + client: Arc, + cohort_id: i32, + team_id: i32, + ) -> Result { + let mut conn = client.get_connection().await.map_err(|e| { + tracing::error!("Failed to get database connection: {}", e); + // TODO should I model my errors more generally? Like, yes, everything behind this API is technically a FlagError, + // but I'm not sure if accessing Cohort definitions should be a FlagError (vs idk, a CohortError? A more general API error?) + FlagError::DatabaseUnavailable + })?; + + let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id FROM posthog_cohort WHERE id = $1 AND team_id = $2"; + let cohort = sqlx::query_as::<_, Cohort>(query) + .bind(cohort_id) + .bind(team_id) + .fetch_optional(&mut *conn) + .await + .map_err(|e| { + tracing::error!("Failed to fetch cohort from database: {}", e); + FlagError::Internal(format!("Database query error: {}", e)) + })?; + + cohort.ok_or_else(|| { + FlagError::CohortNotFound(format!( + "Cohort with id {} not found for team {}", + cohort_id, team_id + )) + }) + } + + /// Returns all cohorts for a given team + #[instrument(skip_all)] + pub async fn list_from_pg( + client: Arc, + team_id: i32, + ) -> Result, FlagError> { + let mut conn = client.get_connection().await.map_err(|e| { + tracing::error!("Failed to get database connection: {}", e); + FlagError::DatabaseUnavailable + })?; + + let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id FROM posthog_cohort WHERE team_id = $1"; + let cohorts = sqlx::query_as::<_, Cohort>(query) + .bind(team_id) + .fetch_all(&mut *conn) + .await + .map_err(|e| { + tracing::error!("Failed to fetch cohorts from database: {}", e); + FlagError::Internal(format!("Database query error: {}", e)) + })?; + + Ok(cohorts) + } + + /// Parses the filters JSON into a CohortProperty structure + // TODO: this doesn't handle the deprecated "groups" field, see + // https://github.com/PostHog/posthog/blob/feat/dynamic-cohorts-rust/posthog/models/cohort/cohort.py#L114-L169 + // I'll handle that in a separate PR. + pub fn parse_filters(&self) -> Result, FlagError> { + let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone()) + .map_err(|e| { + tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e); + FlagError::CohortFiltersParsingError + })?; + Ok(cohort_property + .properties + .to_property_filters() + .into_iter() + .filter(|f| !(f.key == "id" && f.prop_type == "cohort")) + .collect()) + } + + /// Extracts dependent CohortIds from the cohort's filters + pub fn extract_dependencies(&self) -> Result, FlagError> { + let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone()) + .map_err(|e| { + tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e); + FlagError::CohortFiltersParsingError + })?; + + let mut dependencies = HashSet::new(); + Self::traverse_filters(&cohort_property.properties, &mut dependencies)?; + Ok(dependencies) + } + + /// Recursively traverses the filter tree to find cohort dependencies + /// + /// Example filter tree structure: + /// ```json + /// { + /// "properties": { + /// "type": "OR", + /// "values": [ + /// { + /// "type": "OR", + /// "values": [ + /// { + /// "key": "id", + /// "value": 123, + /// "type": "cohort", + /// "operator": "exact" + /// }, + /// { + /// "key": "email", + /// "value": "@posthog.com", + /// "type": "person", + /// "operator": "icontains" + /// } + /// ] + /// } + /// ] + /// } + /// } + /// ``` + fn traverse_filters( + inner: &InnerCohortProperty, + dependencies: &mut HashSet, + ) -> Result<(), FlagError> { + for cohort_values in &inner.values { + for filter in &cohort_values.values { + if filter.is_cohort() { + // Assuming the value is a single integer CohortId + if let Some(cohort_id) = filter.value.as_i64() { + dependencies.insert(cohort_id as CohortId); + } else { + return Err(FlagError::CohortFiltersParsingError); + } + } + // NB: we don't support nested cohort properties, so we don't need to traverse further + } + } + Ok(()) + } +} + +impl InnerCohortProperty { + /// Flattens the nested cohort property structure into a list of property filters. + /// + /// The cohort property structure in Postgres looks like: + /// ```json + /// { + /// "type": "OR", + /// "values": [ + /// { + /// "type": "OR", + /// "values": [ + /// { + /// "key": "email", + /// "value": "@posthog.com", + /// "type": "person", + /// "operator": "icontains" + /// }, + /// { + /// "key": "age", + /// "value": 25, + /// "type": "person", + /// "operator": "gt" + /// } + /// ] + /// } + /// ] + /// } + /// ``` + pub fn to_property_filters(&self) -> Vec { + self.values + .iter() + .flat_map(|value| &value.values) + .cloned() + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + cohort_models::{CohortPropertyType, CohortValues}, + test_utils::{ + insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client, + setup_pg_writer_client, + }, + }; + use serde_json::json; + + #[tokio::test] + async fn test_cohort_from_pg() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + let cohort = insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + None, + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$initial_browser_version", "type": "person", "value": ["125"], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert cohort"); + + let fetched_cohort = Cohort::from_pg(postgres_reader, cohort.id, team.id) + .await + .expect("Failed to fetch cohort"); + + assert_eq!(fetched_cohort.id, cohort.id); + assert_eq!(fetched_cohort.name, "Test Cohort"); + assert_eq!(fetched_cohort.team_id, team.id); + } + + #[tokio::test] + async fn test_list_from_pg() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + // Insert multiple cohorts for the team + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Cohort 1".to_string()), + json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "age", "type": "person", "value": [30], "negation": false, "operator": "gt"}]}]}}), + false, + ) + .await + .expect("Failed to insert cohort1"); + + insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Cohort 2".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "property", "values": [{"key": "country", "type": "person", "value": ["USA"], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert cohort2"); + + let cohorts = Cohort::list_from_pg(postgres_reader, team.id) + .await + .expect("Failed to list cohorts"); + + assert_eq!(cohorts.len(), 2); + let names: HashSet = cohorts.into_iter().map(|c| c.name).collect(); + assert!(names.contains("Cohort 1")); + assert!(names.contains("Cohort 2")); + } + + #[test] + fn test_cohort_parse_filters() { + let cohort = Cohort { + id: 1, + name: "Test Cohort".to_string(), + description: None, + team_id: 1, + deleted: false, + filters: json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$initial_browser_version", "type": "person", "value": ["125"], "negation": false, "operator": "exact"}]}]}}), + query: None, + version: None, + pending_version: None, + count: None, + is_calculating: false, + is_static: false, + errors_calculating: 0, + groups: json!({}), + created_by_id: None, + }; + + let result = cohort.parse_filters().unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].key, "$initial_browser_version"); + assert_eq!(result[0].value, json!(["125"])); + assert_eq!(result[0].prop_type, "person"); + } + + #[test] + fn test_cohort_property_to_property_filters() { + let cohort_property = InnerCohortProperty { + prop_type: CohortPropertyType::AND, + values: vec![CohortValues { + prop_type: "property".to_string(), + values: vec![ + PropertyFilter { + key: "email".to_string(), + value: json!("test@example.com"), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }, + PropertyFilter { + key: "age".to_string(), + value: json!(25), + operator: None, + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }, + ], + }], + }; + + let result = cohort_property.to_property_filters(); + assert_eq!(result.len(), 2); + assert_eq!(result[0].key, "email"); + assert_eq!(result[0].value, json!("test@example.com")); + assert_eq!(result[1].key, "age"); + assert_eq!(result[1].value, json!(25)); + } + + #[tokio::test] + async fn test_extract_dependencies() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .expect("Failed to insert team"); + + // Insert a single cohort that is dependent on another cohort + let dependent_cohort = insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Dependent Cohort".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$browser", "type": "person", "value": ["Safari"], "negation": false, "operator": "exact"}]}]}}), + false, + ) + .await + .expect("Failed to insert dependent_cohort"); + + // Insert main cohort with a single dependency + let main_cohort = insert_cohort_for_team_in_pg( + postgres_writer.clone(), + team.id, + Some("Main Cohort".to_string()), + json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "id", "type": "cohort", "value": dependent_cohort.id, "negation": false}]}]}}), + false, + ) + .await + .expect("Failed to insert main_cohort"); + + let fetched_main_cohort = Cohort::from_pg(postgres_reader.clone(), main_cohort.id, team.id) + .await + .expect("Failed to fetch main cohort"); + + println!("fetched_main_cohort: {:?}", fetched_main_cohort); + + let dependencies = fetched_main_cohort.extract_dependencies().unwrap(); + let expected_dependencies: HashSet = + [dependent_cohort.id].iter().cloned().collect(); + + assert_eq!(dependencies, expected_dependencies); + } +} diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index baebaa04da3..d62ecc9e0e0 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -1,4 +1,7 @@ -use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; +use crate::{ + api::FlagError, cohort_models::CohortId, database::Client as DatabaseClient, + redis::Client as RedisClient, +}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; @@ -7,7 +10,7 @@ use tracing::instrument; // TODO: Add integration tests across repos to ensure this doesn't happen. pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_"; -#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum OperatorType { Exact, @@ -25,6 +28,8 @@ pub enum OperatorType { IsDateExact, IsDateAfter, IsDateBefore, + In, + NotIn, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -36,10 +41,28 @@ pub struct PropertyFilter { pub value: serde_json::Value, pub operator: Option, #[serde(rename = "type")] + // TODO: worth making a enum here to differentiate between cohort and person filters? pub prop_type: String, + pub negation: Option, pub group_type_index: Option, } +impl PropertyFilter { + /// Checks if the filter is a cohort filter + pub fn is_cohort(&self) -> bool { + self.key == "id" && self.prop_type == "cohort" + } + + /// Returns the cohort id if the filter is a cohort filter, or None if it's not a cohort filter + /// or if the value cannot be parsed as a cohort id + pub fn get_cohort_id(&self) -> Option { + if !self.is_cohort() { + return None; + } + self.value.as_i64().map(|id| id as CohortId) + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FlagGroupType { pub properties: Option>, @@ -68,6 +91,9 @@ pub struct FlagFilters { pub super_groups: Option>, } +// TODO: see if you can combine these two structs, like we do with cohort models +// this will require not deserializing on read and instead doing it lazily, on-demand +// (which, tbh, is probably a better idea) #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FeatureFlag { pub id: i32, @@ -142,7 +168,7 @@ impl FeatureFlagList { tracing::error!("failed to parse data to flags list: {}", e); println!("failed to parse data: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; Ok(FeatureFlagList { flags: flags_list }) @@ -174,7 +200,7 @@ impl FeatureFlagList { .map(|row| { let filters = serde_json::from_value(row.filters).map_err(|e| { tracing::error!("Failed to deserialize filters for flag {}: {}", row.key, e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; Ok(FeatureFlag { @@ -200,7 +226,7 @@ impl FeatureFlagList { ) -> Result<(), FlagError> { let payload = serde_json::to_string(&flags.flags).map_err(|e| { tracing::error!("Failed to serialize flags: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; client @@ -1095,7 +1121,7 @@ mod tests { .expect("Failed to set malformed JSON in Redis"); let result = FeatureFlagList::from_redis(redis_client, team.id).await; - assert!(matches!(result, Err(FlagError::DataParsingError))); + assert!(matches!(result, Err(FlagError::RedisDataParsingError))); // Test database query error (using a non-existent table) let result = sqlx::query("SELECT * FROM non_existent_table") diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index bdcd542f098..d9332fce4e4 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -1,30 +1,35 @@ use crate::{ api::{FlagError, FlagValue, FlagsResponse}, + cohort_cache::CohortCacheManager, + cohort_models::{Cohort, CohortId}, database::Client as DatabaseClient, feature_flag_match_reason::FeatureFlagMatchReason, - flag_definitions::{FeatureFlag, FeatureFlagList, FlagGroupType, PropertyFilter}, + flag_definitions::{FeatureFlag, FeatureFlagList, FlagGroupType, OperatorType, PropertyFilter}, metrics_consts::{FLAG_EVALUATION_ERROR_COUNTER, FLAG_HASH_KEY_WRITES_COUNTER}, + metrics_utils::parse_exception_for_prometheus_label, property_matching::match_property, - utils::parse_exception_for_prometheus_label, }; use anyhow::Result; use common_metrics::inc; +use petgraph::algo::{is_cyclic_directed, toposort}; +use petgraph::graph::DiGraph; use serde_json::Value; use sha1::{Digest, Sha1}; -use sqlx::{postgres::PgQueryResult, Acquire, FromRow}; +use sqlx::{postgres::PgQueryResult, Acquire, FromRow, Row}; use std::fmt::Write; use std::sync::Arc; use std::{ - collections::{HashMap, HashSet}, + collections::{HashMap, HashSet, VecDeque}, time::Duration, }; use tokio::time::{sleep, timeout}; use tracing::{error, info}; -type TeamId = i32; -type GroupTypeIndex = i32; -type PostgresReader = Arc; -type PostgresWriter = Arc; +pub type TeamId = i32; +pub type PersonId = i32; +pub type GroupTypeIndex = i32; +pub type PostgresReader = Arc; +pub type PostgresWriter = Arc; #[derive(Debug)] struct SuperConditionEvaluation { @@ -172,6 +177,7 @@ impl GroupTypeMappingCache { /// to fetch the properties from the DB each time. #[derive(Clone, Default, Debug)] pub struct PropertiesCache { + person_id: Option, person_properties: Option>, group_properties: HashMap>, } @@ -182,6 +188,7 @@ pub struct FeatureFlagMatcher { pub team_id: TeamId, pub postgres_reader: PostgresReader, pub postgres_writer: PostgresWriter, + pub cohort_cache: Arc, group_type_mapping_cache: GroupTypeMappingCache, properties_cache: PropertiesCache, groups: HashMap, @@ -195,8 +202,8 @@ impl FeatureFlagMatcher { team_id: TeamId, postgres_reader: PostgresReader, postgres_writer: PostgresWriter, + cohort_cache: Arc, group_type_mapping_cache: Option, - properties_cache: Option, groups: Option>, ) -> Self { FeatureFlagMatcher { @@ -204,16 +211,26 @@ impl FeatureFlagMatcher { team_id, postgres_reader: postgres_reader.clone(), postgres_writer: postgres_writer.clone(), + cohort_cache, group_type_mapping_cache: group_type_mapping_cache .unwrap_or_else(|| GroupTypeMappingCache::new(team_id, postgres_reader.clone())), - properties_cache: properties_cache.unwrap_or_default(), groups: groups.unwrap_or_default(), + properties_cache: PropertiesCache::default(), } } - /// Evaluate feature flags for a given distinct_id - /// - Returns a map of feature flag keys to their values - /// - If an error occurs while evaluating a flag, it will be logged and the flag will be omitted from the result + /// Evaluates all feature flags for the current matcher context. + /// + /// ## Arguments + /// + /// * `feature_flags` - The list of feature flags to evaluate. + /// * `person_property_overrides` - Any overrides for person properties. + /// * `group_property_overrides` - Any overrides for group properties. + /// * `hash_key_override` - Optional hash key overrides for experience continuity. + /// + /// ## Returns + /// + /// * `FlagsResponse` - The result containing flag evaluations and any errors. pub async fn evaluate_all_feature_flags( &mut self, feature_flags: FeatureFlagList, @@ -732,14 +749,38 @@ impl FeatureFlagMatcher { .await; } - // NB: we can only evaluate group or person properties, not both - let properties_to_check = self - .get_properties_to_check(feature_flag, property_overrides, flag_property_filters) + // Separate cohort and non-cohort filters + let (cohort_filters, non_cohort_filters): (Vec, Vec) = + flag_property_filters + .iter() + .cloned() + .partition(|prop| prop.is_cohort()); + + // Get the properties we need to check for in this condition match from the flag + any overrides + let person_or_group_properties = self + .get_properties_to_check(feature_flag, property_overrides, &non_cohort_filters) .await?; - if !all_properties_match(flag_property_filters, &properties_to_check) { + // Evaluate non-cohort filters first, since they're cheaper to evaluate and we can return early if they don't match + if !all_properties_match(&non_cohort_filters, &person_or_group_properties) { return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); } + + // Evaluate cohort filters, if any. + if !cohort_filters.is_empty() { + // Get the person ID for the current distinct ID – this value should be cached at this point, but as a fallback we fetch from the database + let person_id = self.get_person_id().await?; + if !self + .evaluate_cohort_filters( + &cohort_filters, + &person_or_group_properties, + person_id, + ) + .await? + { + return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); + } + } } self.check_rollout(feature_flag, rollout_percentage, hash_key_overrides) @@ -786,6 +827,31 @@ impl FeatureFlagMatcher { } } + /// Retrieves the `PersonId` from the properties cache. + /// If the cache does not contain a `PersonId`, it fetches it from the database + /// and updates the cache accordingly. + async fn get_person_id(&mut self) -> Result { + match self.properties_cache.person_id { + Some(id) => Ok(id), + None => { + let id = self.get_person_id_from_db().await?; + self.properties_cache.person_id = Some(id); + Ok(id) + } + } + } + + /// Fetches the `PersonId` from the database based on the current `distinct_id` and `team_id`. + /// This method is called when the `PersonId` is not present in the properties cache. + async fn get_person_id_from_db(&mut self) -> Result { + let postgres_reader = self.postgres_reader.clone(); + let distinct_id = self.distinct_id.clone(); + let team_id = self.team_id; + fetch_person_properties_from_db(postgres_reader, distinct_id, team_id) + .await + .map(|(_, person_id)| person_id) + } + /// Get person properties from cache or database. /// /// This function attempts to retrieve person properties either from a cache or directly from the database. @@ -805,6 +871,56 @@ impl FeatureFlagMatcher { } } + /// Evaluates dynamic cohort property filters + /// + /// NB: This method first caches all of the cohorts associated with the team, which allows us to avoid + /// hitting the database for each cohort filter. + pub async fn evaluate_cohort_filters( + &self, + cohort_property_filters: &[PropertyFilter], + target_properties: &HashMap, + person_id: PersonId, + ) -> Result { + // At the start of the request, fetch all of the cohorts for the team from the cache + // This method also caches any cohorts for a given team in memory for the duration of the application, so we don't need to fetch from + // the database again until we restart the application. See the CohortCacheManager for more details. + let cohorts = self.cohort_cache.get_cohorts_for_team(self.team_id).await?; + + // Split the cohorts into static and dynamic, since the dynamic ones have property filters + // and we need to evaluate them based on the target properties, whereas the static ones are + // purely based on person properties and are membership-based. + let (static_cohorts, dynamic_cohorts): (Vec<_>, Vec<_>) = + cohorts.iter().partition(|c| c.is_static); + + // Store all cohort match results in a HashMap to avoid re-evaluating the same cohort multiple times, + // since the same cohort could appear in multiple property filters. + let mut cohort_matches = HashMap::new(); + + if !static_cohorts.is_empty() { + let results = evaluate_static_cohorts( + self.postgres_reader.clone(), + person_id, + static_cohorts.iter().map(|c| c.id).collect(), + ) + .await?; + cohort_matches.extend(results); + } + + if !dynamic_cohorts.is_empty() { + for filter in cohort_property_filters { + let cohort_id = filter + .get_cohort_id() + .ok_or(FlagError::CohortFiltersParsingError)?; + let match_result = + evaluate_dynamic_cohorts(cohort_id, target_properties, cohorts.clone())?; + cohort_matches.insert(cohort_id, match_result); + } + } + + // Apply cohort membership logic (IN|NOT_IN) to the cohort match results + apply_cohort_membership_logic(cohort_property_filters, &cohort_matches) + } + /// Check if a super condition matches for a feature flag. /// /// This function evaluates the super conditions of a feature flag to determine if any of them should be enabled. @@ -917,11 +1033,12 @@ impl FeatureFlagMatcher { let postgres_reader = self.postgres_reader.clone(); let distinct_id = self.distinct_id.clone(); let team_id = self.team_id; - let db_properties = + let (db_properties, person_id) = fetch_person_properties_from_db(postgres_reader, distinct_id, team_id).await?; - // once the properties are fetched, cache them so we don't need to fetch again in a given request + // once the properties and person ID are fetched, cache them so we don't need to fetch again in a given request self.properties_cache.person_properties = Some(db_properties.clone()); + self.properties_cache.person_id = Some(person_id); Ok(db_properties) } @@ -1048,6 +1165,221 @@ impl FeatureFlagMatcher { } } +/// Evaluate static cohort filters by checking if the person is in each cohort. +async fn evaluate_static_cohorts( + postgres_reader: PostgresReader, + person_id: i32, // Change this parameter from distinct_id to person_id + cohort_ids: Vec, +) -> Result, FlagError> { + let mut conn = postgres_reader.get_connection().await?; + + let query = r#" + WITH cohort_membership AS ( + SELECT c.cohort_id, + CASE WHEN pc.cohort_id IS NOT NULL THEN true ELSE false END AS is_member + FROM unnest($1::integer[]) AS c(cohort_id) + LEFT JOIN posthog_cohortpeople AS pc + ON pc.person_id = $2 + AND pc.cohort_id = c.cohort_id + ) + SELECT cohort_id, is_member + FROM cohort_membership + "#; + + let rows = sqlx::query(query) + .bind(&cohort_ids) + .bind(person_id) // Bind person_id directly + .fetch_all(&mut *conn) + .await?; + + let result = rows + .into_iter() + .map(|row| { + let cohort_id: CohortId = row.get("cohort_id"); + let is_member: bool = row.get("is_member"); + (cohort_id, is_member) + }) + .collect(); + + Ok(result) +} + +/// Evaluates a dynamic cohort and its dependencies. +/// This uses a topological sort to evaluate dependencies first, which is necessary +/// because a cohort can depend on another cohort, and we need to respect the dependency order. +fn evaluate_dynamic_cohorts( + initial_cohort_id: CohortId, + target_properties: &HashMap, + cohorts: Vec, +) -> Result { + let cohort_dependency_graph = + build_cohort_dependency_graph(initial_cohort_id, cohorts.clone())?; + + // We need to sort cohorts topologically to ensure we evaluate dependencies before the cohorts that depend on them. + // For example, if cohort A depends on cohort B, we need to evaluate B first to know if A matches. + // This also helps detect cycles - if cohort A depends on B which depends on A, toposort will fail. + let sorted_cohort_ids_as_graph_nodes = + toposort(&cohort_dependency_graph, None).map_err(|e| { + FlagError::CohortDependencyCycle(format!("Cyclic dependency detected: {:?}", e)) + })?; + + // Store evaluation results for each cohort in a map, so we can look up whether a cohort matched + // when evaluating cohorts that depend on it, and also return the final result for the initial cohort + let mut evaluation_results = HashMap::new(); + + // Iterate through the sorted nodes in reverse order (so that we can evaluate dependencies first) + for node in sorted_cohort_ids_as_graph_nodes.into_iter().rev() { + let cohort_id = cohort_dependency_graph[node]; + let cohort = cohorts + .iter() + .find(|c| c.id == cohort_id) + .ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?; + let property_filters = cohort.parse_filters()?; + let dependencies = cohort.extract_dependencies()?; + + // Check if all dependencies have been met (i.e., previous cohorts matched) + let dependencies_met = dependencies + .iter() + .all(|dep_id| evaluation_results.get(dep_id).copied().unwrap_or(false)); + + // If dependencies are not met, mark the current cohort as not matched and continue + // NB: We don't want to _exit_ here, since the non-matching cohort could be wrapped in a `not_in` operator + // and we want to evaluate all cohorts to determine if the initial cohort matches. + if !dependencies_met { + evaluation_results.insert(cohort_id, false); + continue; + } + + // Evaluate all property filters for the current cohort + let all_filters_match = property_filters + .iter() + .all(|filter| match_property(filter, target_properties, false).unwrap_or(false)); + + // Store the evaluation result for the current cohort + evaluation_results.insert(cohort_id, all_filters_match); + } + + // Retrieve and return the evaluation result for the initial cohort + evaluation_results + .get(&initial_cohort_id) + .copied() + .ok_or_else(|| FlagError::CohortNotFound(initial_cohort_id.to_string())) +} + +/// Apply cohort membership logic (i.e., IN|NOT_IN) +fn apply_cohort_membership_logic( + cohort_filters: &[PropertyFilter], + cohort_matches: &HashMap, +) -> Result { + for filter in cohort_filters { + let cohort_id = filter + .get_cohort_id() + .ok_or(FlagError::CohortFiltersParsingError)?; + let matches = cohort_matches.get(&cohort_id).copied().unwrap_or(false); + let operator = filter.operator.unwrap_or(OperatorType::In); + + // Combine the operator logic directly within this method + let membership_match = match operator { + OperatorType::In => matches, + OperatorType::NotIn => !matches, + // Currently supported operators are IN and NOT IN + // Any other operator defaults to false + _ => false, + }; + + // If any filter does not match, return false early + if !membership_match { + return Ok(false); + } + } + // All filters matched + Ok(true) +} + +/// Constructs a dependency graph for cohorts. +/// +/// Example dependency graph: +/// ```text +/// A B +/// | /| +/// | / | +/// | / | +/// C D +/// \ / +/// \ / +/// E +/// ``` +/// In this example: +/// - Cohorts A and B are root nodes (no dependencies) +/// - C depends on A and B +/// - D depends on B +/// - E depends on C and D +/// +/// The graph is acyclic, which is required for valid cohort dependencies. +fn build_cohort_dependency_graph( + initial_cohort_id: CohortId, + cohorts: Vec, +) -> Result, FlagError> { + let mut graph = DiGraph::new(); + let mut node_map = HashMap::new(); + let mut queue = VecDeque::new(); + + let initial_cohort = cohorts + .iter() + .find(|c| c.id == initial_cohort_id) + .ok_or(FlagError::CohortNotFound(initial_cohort_id.to_string()))?; + + if initial_cohort.is_static { + return Ok(graph); + } + + // This implements a breadth-first search (BFS) traversal to build a directed graph of cohort dependencies. + // Starting from the initial cohort, we: + // 1. Add each cohort as a node in the graph + // 2. Track visited nodes in a map to avoid duplicates + // 3. For each cohort, get its dependencies and add directed edges from the cohort to its dependencies + // 4. Queue up any unvisited dependencies to process their dependencies later + // This builds up the full dependency graph level by level, which we can later check for cycles + queue.push_back(initial_cohort_id); + node_map.insert(initial_cohort_id, graph.add_node(initial_cohort_id)); + + while let Some(cohort_id) = queue.pop_front() { + let cohort = cohorts + .iter() + .find(|c| c.id == cohort_id) + .ok_or(FlagError::CohortNotFound(cohort_id.to_string()))?; + let dependencies = cohort.extract_dependencies()?; + for dep_id in dependencies { + // Retrieve the current node **before** mutable borrowing + // This is safe because we're not mutating the node map, + // and it keeps the borrow checker happy + let current_node = node_map[&cohort_id]; + // Add dependency node if we haven't seen this cohort ID before in our traversal. + // This happens when we discover a new dependency that wasn't previously + // encountered while processing other cohorts in the graph. + let dep_node = node_map + .entry(dep_id) + .or_insert_with(|| graph.add_node(dep_id)); + + graph.add_edge(current_node, *dep_node, ()); + + if !node_map.contains_key(&dep_id) { + queue.push_back(dep_id); + } + } + } + + // Check for cycles, this is an directed acyclic graph so we use is_cyclic_directed + if is_cyclic_directed(&graph) { + return Err(FlagError::CohortDependencyCycle(format!( + "Cyclic dependency detected starting at cohort {}", + initial_cohort_id + ))); + } + + Ok(graph) +} + /// Fetch and locally cache all properties for a given distinct ID and team ID. /// /// This function fetches both person and group properties for a specified distinct ID and team ID. @@ -1063,32 +1395,52 @@ async fn fetch_and_locally_cache_all_properties( let query = r#" SELECT - (SELECT "posthog_person"."properties" - FROM "posthog_person" - INNER JOIN "posthog_persondistinctid" - ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") - WHERE ("posthog_persondistinctid"."distinct_id" = $1 - AND "posthog_persondistinctid"."team_id" = $2 - AND "posthog_person"."team_id" = $2) - LIMIT 1) as person_properties, - - (SELECT json_object_agg("posthog_group"."group_type_index", "posthog_group"."group_properties") - FROM "posthog_group" - WHERE ("posthog_group"."team_id" = $2 - AND "posthog_group"."group_type_index" = ANY($3))) as group_properties + person.person_id, + person.person_properties, + group_properties.group_properties + FROM ( + SELECT + "posthog_person"."id" AS person_id, + "posthog_person"."properties" AS person_properties + FROM "posthog_person" + INNER JOIN "posthog_persondistinctid" + ON "posthog_person"."id" = "posthog_persondistinctid"."person_id" + WHERE + "posthog_persondistinctid"."distinct_id" = $1 + AND "posthog_persondistinctid"."team_id" = $2 + AND "posthog_person"."team_id" = $2 + LIMIT 1 + ) AS person, + ( + SELECT + json_object_agg( + "posthog_group"."group_type_index", + "posthog_group"."group_properties" + ) AS group_properties + FROM "posthog_group" + WHERE + "posthog_group"."team_id" = $2 + AND "posthog_group"."group_type_index" = ANY($3) + ) AS group_properties "#; let group_type_indexes_vec: Vec = group_type_indexes.iter().cloned().collect(); - let row: (Option, Option) = sqlx::query_as(query) + let row: (Option, Option, Option) = sqlx::query_as(query) .bind(&distinct_id) .bind(team_id) .bind(&group_type_indexes_vec) .fetch_optional(&mut *conn) .await? - .unwrap_or((None, None)); + .unwrap_or((None, None, None)); - if let Some(person_props) = row.0 { + let (person_id, person_props, group_props) = row; + + if let Some(person_id) = person_id { + properties_cache.person_id = Some(person_id); + } + + if let Some(person_props) = person_props { properties_cache.person_properties = Some( person_props .as_object() @@ -1099,7 +1451,7 @@ async fn fetch_and_locally_cache_all_properties( ); } - if let Some(group_props) = row.1 { + if let Some(group_props) = group_props { let group_props_map: HashMap> = group_props .as_object() .unwrap_or(&serde_json::Map::new()) @@ -1122,7 +1474,7 @@ async fn fetch_and_locally_cache_all_properties( Ok(()) } -/// Fetch person properties from the database for a given distinct ID and team ID. +/// Fetch person properties and person ID from the database for a given distinct ID and team ID. /// /// This function constructs and executes a SQL query to fetch the person properties for a specified distinct ID and team ID. /// It returns the fetched properties as a HashMap. @@ -1130,31 +1482,37 @@ async fn fetch_person_properties_from_db( postgres_reader: PostgresReader, distinct_id: String, team_id: TeamId, -) -> Result, FlagError> { +) -> Result<(HashMap, i32), FlagError> { let mut conn = postgres_reader.as_ref().get_connection().await?; let query = r#" - SELECT "posthog_person"."properties" as person_properties - FROM "posthog_person" - INNER JOIN "posthog_persondistinctid" ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") - WHERE ("posthog_persondistinctid"."distinct_id" = $1 - AND "posthog_persondistinctid"."team_id" = $2 - AND "posthog_person"."team_id" = $2) - LIMIT 1 - "#; + SELECT "posthog_person"."id" as person_id, "posthog_person"."properties" as person_properties + FROM "posthog_person" + INNER JOIN "posthog_persondistinctid" ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") + WHERE ("posthog_persondistinctid"."distinct_id" = $1 + AND "posthog_persondistinctid"."team_id" = $2 + AND "posthog_person"."team_id" = $2) + LIMIT 1 + "#; - let row: Option = sqlx::query_scalar(query) + let row: Option<(i32, Value)> = sqlx::query_as(query) .bind(&distinct_id) .bind(team_id) .fetch_optional(&mut *conn) .await?; - Ok(row - .and_then(|v| v.as_object().cloned()) - .unwrap_or_default() - .into_iter() - .map(|(k, v)| (k, v.clone())) - .collect()) + match row { + Some((person_id, person_props)) => { + let properties_map = person_props + .as_object() + .unwrap_or(&serde_json::Map::new()) + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + Ok((properties_map, person_id)) + } + None => Err(FlagError::PersonNotFound), + } } /// Fetch group properties from the database for a given team ID and group type index. @@ -1216,11 +1574,11 @@ fn locally_computable_property_overrides( /// Check if all properties match the given filters fn all_properties_match( flag_condition_properties: &[PropertyFilter], - target_properties: &HashMap, + matching_property_values: &HashMap, ) -> bool { flag_condition_properties .iter() - .all(|property| match_property(property, target_properties, false).unwrap_or(false)) + .all(|property| match_property(property, matching_property_values, false).unwrap_or(false)) } async fn get_feature_flag_hash_key_overrides( @@ -1443,6 +1801,7 @@ mod tests { OperatorType, }, test_utils::{ + add_person_to_cohort, get_person_id_by_distinct_id, insert_cohort_for_team_in_pg, insert_flag_for_team_in_pg, insert_new_team_in_pg, insert_person_for_team_in_pg, setup_pg_reader_client, setup_pg_writer_client, }, @@ -1485,6 +1844,7 @@ mod tests { async fn test_fetch_properties_from_pg_to_match() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await @@ -1529,12 +1889,13 @@ mod tests { )) .unwrap(); + // Matcher for a matching distinct_id let mut matcher = FeatureFlagMatcher::new( distinct_id.clone(), team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -1542,12 +1903,13 @@ mod tests { assert!(match_result.matches); assert_eq!(match_result.variant, None); + // Matcher for a non-matching distinct_id let mut matcher = FeatureFlagMatcher::new( not_matching_distinct_id.clone(), team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -1555,24 +1917,27 @@ mod tests { assert!(!match_result.matches); assert_eq!(match_result.variant, None); + // Matcher for a distinct_id that does not exist let mut matcher = FeatureFlagMatcher::new( "other_distinct_id".to_string(), team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); - let match_result = matcher.get_match(&flag, None, None).await.unwrap(); - assert!(!match_result.matches); - assert_eq!(match_result.variant, None); + let match_result = matcher.get_match(&flag, None, None).await; + + // Expecting an error for non-existent distinct_id + assert!(match_result.is_err()); } #[tokio::test] async fn test_person_property_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1590,6 +1955,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -1611,7 +1977,7 @@ mod tests { team.id, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ); @@ -1633,6 +1999,7 @@ mod tests { async fn test_group_property_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1650,6 +2017,7 @@ mod tests { operator: None, prop_type: "group".to_string(), group_type_index: Some(1), + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -1664,10 +2032,12 @@ mod tests { None, ); - let mut cache = GroupTypeMappingCache::new(team.id, postgres_reader.clone()); + let mut group_type_mapping_cache = + GroupTypeMappingCache::new(team.id, postgres_reader.clone()); let group_types_to_indexes = [("organization".to_string(), 1)].into_iter().collect(); - cache.group_types_to_indexes = group_types_to_indexes; - cache.group_indexes_to_types = [(1, "organization".to_string())].into_iter().collect(); + group_type_mapping_cache.group_types_to_indexes = group_types_to_indexes; + group_type_mapping_cache.group_indexes_to_types = + [(1, "organization".to_string())].into_iter().collect(); let groups = HashMap::from([("organization".to_string(), json!("org_123"))]); @@ -1684,8 +2054,8 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - Some(cache), - None, + cohort_cache.clone(), + Some(group_type_mapping_cache), Some(groups), ); @@ -1708,14 +2078,14 @@ mod tests { let flag = create_test_flag_with_variants(1); let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - - let mut cache = GroupTypeMappingCache::new(1, postgres_reader.clone()); + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let mut group_type_mapping_cache = GroupTypeMappingCache::new(1, postgres_reader.clone()); let group_types_to_indexes = [("group_type_1".to_string(), 1)].into_iter().collect(); let group_type_index_to_name = [(1, "group_type_1".to_string())].into_iter().collect(); - cache.group_types_to_indexes = group_types_to_indexes; - cache.group_indexes_to_types = group_type_index_to_name; + group_type_mapping_cache.group_types_to_indexes = group_types_to_indexes; + group_type_mapping_cache.group_indexes_to_types = group_type_index_to_name; let groups = HashMap::from([("group_type_1".to_string(), json!("group_key_1"))]); @@ -1724,8 +2094,8 @@ mod tests { 1, postgres_reader.clone(), postgres_writer.clone(), - Some(cache), - None, + cohort_cache.clone(), + Some(group_type_mapping_cache), Some(groups), ); let variant = matcher.get_matching_variant(&flag, None).await.unwrap(); @@ -1740,6 +2110,7 @@ mod tests { async fn test_get_matching_variant_with_db() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1751,7 +2122,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -1765,6 +2136,7 @@ mod tests { async fn test_is_condition_match_empty_properties() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -1797,7 +2169,7 @@ mod tests { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ); @@ -1854,6 +2226,7 @@ mod tests { async fn test_overrides_avoid_db_lookups() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1871,6 +2244,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -1893,7 +2267,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -1923,6 +2297,7 @@ mod tests { async fn test_fallback_to_db_when_overrides_insufficient() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -1941,6 +2316,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "age".to_string(), @@ -1948,6 +2324,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]), rollout_percentage: Some(100.0), @@ -1982,7 +2359,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2006,6 +2383,7 @@ mod tests { async fn test_property_fetching_and_caching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2025,7 +2403,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2050,6 +2428,7 @@ mod tests { async fn test_property_caching() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2069,7 +2448,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2102,7 +2481,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2150,6 +2529,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "age".to_string(), @@ -2157,6 +2537,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]; @@ -2170,6 +2551,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "cohort".to_string(), @@ -2177,6 +2559,7 @@ mod tests { operator: None, prop_type: "cohort".to_string(), group_type_index: None, + negation: None, }, ]; @@ -2189,6 +2572,7 @@ mod tests { async fn test_concurrent_flag_evaluation() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2218,13 +2602,14 @@ mod tests { let flag_clone = flag.clone(); let postgres_reader_clone = postgres_reader.clone(); let postgres_writer_clone = postgres_writer.clone(); + let cohort_cache_clone = cohort_cache.clone(); handles.push(tokio::spawn(async move { let mut matcher = FeatureFlagMatcher::new( format!("test_user_{}", i), team.id, postgres_reader_clone, postgres_writer_clone, - None, + cohort_cache_clone, None, None, ); @@ -2246,6 +2631,7 @@ mod tests { async fn test_property_operators() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2264,6 +2650,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "email".to_string(), @@ -2271,6 +2658,7 @@ mod tests { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]), rollout_percentage: Some(100.0), @@ -2300,7 +2688,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2314,7 +2702,7 @@ mod tests { async fn test_empty_hashed_identifier() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -2341,7 +2729,7 @@ mod tests { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ); @@ -2355,6 +2743,7 @@ mod tests { async fn test_rollout_percentage() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let mut flag = create_test_flag( Some(1), None, @@ -2381,7 +2770,7 @@ mod tests { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ); @@ -2402,7 +2791,7 @@ mod tests { async fn test_uneven_variant_distribution() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; - + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let mut flag = create_test_flag_with_variants(1); // Adjust variant rollout percentages to be uneven @@ -2432,7 +2821,7 @@ mod tests { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ); @@ -2464,6 +2853,7 @@ mod tests { async fn test_missing_properties_in_db() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2491,6 +2881,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2510,7 +2901,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache, None, None, ); @@ -2524,6 +2915,7 @@ mod tests { async fn test_malformed_property_data() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2551,6 +2943,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2570,7 +2963,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache, None, None, ); @@ -2585,6 +2978,7 @@ mod tests { async fn test_get_match_with_insufficient_overrides() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2603,6 +2997,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }, PropertyFilter { key: "age".to_string(), @@ -2610,6 +3005,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }, ]), rollout_percentage: Some(100.0), @@ -2644,7 +3040,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache, None, None, ); @@ -2661,6 +3057,7 @@ mod tests { async fn test_evaluation_reasons() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = create_test_flag( Some(1), None, @@ -2687,7 +3084,7 @@ mod tests { 1, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache, None, None, ); @@ -2705,6 +3102,7 @@ mod tests { async fn test_complex_conditions() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2723,6 +3121,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2734,6 +3133,7 @@ mod tests { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2763,7 +3163,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache, None, None, ); @@ -2777,6 +3177,7 @@ mod tests { async fn test_super_condition_matches_boolean() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2795,6 +3196,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(0.0), variant: None, @@ -2806,6 +3208,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2826,6 +3229,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2845,12 +3249,25 @@ mod tests { .await .unwrap(); + insert_person_for_team_in_pg(postgres_reader.clone(), team.id, "lil_id".to_string(), None) + .await + .unwrap(); + + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "another_id".to_string(), + None, + ) + .await + .unwrap(); + let mut matcher_test_id = FeatureFlagMatcher::new( "test_id".to_string(), team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2860,7 +3277,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2870,7 +3287,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2897,6 +3314,7 @@ mod tests { async fn test_super_condition_matches_string() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2924,6 +3342,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(0.0), variant: None, @@ -2935,6 +3354,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2955,6 +3375,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -2970,7 +3391,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -2986,6 +3407,7 @@ mod tests { async fn test_super_condition_matches_and_false() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -2999,6 +3421,19 @@ mod tests { .await .unwrap(); + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "another_id".to_string(), + None, + ) + .await + .unwrap(); + + insert_person_for_team_in_pg(postgres_reader.clone(), team.id, "lil_id".to_string(), None) + .await + .unwrap(); + let flag = create_test_flag( Some(1), Some(team.id), @@ -3013,6 +3448,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(0.0), variant: None, @@ -3024,6 +3460,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3044,6 +3481,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3059,7 +3497,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -3069,7 +3507,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -3079,7 +3517,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ); @@ -3116,6 +3554,811 @@ mod tests { assert_eq!(result_another_id.condition_index, Some(2)); } + #[tokio::test] + async fn test_basic_cohort_matching() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with the condition that matches the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_not_in_cohort_matching() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with a condition that does not match the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "130", + "negation": false, + "operator": "gt" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that do not match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a NotIn cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::NotIn), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_not_in_cohort_matching_user_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with a condition that matches the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a NotIn cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::NotIn), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + // The user matches the cohort, but the flag is set to NotIn, so it should evaluate to false + assert!(!result.matches); + } + + #[tokio::test] + async fn test_cohort_dependent_on_another_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a base cohort + let base_cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "125", + "negation": false, + "operator": "gt" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a dependent cohort that includes the base cohort + let dependent_cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "id", + "type": "cohort", + "value": base_cohort_row.id, + "negation": false, + "operator": "in" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that match the base cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 126})), + ) + .await + .unwrap(); + + // Define a flag with a cohort filter that depends on another cohort + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(dependent_cohort_row.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!(result.matches); + } + + #[tokio::test] + async fn test_in_cohort_matching_user_not_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a cohort with a condition that does not match the test user's properties + let cohort_row = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + None, + json!({ + "properties": { + "type": "OR", + "values": [{ + "type": "OR", + "values": [{ + "key": "$browser_version", + "type": "person", + "value": "130", + "negation": false, + "operator": "gt" + }] + }] + } + }), + false, + ) + .await + .unwrap(); + + // Insert a person with properties that do not match the cohort condition + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "test_user".to_string(), + Some(json!({"$browser_version": 125})), + ) + .await + .unwrap(); + + // Define a flag with an In cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort_row.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + "test_user".to_string(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + // The user does not match the cohort, and the flag is set to In, so it should evaluate to false + assert!(!result.matches); + } + + #[tokio::test] + async fn test_static_cohort_matching_user_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a static cohort + let cohort = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + Some("Static Cohort".to_string()), + json!({}), // Static cohorts don't have property filters + true, // is_static = true + ) + .await + .unwrap(); + + // Insert a person + let distinct_id = "static_user".to_string(); + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "static@user.com"})), + ) + .await + .unwrap(); + + // Retrieve the person's ID + let person_id = + get_person_id_by_distinct_id(postgres_reader.clone(), team.id, &distinct_id) + .await + .unwrap(); + + // Associate the person with the static cohort + add_person_to_cohort(postgres_reader.clone(), person_id, cohort.id) + .await + .unwrap(); + + // Define a flag with an 'In' cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!( + result.matches, + "User should match the static cohort and flag" + ); + } + + #[tokio::test] + async fn test_static_cohort_matching_user_not_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a static cohort + let cohort = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + Some("Another Static Cohort".to_string()), + json!({}), // Static cohorts don't have property filters + true, + ) + .await + .unwrap(); + + // Insert a person + let distinct_id = "non_static_user".to_string(); + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "nonstatic@user.com"})), + ) + .await + .unwrap(); + + // Note: Do NOT associate the person with the static cohort + + // Define a flag with an 'In' cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!( + !result.matches, + "User should not match the static cohort and flag" + ); + } + + #[tokio::test] + async fn test_static_cohort_not_in_matching_user_not_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a static cohort + let cohort = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + Some("Static Cohort NotIn".to_string()), + json!({}), // Static cohorts don't have property filters + true, // is_static = true + ) + .await + .unwrap(); + + // Insert a person + let distinct_id = "not_in_static_user".to_string(); + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "notinstatic@user.com"})), + ) + .await + .unwrap(); + + // No association with the static cohort + + // Define a flag with a 'NotIn' cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort.id), + operator: Some(OperatorType::NotIn), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!( + result.matches, + "User not in the static cohort should match the 'NotIn' flag" + ); + } + + #[tokio::test] + async fn test_static_cohort_not_in_matching_user_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a static cohort + let cohort = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + Some("Static Cohort NotIn User In".to_string()), + json!({}), // Static cohorts don't have property filters + true, // is_static = true + ) + .await + .unwrap(); + + // Insert a person + let distinct_id = "in_not_in_static_user".to_string(); + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "innotinstatic@user.com"})), + ) + .await + .unwrap(); + + // Retrieve the person's ID + let person_id = + get_person_id_by_distinct_id(postgres_reader.clone(), team.id, &distinct_id) + .await + .unwrap(); + + // Associate the person with the static cohort + add_person_to_cohort(postgres_reader.clone(), person_id, cohort.id) + .await + .unwrap(); + + // Define a flag with a 'NotIn' cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort.id), + operator: Some(OperatorType::NotIn), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!( + !result.matches, + "User in the static cohort should not match the 'NotIn' flag" + ); + } + #[tokio::test] async fn test_set_feature_flag_hash_key_overrides_success() { let postgres_reader = setup_pg_reader_client(None).await; @@ -3123,7 +4366,7 @@ mod tests { let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); - let distinct_id = "user1".to_string(); + let distinct_id = "user2".to_string(); // Insert person insert_person_for_team_in_pg(postgres_reader.clone(), team.id, distinct_id.clone(), None) @@ -3148,7 +4391,7 @@ mod tests { Some(true), // ensure_experience_continuity ); - // need to convert flag to FeatureFlagRow + // Convert flag to FeatureFlagRow let flag_row = FeatureFlagRow { id: flag.id, team_id: flag.team_id, @@ -3165,8 +4408,8 @@ mod tests { .await .unwrap(); - // Attempt to set hash key override - let result = set_feature_flag_hash_key_overrides( + // Set hash key override + set_feature_flag_hash_key_overrides( postgres_writer.clone(), team.id, vec![distinct_id.clone()], @@ -3175,9 +4418,7 @@ mod tests { .await .unwrap(); - assert!(result, "Hash key override should be set successfully"); - - // Retrieve the hash key overrides + // Retrieve hash key overrides let overrides = get_feature_flag_hash_key_overrides( postgres_reader.clone(), team.id, @@ -3186,14 +4427,10 @@ mod tests { .await .unwrap(); - assert!( - !overrides.is_empty(), - "At least one hash key override should be set" - ); assert_eq!( overrides.get("test_flag"), Some(&"hash_key_2".to_string()), - "Hash key override for 'test_flag' should match the set value" + "Hash key override should match the set value" ); } @@ -3271,10 +4508,12 @@ mod tests { "Hash key override should match the set value" ); } + #[tokio::test] async fn test_evaluate_feature_flags_with_experience_continuity() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -3304,6 +4543,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3337,7 +4577,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ) @@ -3356,12 +4596,12 @@ mod tests { async fn test_evaluate_feature_flags_with_continuity_missing_override() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); let distinct_id = "user4".to_string(); - // Insert person insert_person_for_team_in_pg( postgres_reader.clone(), team.id, @@ -3385,6 +4625,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3408,7 +4649,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ) @@ -3427,12 +4668,12 @@ mod tests { async fn test_evaluate_all_feature_flags_mixed_continuity() { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); let distinct_id = "user5".to_string(); - // Insert person insert_person_for_team_in_pg( postgres_reader.clone(), team.id, @@ -3456,6 +4697,7 @@ mod tests { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3484,6 +4726,7 @@ mod tests { operator: Some(OperatorType::Gt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -3517,7 +4760,7 @@ mod tests { team.id, postgres_reader.clone(), postgres_writer.clone(), - None, + cohort_cache.clone(), None, None, ) diff --git a/rust/feature-flags/src/flag_request.rs b/rust/feature-flags/src/flag_request.rs index 771c216834c..1cf64eb879a 100644 --- a/rust/feature-flags/src/flag_request.rs +++ b/rust/feature-flags/src/flag_request.rs @@ -158,8 +158,8 @@ impl FlagRequest { pub async fn get_flags_from_cache_or_pg( &self, team_id: i32, - redis_client: Arc, - pg_client: Arc, + redis_client: &Arc, + pg_client: &Arc, ) -> Result { let mut cache_hit = false; let flags = match FeatureFlagList::from_redis(redis_client.clone(), team_id).await { @@ -167,10 +167,14 @@ impl FlagRequest { cache_hit = true; Ok(flags) } - Err(_) => match FeatureFlagList::from_pg(pg_client, team_id).await { + Err(_) => match FeatureFlagList::from_pg(pg_client.clone(), team_id).await { Ok(flags) => { - if let Err(e) = - FeatureFlagList::update_flags_in_redis(redis_client, team_id, &flags).await + if let Err(e) = FeatureFlagList::update_flags_in_redis( + redis_client.clone(), + team_id, + &flags, + ) + .await { tracing::warn!("Failed to update Redis cache: {}", e); // TODO add new metric category for this @@ -206,7 +210,6 @@ mod tests { TEAM_FLAGS_CACHE_PREFIX, }; use crate::flag_request::FlagRequest; - use crate::redis::Client as RedisClient; use crate::team::Team; use crate::test_utils::{insert_new_team_in_redis, setup_pg_reader_client, setup_redis_client}; use bytes::Bytes; @@ -360,6 +363,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(50.0), variant: None, @@ -402,6 +406,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -426,7 +431,7 @@ mod tests { // Test fetching from Redis let result = flag_request - .get_flags_from_cache_or_pg(team.id, redis_client.clone(), pg_client.clone()) + .get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client) .await; assert!(result.is_ok()); let fetched_flags = result.unwrap(); @@ -483,7 +488,7 @@ mod tests { .expect("Failed to remove flags from Redis"); let result = flag_request - .get_flags_from_cache_or_pg(team.id, redis_client.clone(), pg_client.clone()) + .get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client) .await; assert!(result.is_ok()); // Verify that the flags were re-added to Redis diff --git a/rust/feature-flags/src/lib.rs b/rust/feature-flags/src/lib.rs index 051b3e27697..67659bfcf9d 100644 --- a/rust/feature-flags/src/lib.rs +++ b/rust/feature-flags/src/lib.rs @@ -1,4 +1,7 @@ pub mod api; +pub mod cohort_cache; +pub mod cohort_models; +pub mod cohort_operations; pub mod config; pub mod database; pub mod feature_flag_match_reason; @@ -8,13 +11,13 @@ pub mod flag_matching; pub mod flag_request; pub mod geoip; pub mod metrics_consts; +pub mod metrics_utils; pub mod property_matching; pub mod redis; pub mod request_handler; pub mod router; pub mod server; pub mod team; -pub mod utils; pub mod v0_endpoint; // Test modules don't need to be compiled with main binary diff --git a/rust/feature-flags/src/utils.rs b/rust/feature-flags/src/metrics_utils.rs similarity index 100% rename from rust/feature-flags/src/utils.rs rename to rust/feature-flags/src/metrics_utils.rs diff --git a/rust/feature-flags/src/property_matching.rs b/rust/feature-flags/src/property_matching.rs index 8d12fe6ab5e..84479f13161 100644 --- a/rust/feature-flags/src/property_matching.rs +++ b/rust/feature-flags/src/property_matching.rs @@ -44,7 +44,7 @@ pub fn match_property( } let key = &property.key; - let operator = property.operator.clone().unwrap_or(OperatorType::Exact); + let operator = property.operator.unwrap_or(OperatorType::Exact); let value = &property.value; let match_value = matching_property_values.get(key); @@ -193,6 +193,12 @@ pub fn match_property( // Ok(false) // } } + OperatorType::In | OperatorType::NotIn => { + // TODO: we handle these in cohort matching, so we can just return false here + // because by the time we match properties, we've already decomposed the cohort + // filter into multiple property filters + Ok(false) + } } } @@ -260,6 +266,7 @@ mod test_match_properties { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -313,6 +320,7 @@ mod test_match_properties { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -335,6 +343,7 @@ mod test_match_properties { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -379,6 +388,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNot), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -416,6 +426,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNot), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -490,6 +501,7 @@ mod test_match_properties { operator: Some(OperatorType::IsSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -538,6 +550,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -595,6 +608,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -634,6 +648,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -674,6 +689,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( &property_b, @@ -708,6 +724,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -730,6 +747,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( &property_d, @@ -760,6 +778,7 @@ mod test_match_properties { operator: Some(OperatorType::Gt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -802,6 +821,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -848,6 +868,7 @@ mod test_match_properties { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -889,6 +910,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -935,6 +957,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1013,6 +1036,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNot), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1034,6 +1058,7 @@ mod test_match_properties { operator: Some(OperatorType::IsSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1049,6 +1074,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1070,6 +1096,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1085,6 +1112,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1118,6 +1146,7 @@ mod test_match_properties { operator: None, prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1137,6 +1166,7 @@ mod test_match_properties { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1152,6 +1182,7 @@ mod test_match_properties { operator: Some(OperatorType::IsSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1167,6 +1198,7 @@ mod test_match_properties { operator: Some(OperatorType::IsNotSet), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(match_property( @@ -1203,6 +1235,7 @@ mod test_match_properties { operator: Some(OperatorType::Icontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1218,6 +1251,7 @@ mod test_match_properties { operator: Some(OperatorType::NotIcontains), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1233,6 +1267,7 @@ mod test_match_properties { operator: Some(OperatorType::Regex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1248,6 +1283,7 @@ mod test_match_properties { operator: Some(OperatorType::NotRegex), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1263,6 +1299,7 @@ mod test_match_properties { operator: Some(OperatorType::Gt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1278,6 +1315,7 @@ mod test_match_properties { operator: Some(OperatorType::Gte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1293,6 +1331,7 @@ mod test_match_properties { operator: Some(OperatorType::Lt), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1308,6 +1347,7 @@ mod test_match_properties { operator: Some(OperatorType::Lte), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( @@ -1324,6 +1364,7 @@ mod test_match_properties { operator: Some(OperatorType::IsDateBefore), prop_type: "person".to_string(), group_type_index: None, + negation: None, }; assert!(!match_property( diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs index 5e0be8faacc..5ef43896e64 100644 --- a/rust/feature-flags/src/request_handler.rs +++ b/rust/feature-flags/src/request_handler.rs @@ -1,5 +1,6 @@ use crate::{ api::{FlagError, FlagsResponse}, + cohort_cache::CohortCacheManager, database::Client, flag_definitions::FeatureFlagList, flag_matching::{FeatureFlagMatcher, GroupTypeMappingCache}, @@ -69,6 +70,7 @@ pub struct FeatureFlagEvaluationContext { feature_flags: FeatureFlagList, postgres_reader: Arc, postgres_writer: Arc, + cohort_cache: Arc, #[builder(default)] person_property_overrides: Option>, #[builder(default)] @@ -95,6 +97,7 @@ pub async fn process_request(context: RequestContext) -> Result Result = state.postgres_reader.clone(); - let postgres_writer_dyn: Arc = state.postgres_writer.clone(); - let evaluation_context = FeatureFlagEvaluationContextBuilder::default() .team_id(team_id) .distinct_id(distinct_id) .feature_flags(feature_flags_from_cache_or_pg) - .postgres_reader(postgres_reader_dyn) - .postgres_writer(postgres_writer_dyn) + .postgres_reader(state.postgres_reader.clone()) + .postgres_writer(state.postgres_writer.clone()) + .cohort_cache(state.cohort_cache.clone()) .person_property_overrides(person_property_overrides) .group_property_overrides(group_property_overrides) .groups(groups) @@ -224,8 +225,8 @@ pub async fn evaluate_feature_flags(context: FeatureFlagEvaluationContext) -> Fl context.team_id, context.postgres_reader, context.postgres_writer, + context.cohort_cache, Some(group_type_mapping_cache), - None, // TODO maybe remove this from the matcher struct, since it's used internally but not passed around context.groups, ); feature_flag_matcher @@ -359,6 +360,7 @@ mod tests { async fn test_evaluate_feature_flags() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = FeatureFlag { name: Some("Test Flag".to_string()), id: 1, @@ -374,6 +376,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "person".to_string(), group_type_index: None, + negation: None, }]), rollout_percentage: Some(100.0), // Set to 100% to ensure it's always on variant: None, @@ -397,6 +400,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .person_property_overrides(Some(person_properties)) .build() .expect("Failed to build FeatureFlagEvaluationContext"); @@ -505,6 +509,7 @@ mod tests { async fn test_evaluate_feature_flags_multiple_flags() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flags = vec![ FeatureFlag { name: Some("Flag 1".to_string()), @@ -556,6 +561,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .build() .expect("Failed to build FeatureFlagEvaluationContext"); @@ -608,6 +614,7 @@ mod tests { async fn test_evaluate_feature_flags_with_overrides() { let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let team = insert_new_team_in_pg(postgres_reader.clone(), None) .await .unwrap(); @@ -627,6 +634,7 @@ mod tests { operator: Some(OperatorType::Exact), prop_type: "group".to_string(), group_type_index: Some(0), + negation: None, }]), rollout_percentage: Some(100.0), variant: None, @@ -655,6 +663,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .group_property_overrides(Some(group_property_overrides)) .groups(Some(groups)) .build() @@ -688,6 +697,7 @@ mod tests { let long_id = "a".repeat(1000); let postgres_reader: Arc = setup_pg_reader_client(None).await; let postgres_writer: Arc = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let flag = FeatureFlag { name: Some("Test Flag".to_string()), id: 1, @@ -717,6 +727,7 @@ mod tests { .feature_flags(feature_flag_list) .postgres_reader(postgres_reader) .postgres_writer(postgres_writer) + .cohort_cache(cohort_cache) .build() .expect("Failed to build FeatureFlagEvaluationContext"); diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index 505f18adfb0..e34ea31a3c6 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -9,11 +9,12 @@ use health::HealthRegistry; use tower::limit::ConcurrencyLimitLayer; use crate::{ + cohort_cache::CohortCacheManager, config::{Config, TeamIdsToTrack}, database::Client as DatabaseClient, geoip::GeoIpClient, + metrics_utils::team_id_label_filter, redis::Client as RedisClient, - utils::team_id_label_filter, v0_endpoint, }; @@ -22,6 +23,7 @@ pub struct State { pub redis: Arc, pub postgres_reader: Arc, pub postgres_writer: Arc, + pub cohort_cache: Arc, // TODO does this need a better name than just `cohort_cache`? pub geoip: Arc, pub team_ids_to_track: TeamIdsToTrack, } @@ -30,6 +32,7 @@ pub fn router( redis: Arc, postgres_reader: Arc, postgres_writer: Arc, + cohort_cache: Arc, geoip: Arc, liveness: HealthRegistry, config: Config, @@ -42,6 +45,7 @@ where redis, postgres_reader, postgres_writer, + cohort_cache, geoip, team_ids_to_track: config.team_ids_to_track.clone(), }; diff --git a/rust/feature-flags/src/server.rs b/rust/feature-flags/src/server.rs index c9e238fa8fd..69ff759ddfc 100644 --- a/rust/feature-flags/src/server.rs +++ b/rust/feature-flags/src/server.rs @@ -6,6 +6,7 @@ use std::time::Duration; use health::{HealthHandle, HealthRegistry}; use tokio::net::TcpListener; +use crate::cohort_cache::CohortCacheManager; use crate::config::Config; use crate::database::get_pool; use crate::geoip::GeoIpClient; @@ -54,6 +55,8 @@ where } }; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let health = HealthRegistry::new("liveness"); // TODO - we don't have a more complex health check yet, but we should add e.g. some around DB operations @@ -67,6 +70,7 @@ where redis_client, postgres_reader, postgres_writer, + cohort_cache, geoip_service, health, config, diff --git a/rust/feature-flags/src/team.rs b/rust/feature-flags/src/team.rs index 0fa75f0bd3d..f13cf29094b 100644 --- a/rust/feature-flags/src/team.rs +++ b/rust/feature-flags/src/team.rs @@ -42,7 +42,7 @@ impl Team { // TODO: Consider an LRU cache for teams as well, with small TTL to skip redis/pg lookups let team: Team = serde_json::from_str(&serialized_team).map_err(|e| { tracing::error!("failed to parse data to team: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; Ok(team) @@ -55,7 +55,7 @@ impl Team { ) -> Result<(), FlagError> { let serialized_team = serde_json::to_string(&team).map_err(|e| { tracing::error!("Failed to serialize team: {}", e); - FlagError::DataParsingError + FlagError::RedisDataParsingError })?; client @@ -173,7 +173,7 @@ mod tests { let client = setup_redis_client(None); match Team::from_redis(client.clone(), team.api_token.clone()).await { - Err(FlagError::DataParsingError) => (), + Err(FlagError::RedisDataParsingError) => (), Err(other) => panic!("Expected DataParsingError, got {:?}", other), Ok(_) => panic!("Expected DataParsingError"), }; diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index 32a2016bf75..346ed106ea6 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -1,11 +1,12 @@ use anyhow::Error; use axum::async_trait; use serde_json::{json, Value}; -use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, PgPool, Postgres}; +use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, Postgres, Row}; use std::sync::Arc; use uuid::Uuid; use crate::{ + cohort_models::Cohort, config::{Config, DEFAULT_TEST_CONFIG}, database::{get_pool, Client, CustomDatabaseError}, flag_definitions::{self, FeatureFlag, FeatureFlagRow}, @@ -23,7 +24,9 @@ pub fn random_string(prefix: &str, length: usize) -> String { format!("{}{}", prefix, suffix) } -pub async fn insert_new_team_in_redis(client: Arc) -> Result { +pub async fn insert_new_team_in_redis( + client: Arc, +) -> Result { let id = rand::thread_rng().gen_range(0..10_000_000); let token = random_string("phc_", 12); let team = Team { @@ -48,7 +51,7 @@ pub async fn insert_new_team_in_redis(client: Arc) -> Result, + client: Arc, team_id: i32, json_value: Option, ) -> Result<(), Error> { @@ -88,7 +91,7 @@ pub async fn insert_flags_for_team_in_redis( Ok(()) } -pub fn setup_redis_client(url: Option) -> Arc { +pub fn setup_redis_client(url: Option) -> Arc { let redis_url = match url { Some(value) => value, None => "redis://localhost:6379/".to_string(), @@ -130,7 +133,7 @@ pub fn create_flag_from_json(json_value: Option) -> Vec { flags } -pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc { +pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc { let config = config.unwrap_or(&DEFAULT_TEST_CONFIG); Arc::new( get_pool(&config.read_database_url, config.max_pg_connections) @@ -139,7 +142,7 @@ pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc { ) } -pub async fn setup_pg_writer_client(config: Option<&Config>) -> Arc { +pub async fn setup_pg_writer_client(config: Option<&Config>) -> Arc { let config = config.unwrap_or(&DEFAULT_TEST_CONFIG); Arc::new( get_pool(&config.write_database_url, config.max_pg_connections) @@ -261,7 +264,7 @@ pub async fn insert_new_team_in_pg( } pub async fn insert_flag_for_team_in_pg( - client: Arc, + client: Arc, team_id: i32, flag: Option, ) -> Result { @@ -310,11 +313,12 @@ pub async fn insert_flag_for_team_in_pg( } pub async fn insert_person_for_team_in_pg( - client: Arc, + client: Arc, team_id: i32, distinct_id: String, properties: Option, -) -> Result<(), Error> { +) -> Result { + // Changed return type to Result let payload = match properties { Some(value) => value, None => json!({ @@ -326,7 +330,7 @@ pub async fn insert_person_for_team_in_pg( let uuid = Uuid::now_v7(); let mut conn = client.get_connection().await?; - let res = sqlx::query( + let row = sqlx::query( r#" WITH inserted_person AS ( INSERT INTO posthog_person ( @@ -334,10 +338,11 @@ pub async fn insert_person_for_team_in_pg( properties_last_operation, team_id, is_user_id, is_identified, uuid, version ) VALUES ('2023-04-05', $1, '{}', '{}', $2, NULL, true, $3, 0) - RETURNING * + RETURNING id ) INSERT INTO posthog_persondistinctid (distinct_id, person_id, team_id, version) VALUES ($4, (SELECT id FROM inserted_person), $5, 0) + RETURNING person_id "#, ) .bind(&payload) @@ -345,10 +350,109 @@ pub async fn insert_person_for_team_in_pg( .bind(uuid) .bind(&distinct_id) .bind(team_id) + .fetch_one(&mut *conn) + .await?; + + let person_id: i32 = row.get::("person_id"); + Ok(person_id) +} + +pub async fn insert_cohort_for_team_in_pg( + client: Arc, + team_id: i32, + name: Option, + filters: serde_json::Value, + is_static: bool, +) -> Result { + let cohort = Cohort { + id: 0, // Placeholder, will be updated after insertion + name: name.unwrap_or("Test Cohort".to_string()), + description: Some("Description for cohort".to_string()), + team_id, + deleted: false, + filters, + query: None, + version: Some(1), + pending_version: None, + count: None, + is_calculating: false, + is_static, + errors_calculating: 0, + groups: serde_json::json!([]), + created_by_id: None, + }; + + let mut conn = client.get_connection().await?; + let row: (i32,) = sqlx::query_as( + r#"INSERT INTO posthog_cohort + (name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id) VALUES + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + RETURNING id"#, + ) + .bind(&cohort.name) + .bind(&cohort.description) + .bind(cohort.team_id) + .bind(cohort.deleted) + .bind(&cohort.filters) + .bind(&cohort.query) + .bind(cohort.version) + .bind(cohort.pending_version) + .bind(cohort.count) + .bind(cohort.is_calculating) + .bind(cohort.is_static) + .bind(cohort.errors_calculating) + .bind(&cohort.groups) + .bind(cohort.created_by_id) + .fetch_one(&mut *conn) + .await?; + + // Update the cohort_row with the actual id generated by sqlx + let id = row.0; + + Ok(Cohort { id, ..cohort }) +} + +pub async fn get_person_id_by_distinct_id( + client: Arc, + team_id: i32, + distinct_id: &str, +) -> Result { + let mut conn = client.get_connection().await?; + let row: (i32,) = sqlx::query_as( + r#"SELECT id FROM posthog_person + WHERE team_id = $1 AND id = ( + SELECT person_id FROM posthog_persondistinctid + WHERE team_id = $1 AND distinct_id = $2 + LIMIT 1 + ) + LIMIT 1"#, + ) + .bind(team_id) + .bind(distinct_id) + .fetch_one(&mut *conn) + .await + .map_err(|_| anyhow::anyhow!("Person not found"))?; + + Ok(row.0) +} + +pub async fn add_person_to_cohort( + client: Arc, + person_id: i32, + cohort_id: i32, +) -> Result<(), Error> { + let mut conn = client.get_connection().await?; + let res = sqlx::query( + r#"INSERT INTO posthog_cohortpeople (cohort_id, person_id) + VALUES ($1, $2) + ON CONFLICT DO NOTHING"#, + ) + .bind(cohort_id) + .bind(person_id) .execute(&mut *conn) .await?; - assert_eq!(res.rows_affected(), 1); + assert!(res.rows_affected() > 0, "Failed to add person to cohort"); Ok(()) } diff --git a/rust/feature-flags/tests/test_flag_matching_consistency.rs b/rust/feature-flags/tests/test_flag_matching_consistency.rs index 94f4f67dcdc..c632d28bc15 100644 --- a/rust/feature-flags/tests/test_flag_matching_consistency.rs +++ b/rust/feature-flags/tests/test_flag_matching_consistency.rs @@ -1,3 +1,6 @@ +use std::sync::Arc; + +use feature_flags::cohort_cache::CohortCacheManager; use feature_flags::feature_flag_match_reason::FeatureFlagMatchReason; /// These tests are common between all libraries doing local evaluation of feature flags. /// This ensures there are no mismatches between implementations. @@ -110,6 +113,7 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { for (i, result) in results.iter().enumerate().take(1000) { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let distinct_id = format!("distinct_id_{}", i); @@ -118,7 +122,7 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, ) @@ -1209,6 +1213,7 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { for (i, result) in results.iter().enumerate().take(1000) { let postgres_reader = setup_pg_reader_client(None).await; let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); let distinct_id = format!("distinct_id_{}", i); let feature_flag_match = FeatureFlagMatcher::new( @@ -1216,7 +1221,7 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() { 1, postgres_reader, postgres_writer, - None, + cohort_cache, None, None, )