0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-21 13:39:22 +01:00

Merge branch 'master' into ch-24.3

This commit is contained in:
James Greenhill 2024-11-16 08:14:32 -08:00 committed by GitHub
commit 7b59b2dff6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 2953 additions and 338 deletions

View File

@ -77,6 +77,9 @@ describe('Signup', () => {
cy.get('[data-attr=password]').type(VALID_PASSWORD).should('have.value', VALID_PASSWORD)
cy.get('[data-attr=signup-start]').click()
cy.get('[data-attr=signup-name]').type('Alice Bob').should('have.value', 'Alice Bob')
cy.get('[data-attr=signup-role-at-organization]').click()
cy.get('.Popover li:first-child').click()
cy.get('[data-attr=signup-role-at-organization]').contains('Engineering')
cy.get('[data-attr=signup-submit]').click()
cy.wait('@signupRequest').then((interception) => {
@ -93,6 +96,9 @@ describe('Signup', () => {
cy.get('[data-attr=password]').type(VALID_PASSWORD).should('have.value', VALID_PASSWORD)
cy.get('[data-attr=signup-start]').click()
cy.get('[data-attr=signup-name]').type('Alice Bob').should('have.value', 'Alice Bob')
cy.get('[data-attr=signup-role-at-organization]').click()
cy.get('.Popover li:first-child').click()
cy.get('[data-attr=signup-role-at-organization]').contains('Engineering')
cy.get('[data-attr=signup-submit]').click()
cy.wait('@signupRequest').then(() => {
@ -105,6 +111,9 @@ describe('Signup', () => {
const newEmail = `new_user+${Math.floor(Math.random() * 10000)}@posthog.com`
cy.get('[data-attr=signup-email]').clear().type(newEmail).should('have.value', newEmail)
cy.get('[data-attr=signup-start]').click()
cy.get('[data-attr=signup-role-at-organization]').click()
cy.get('.Popover li:first-child').click()
cy.get('[data-attr=signup-role-at-organization]').contains('Engineering')
cy.get('[data-attr=signup-submit]').click()
cy.wait('@signupRequest').then((interception) => {

View File

@ -328,6 +328,7 @@ class ExperimentSerializer(serializers.ModelSerializer):
"name": f'Feature Flag for Experiment {validated_data["name"]}',
"filters": filters,
"active": not is_draft,
"creation_context": "experiments",
},
context=self.context,
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 159 KiB

After

Width:  |  Height:  |  Size: 159 KiB

View File

@ -3,7 +3,7 @@ import { LemonSelect } from 'lib/lemon-ui/LemonSelect'
export default function SignupRoleSelect({ className }: { className?: string }): JSX.Element {
return (
<LemonField name="role_at_organization" label="What is your role?" className={className} showOptional>
<LemonField name="role_at_organization" label="What is your role?" className={className}>
<LemonSelect
fullWidth
data-attr="signup-role-at-organization"

View File

@ -245,9 +245,7 @@ export const codeEditorLogic = kea<codeEditorLogicType>([
}
if (props.monaco) {
const defaultQuery = values.featureFlags[FEATURE_FLAGS.SQL_EDITOR]
? ''
: 'SELECT event FROM events LIMIT 100'
const defaultQuery = 'SELECT event FROM events LIMIT 100'
const uri = props.monaco.Uri.parse(currentModelCount.toString())
const model = props.monaco.editor.createModel(defaultQuery, props.language, uri)
props.editor?.setModel(model)

View File

@ -67,7 +67,7 @@ export const signupLogic = kea<signupLogicType>([
password: !values.preflight?.demo
? !password
? 'Please enter your password to continue'
: values.validatedPassword.feedback
: values.validatedPassword.feedback || undefined
: undefined,
}),
submit: async () => {
@ -83,8 +83,9 @@ export const signupLogic = kea<signupLogicType>([
role_at_organization: '',
referral_source: '',
} as SignupForm,
errors: ({ name }) => ({
errors: ({ name, role_at_organization }) => ({
name: !name ? 'Please enter your name' : undefined,
role_at_organization: !role_at_organization ? 'Please select your role in the organization' : undefined,
}),
submit: async (payload, breakpoint) => {
breakpoint()

View File

@ -21,6 +21,7 @@ import { BillingCTAHero } from './BillingCTAHero'
import { billingLogic } from './billingLogic'
import { BillingProduct } from './BillingProduct'
import { CreditCTAHero } from './CreditCTAHero'
import { PaymentEntryModal } from './PaymentEntryModal'
import { UnsubscribeCard } from './UnsubscribeCard'
export const scene: SceneExport = {
@ -82,6 +83,8 @@ export function Billing(): JSX.Element {
const platformAndSupportProduct = products?.find((product) => product.type === 'platform_and_support')
return (
<div ref={ref}>
<PaymentEntryModal />
{showLicenseDirectInput && (
<>
<Form logic={billingLogic} formKey="activateLicense" enableFormOnSubmit className="space-y-4">

View File

@ -1,12 +1,13 @@
import { LemonButton, LemonModal, Spinner } from '@posthog/lemon-ui'
import { Elements, PaymentElement, useElements, useStripe } from '@stripe/react-stripe-js'
import { loadStripe } from '@stripe/stripe-js'
import { useActions, useValues } from 'kea'
import { useEffect } from 'react'
import { WavingHog } from 'lib/components/hedgehogs'
import { useEffect, useState } from 'react'
import { urls } from 'scenes/urls'
import { paymentEntryLogic } from './paymentEntryLogic'
const stripePromise = loadStripe(window.STRIPE_PUBLIC_KEY!)
const stripeJs = async (): Promise<typeof import('@stripe/stripe-js')> => await import('@stripe/stripe-js')
export const PaymentForm = (): JSX.Element => {
const { error, isLoading } = useValues(paymentEntryLogic)
@ -34,13 +35,17 @@ export const PaymentForm = (): JSX.Element => {
setLoading(false)
setError(result.error.message)
} else {
pollAuthorizationStatus()
pollAuthorizationStatus(result.paymentIntent.id)
}
}
return (
<div>
<PaymentElement />
<p className="text-xs text-muted mt-0.5">
Your card will not be charged but we place a $0.50 hold on it to verify your card that will be released
in 7 days.
</p>
{error && <div className="error">{error}</div>}
<div className="flex justify-end space-x-2 mt-2">
<LemonButton disabled={isLoading} type="secondary" onClick={hidePaymentEntryModal}>
@ -58,21 +63,34 @@ interface PaymentEntryModalProps {
redirectPath?: string | null
}
export const PaymentEntryModal = ({ redirectPath = null }: PaymentEntryModalProps): JSX.Element | null => {
export const PaymentEntryModal = ({
redirectPath = urls.organizationBilling(),
}: PaymentEntryModalProps): JSX.Element => {
const { clientSecret, paymentEntryModalOpen } = useValues(paymentEntryLogic)
const { hidePaymentEntryModal, initiateAuthorization } = useActions(paymentEntryLogic)
const [stripePromise, setStripePromise] = useState<any>(null)
useEffect(() => {
// Load Stripe.js asynchronously
const loadStripeJs = async (): Promise<void> => {
const { loadStripe } = await stripeJs()
const publicKey = window.STRIPE_PUBLIC_KEY!
setStripePromise(await loadStripe(publicKey))
}
void loadStripeJs()
}, [])
useEffect(() => {
initiateAuthorization(redirectPath)
}, [redirectPath])
}, [initiateAuthorization, redirectPath])
return (
<LemonModal
onClose={hidePaymentEntryModal}
width="max(44vw)"
isOpen={paymentEntryModalOpen}
title="Add your payment details"
description="Your card will not be charged."
title="Add your payment details to subscribe"
description=""
>
<div>
{clientSecret ? (
@ -80,9 +98,13 @@ export const PaymentEntryModal = ({ redirectPath = null }: PaymentEntryModalProp
<PaymentForm />
</Elements>
) : (
<div className="min-h-40 flex justify-center items-center">
<div className="text-4xl">
<Spinner />
<div className="min-h-80 flex flex-col justify-center items-center">
<p className="text-muted text-md mt-4">We're contacting the Hedgehogs for approval.</p>
<div className="flex items-center space-x-2">
<div className="text-4xl">
<Spinner />
</div>
<WavingHog className="w-18 h-18" />
</div>
</div>
)}

View File

@ -12,7 +12,7 @@ export const paymentEntryLogic = kea<paymentEntryLogicType>({
setLoading: (loading) => ({ loading }),
setError: (error) => ({ error }),
initiateAuthorization: (redirectPath: string | null) => ({ redirectPath }),
pollAuthorizationStatus: true,
pollAuthorizationStatus: (paymentIntentId?: string) => ({ paymentIntentId }),
setAuthorizationStatus: (status: string | null) => ({ status }),
showPaymentEntryModal: true,
hidePaymentEntryModal: true,
@ -73,7 +73,7 @@ export const paymentEntryLogic = kea<paymentEntryLogicType>({
}
},
pollAuthorizationStatus: async () => {
pollAuthorizationStatus: async ({ paymentIntentId }) => {
const pollInterval = 2000 // Poll every 2 seconds
const maxAttempts = 30 // Max 1 minute of polling (30 * 2 seconds)
let attempts = 0
@ -81,9 +81,9 @@ export const paymentEntryLogic = kea<paymentEntryLogicType>({
const poll = async (): Promise<void> => {
try {
const urlParams = new URLSearchParams(window.location.search)
const paymentIntentId = urlParams.get('payment_intent')
const searchPaymentIntentId = urlParams.get('payment_intent')
const response = await api.create('api/billing/activate/authorize/status', {
payment_intent_id: paymentIntentId,
payment_intent_id: paymentIntentId || searchPaymentIntentId,
})
const status = response.status

View File

@ -1,19 +1,10 @@
import { Monaco } from '@monaco-editor/react'
import { BindLogic, useActions, useValues } from 'kea'
import { useActions, useValues } from 'kea'
import { router } from 'kea-router'
import {
activemodelStateKey,
codeEditorLogic,
CodeEditorLogicProps,
editorModelsStateKey,
} from 'lib/monaco/codeEditorLogic'
import type { editor as importedEditor, Uri } from 'monaco-editor'
import { useCallback, useEffect, useState } from 'react'
import { dataNodeLogic } from '~/queries/nodes/DataNode/dataNodeLogic'
import { hogQLQueryEditorLogic } from '~/queries/nodes/HogQLQuery/hogQLQueryEditorLogic'
import { HogQLQuery, NodeKind } from '~/queries/schema'
import type { editor as importedEditor } from 'monaco-editor'
import { useState } from 'react'
import { multitabEditorLogic } from './multitabEditorLogic'
import { QueryPane } from './QueryPane'
import { QueryTabs } from './QueryTabs'
import { ResultPane } from './ResultPane'
@ -24,152 +15,54 @@ export function QueryWindow(): JSX.Element {
)
const [monaco, editor] = monacoAndEditor ?? []
const key = router.values.location.pathname
const [query, setActiveQueryInput] = useState<HogQLQuery>({
kind: NodeKind.HogQLQuery,
query: '',
})
const hogQLQueryEditorLogicProps = {
query,
setQuery: (query: HogQLQuery) => {
setActiveQueryInput(query)
},
onChange: () => {},
key,
}
const logic = hogQLQueryEditorLogic(hogQLQueryEditorLogicProps)
const { queryInput, promptError } = useValues(logic)
const { setQueryInput, saveQuery, saveAsView } = useActions(logic)
const codeEditorKey = `hogQLQueryEditor/${router.values.location.pathname}`
const codeEditorLogicProps: CodeEditorLogicProps = {
const logic = multitabEditorLogic({
key: codeEditorKey,
sourceQuery: query,
query: queryInput,
language: 'hogQL',
metadataFilters: query.filters,
monaco,
editor,
multitab: true,
}
const { activeModelUri, allModels, hasErrors, error, isValidView } = useValues(
codeEditorLogic(codeEditorLogicProps)
)
const { createModel, setModel, deleteModel, setModels, addModel, updateState } = useActions(
codeEditorLogic(codeEditorLogicProps)
)
const modelKey = `hogQLQueryEditor/${activeModelUri?.path}`
useEffect(() => {
if (monaco && activeModelUri) {
const _model = monaco.editor.getModel(activeModelUri)
const val = _model?.getValue()
setQueryInput(val ?? '')
saveQuery()
}
}, [activeModelUri])
const onAdd = useCallback(() => {
createModel()
}, [createModel])
})
const { allTabs, activeModelUri, queryInput, activeQuery, activeTabKey, hasErrors, error, isValidView } =
useValues(logic)
const { selectTab, deleteTab, createTab, setQueryInput, runQuery, saveAsView } = useActions(logic)
return (
<div className="flex flex-1 flex-col h-full">
<QueryTabs
models={allModels}
onClick={setModel}
onClear={deleteModel}
onAdd={onAdd}
models={allTabs}
onClick={selectTab}
onClear={deleteTab}
onAdd={createTab}
activeModelUri={activeModelUri}
/>
<QueryPane
queryInput={queryInput}
promptError={promptError}
promptError={null}
codeEditorProps={{
onChange: (v) => {
setQueryInput(v ?? '')
updateState()
},
onMount: (editor, monaco) => {
setMonacoAndEditor([monaco, editor])
const allModelQueries = localStorage.getItem(editorModelsStateKey(codeEditorKey))
const activeModelUri = localStorage.getItem(activemodelStateKey(codeEditorKey))
if (allModelQueries) {
// clear existing models
monaco.editor.getModels().forEach((model) => {
model.dispose()
})
const models = JSON.parse(allModelQueries || '[]')
const newModels: Uri[] = []
models.forEach((model: Record<string, any>) => {
if (monaco) {
const uri = monaco.Uri.parse(model.path)
const newModel = monaco.editor.createModel(model.query, 'hogQL', uri)
editor?.setModel(newModel)
newModels.push(uri)
}
})
setModels(newModels)
if (activeModelUri) {
const uri = monaco.Uri.parse(activeModelUri)
const activeModel = monaco.editor
.getModels()
.find((model) => model.uri.path === uri.path)
activeModel && editor?.setModel(activeModel)
const val = activeModel?.getValue()
if (val) {
setQueryInput(val)
saveQuery()
}
setModel(uri)
} else if (newModels.length) {
setModel(newModels[0])
}
} else {
const model = editor.getModel()
if (model) {
addModel(model.uri)
setModel(model.uri)
}
}
},
onPressCmdEnter: (value, selectionType) => {
if (value && selectionType === 'selection') {
saveQuery(value)
runQuery(value)
} else {
saveQuery()
runQuery()
}
},
}}
/>
<BindLogic
logic={dataNodeLogic}
props={{
key: modelKey,
query: query,
doNotLoad: !query.query,
}}
>
<ResultPane
onQueryInputChange={saveQuery}
onSave={saveAsView}
saveDisabledReason={
hasErrors ? error ?? 'Query has errors' : !isValidView ? 'All fields must have an alias' : ''
}
/>
</BindLogic>
<ResultPane
logicKey={activeTabKey}
query={activeQuery ?? ''}
onQueryInputChange={runQuery}
onSave={saveAsView}
saveDisabledReason={
hasErrors ? error ?? 'Query has errors' : !isValidView ? 'All fields must have an alias' : ''
}
/>
</div>
)
}

View File

@ -7,6 +7,7 @@ import DataGrid from 'react-data-grid'
import { themeLogic } from '~/layout/navigation-3000/themeLogic'
import { dataNodeLogic } from '~/queries/nodes/DataNode/dataNodeLogic'
import { NodeKind } from '~/queries/schema'
enum ResultsTab {
Results = 'results',
@ -17,11 +18,28 @@ interface ResultPaneProps {
onSave: () => void
saveDisabledReason?: string
onQueryInputChange: () => void
logicKey: string
query: string
}
export function ResultPane({ onQueryInputChange, onSave, saveDisabledReason }: ResultPaneProps): JSX.Element {
export function ResultPane({
onQueryInputChange,
onSave,
saveDisabledReason,
logicKey,
query,
}: ResultPaneProps): JSX.Element {
const { isDarkModeOn } = useValues(themeLogic)
const { response, responseLoading } = useValues(dataNodeLogic)
const { response, responseLoading } = useValues(
dataNodeLogic({
key: logicKey,
query: {
kind: NodeKind.HogQLQuery,
query,
},
doNotLoad: !query,
})
)
const columns = useMemo(() => {
return (

View File

@ -0,0 +1,332 @@
import { Monaco } from '@monaco-editor/react'
import { LemonDialog, LemonInput } from '@posthog/lemon-ui'
import { actions, kea, listeners, path, props, propsChanged, reducers, selectors } from 'kea'
import { subscriptions } from 'kea-subscriptions'
import { LemonField } from 'lib/lemon-ui/LemonField'
import { ModelMarker } from 'lib/monaco/codeEditorLogic'
import { editor, MarkerSeverity, Uri } from 'monaco-editor'
import { dataNodeLogic } from '~/queries/nodes/DataNode/dataNodeLogic'
import { performQuery } from '~/queries/query'
import { HogLanguage, HogQLMetadata, HogQLMetadataResponse, HogQLNotice, HogQLQuery, NodeKind } from '~/queries/schema'
import { dataWarehouseViewsLogic } from '../saved_queries/dataWarehouseViewsLogic'
import type { multitabEditorLogicType } from './multitabEditorLogicType'
export interface MultitabEditorLogicProps {
key: string
monaco?: Monaco | null
editor?: editor.IStandaloneCodeEditor | null
}
export const editorModelsStateKey = (key: string | number): string => `${key}/editorModelQueries`
export const activemodelStateKey = (key: string | number): string => `${key}/activeModelUri`
export const multitabEditorLogic = kea<multitabEditorLogicType>([
path(['data-warehouse', 'editor', 'multitabEditorLogic']),
props({} as MultitabEditorLogicProps),
actions({
setQueryInput: (queryInput: string) => ({ queryInput }),
updateState: true,
runQuery: (queryOverride?: string) => ({ queryOverride }),
setActiveQuery: (query: string) => ({ query }),
setTabs: (tabs: Uri[]) => ({ tabs }),
addTab: (tab: Uri) => ({ tab }),
createTab: () => null,
deleteTab: (tab: Uri) => ({ tab }),
removeTab: (tab: Uri) => ({ tab }),
selectTab: (tab: Uri) => ({ tab }),
setLocalState: (key: string, value: any) => ({ key, value }),
initialize: true,
saveAsView: true,
saveAsViewSuccess: (name: string) => ({ name }),
reloadMetadata: true,
setMetadata: (query: string, metadata: HogQLMetadataResponse) => ({ query, metadata }),
}),
propsChanged(({ actions }, oldProps) => {
if (!oldProps.monaco && !oldProps.editor) {
actions.initialize()
}
}),
reducers(({ props }) => ({
queryInput: [
'',
{
setQueryInput: (_, { queryInput }) => queryInput,
},
],
activeQuery: [
null as string | null,
{
setActiveQuery: (_, { query }) => query,
},
],
activeModelUri: [
null as Uri | null,
{
selectTab: (_, { tab }) => tab,
},
],
allTabs: [
[] as Uri[],
{
addTab: (state, { tab }) => {
const newTabs = [...state, tab]
return newTabs
},
removeTab: (state, { tab: tabToRemove }) => {
const newModels = state.filter((tab) => tab.toString() !== tabToRemove.toString())
return newModels
},
setTabs: (_, { tabs }) => tabs,
},
],
metadata: [
null as null | [string, HogQLMetadataResponse],
{
setMetadata: (_, { query, metadata }) => [query, metadata],
},
],
modelMarkers: [
[] as ModelMarker[],
{
setMetadata: (_, { query, metadata }) => {
const model = props.editor?.getModel()
if (!model || !metadata) {
return []
}
const markers: ModelMarker[] = []
const metadataResponse = metadata
function noticeToMarker(error: HogQLNotice, severity: MarkerSeverity): ModelMarker {
const start = model!.getPositionAt(error.start ?? 0)
const end = model!.getPositionAt(error.end ?? query.length)
return {
start: error.start ?? 0,
startLineNumber: start.lineNumber,
startColumn: start.column,
end: error.end ?? query.length,
endLineNumber: end.lineNumber,
endColumn: end.column,
message: error.message ?? 'Unknown error',
severity: severity,
hogQLFix: error.fix,
}
}
for (const notice of metadataResponse?.errors ?? []) {
markers.push(noticeToMarker(notice, 8 /* MarkerSeverity.Error */))
}
for (const notice of metadataResponse?.warnings ?? []) {
markers.push(noticeToMarker(notice, 4 /* MarkerSeverity.Warning */))
}
for (const notice of metadataResponse?.notices ?? []) {
markers.push(noticeToMarker(notice, 1 /* MarkerSeverity.Hint */))
}
props.monaco?.editor.setModelMarkers(model, 'hogql', markers)
return markers
},
},
],
})),
listeners(({ values, props, actions }) => ({
createTab: () => {
let currentModelCount = 1
const allNumbers = values.allTabs.map((tab) => parseInt(tab.path.split('/').pop() || '0'))
while (allNumbers.includes(currentModelCount)) {
currentModelCount++
}
if (props.monaco) {
const uri = props.monaco.Uri.parse(currentModelCount.toString())
const model = props.monaco.editor.createModel('', 'hogQL', uri)
props.editor?.setModel(model)
actions.addTab(uri)
actions.selectTab(uri)
const queries = values.allTabs.map((tab) => {
return {
query: props.monaco?.editor.getModel(tab)?.getValue() || '',
path: tab.path.split('/').pop(),
}
})
actions.setLocalState(editorModelsStateKey(props.key), JSON.stringify(queries))
}
},
selectTab: ({ tab }) => {
if (props.monaco) {
const model = props.monaco.editor.getModel(tab)
props.editor?.setModel(model)
}
const path = tab.path.split('/').pop()
path && actions.setLocalState(activemodelStateKey(props.key), path)
},
deleteTab: ({ tab: tabToRemove }) => {
if (props.monaco) {
const model = props.monaco.editor.getModel(tabToRemove)
if (tabToRemove == values.activeModelUri) {
const indexOfModel = values.allTabs.findIndex((tab) => tab.toString() === tabToRemove.toString())
const nextModel =
values.allTabs[indexOfModel + 1] || values.allTabs[indexOfModel - 1] || values.allTabs[0] // there will always be one
actions.selectTab(nextModel)
}
model?.dispose()
actions.removeTab(tabToRemove)
const queries = values.allTabs.map((tab) => {
return {
query: props.monaco?.editor.getModel(tab)?.getValue() || '',
path: tab.path.split('/').pop(),
}
})
actions.setLocalState(editorModelsStateKey(props.key), JSON.stringify(queries))
}
},
setLocalState: ({ key, value }) => {
localStorage.setItem(key, value)
},
initialize: () => {
const allModelQueries = localStorage.getItem(editorModelsStateKey(props.key))
const activeModelUri = localStorage.getItem(activemodelStateKey(props.key))
if (allModelQueries) {
// clear existing models
props.monaco?.editor.getModels().forEach((model: editor.ITextModel) => {
model.dispose()
})
const models = JSON.parse(allModelQueries || '[]')
const newModels: Uri[] = []
models.forEach((model: Record<string, any>) => {
if (props.monaco) {
const uri = props.monaco.Uri.parse(model.path)
const newModel = props.monaco.editor.createModel(model.query, 'hogQL', uri)
props.editor?.setModel(newModel)
newModels.push(uri)
}
})
actions.setTabs(newModels)
if (activeModelUri) {
const uri = props.monaco?.Uri.parse(activeModelUri)
const activeModel = props.monaco?.editor
.getModels()
.find((model: editor.ITextModel) => model.uri.path === uri?.path)
activeModel && props.editor?.setModel(activeModel)
const val = activeModel?.getValue()
if (val) {
actions.setQueryInput(val)
actions.runQuery()
}
uri && actions.selectTab(uri)
} else if (newModels.length) {
actions.selectTab(newModels[0])
}
} else {
const model = props.editor?.getModel()
if (model) {
actions.createTab()
}
}
},
setQueryInput: () => {
actions.updateState()
},
updateState: async (_, breakpoint) => {
await breakpoint(100)
const queries = values.allTabs.map((model) => {
return {
query: props.monaco?.editor.getModel(model)?.getValue() || '',
path: model.path.split('/').pop(),
}
})
localStorage.setItem(editorModelsStateKey(props.key), JSON.stringify(queries))
},
runQuery: ({ queryOverride }) => {
actions.setActiveQuery(queryOverride || values.queryInput)
},
saveAsView: async () => {
LemonDialog.openForm({
title: 'Save as view',
initialValues: { viewName: '' },
content: (
<LemonField name="viewName">
<LemonInput placeholder="Please enter the name of the view" autoFocus />
</LemonField>
),
errors: {
viewName: (name) => (!name ? 'You must enter a name' : undefined),
},
onSubmit: ({ viewName }) => actions.saveAsViewSuccess(viewName),
})
},
saveAsViewSuccess: async ({ name }) => {
const query: HogQLQuery = {
kind: NodeKind.HogQLQuery,
query: values.queryInput,
}
await dataWarehouseViewsLogic.asyncActions.createDataWarehouseSavedQuery({ name, query })
},
reloadMetadata: async (_, breakpoint) => {
const model = props.editor?.getModel()
if (!model || !props.monaco) {
return
}
await breakpoint(300)
const query = values.queryInput
if (query === '') {
return
}
const response = await performQuery<HogQLMetadata>({
kind: NodeKind.HogQLMetadata,
language: HogLanguage.hogQL,
query: query,
})
breakpoint()
actions.setMetadata(query, response)
},
})),
subscriptions(({ props, actions, values }) => ({
activeModelUri: (activeModelUri) => {
if (props.monaco) {
const _model = props.monaco.editor.getModel(activeModelUri)
const val = _model?.getValue()
actions.setQueryInput(val ?? '')
actions.runQuery()
dataNodeLogic({
key: values.activeTabKey,
query: {
kind: NodeKind.HogQLQuery,
query: val ?? '',
},
doNotLoad: !val,
}).mount()
}
},
queryInput: () => {
actions.reloadMetadata()
},
})),
selectors({
activeTabKey: [(s) => [s.activeModelUri], (activeModelUri) => `hogQLQueryEditor/${activeModelUri?.path}`],
isValidView: [(s) => [s.metadata], (metadata) => !!(metadata && metadata[1]?.isValidView)],
hasErrors: [
(s) => [s.modelMarkers],
(modelMarkers) => !!(modelMarkers ?? []).filter((e) => e.severity === 8 /* MarkerSeverity.Error */).length,
],
error: [
(s) => [s.hasErrors, s.modelMarkers],
(hasErrors, modelMarkers) => {
const firstError = modelMarkers.find((e) => e.severity === 8 /* MarkerSeverity.Error */)
return hasErrors && firstError
? `Error on line ${firstError.startLineNumber}, column ${firstError.startColumn}`
: null
},
],
}),
])

View File

@ -302,7 +302,7 @@ export const featureFlagLogic = kea<featureFlagLogicType>([
}),
forms(({ actions, values }) => ({
featureFlag: {
defaults: { ...NEW_FLAG } as FeatureFlagType,
defaults: { ...NEW_FLAG },
errors: ({ key, filters }) => {
return {
key: validateFeatureFlagKey(key),

View File

@ -26,6 +26,7 @@ import { ExperimentsSDKInstructions } from './sdks/experiments/ExperimentsSDKIns
import { FeatureFlagsSDKInstructions } from './sdks/feature-flags/FeatureFlagsSDKInstructions'
import { ProductAnalyticsSDKInstructions } from './sdks/product-analytics/ProductAnalyticsSDKInstructions'
import { SDKs } from './sdks/SDKs'
import { sdksLogic } from './sdks/sdksLogic'
import { SessionReplaySDKInstructions } from './sdks/session-replay/SessionReplaySDKInstructions'
import { SurveysSDKInstructions } from './sdks/surveys/SurveysSDKInstructions'
@ -105,12 +106,16 @@ const OnboardingWrapper = ({ children }: { children: React.ReactNode }): JSX.Ele
const ProductAnalyticsOnboarding = (): JSX.Element => {
const { currentTeam } = useValues(teamLogic)
const { featureFlags } = useValues(featureFlagLogic)
const { combinedSnippetAndLiveEventsHosts } = useValues(sdksLogic)
// mount the logic here so that it stays mounted for the entire onboarding flow
// not sure if there is a better way to do this
useValues(newDashboardLogic)
const showTemplateSteps =
featureFlags[FEATURE_FLAGS.ONBOARDING_DASHBOARD_TEMPLATES] == 'test' && window.innerWidth > 1000
featureFlags[FEATURE_FLAGS.ONBOARDING_DASHBOARD_TEMPLATES] == 'test' &&
window.innerWidth > 1000 &&
combinedSnippetAndLiveEventsHosts.length > 0
const options: ProductConfigOption[] = [
{

View File

@ -52,8 +52,8 @@ const UrlInput = ({ iframeRef }: { iframeRef: React.RefObject<HTMLIFrameElement>
return (
<div className="w-full flex gap-x-2 border-b border-1 border-border-bold p-2">
<LemonInput
size="small"
className="grow font-mono text-sm"
size="medium"
className="grow font-mono text-sm pl-0.5"
defaultValue={currentPath}
value={inputValue}
onChange={(v) => setInputValue(v)}

View File

@ -0,0 +1,88 @@
import { Meta, StoryFn, StoryObj } from '@storybook/react'
import { BindLogic, useActions, useValues } from 'kea'
import { useEffect } from 'react'
import recordingEventsJson from 'scenes/session-recordings/__mocks__/recording_events_query'
import recordingMetaJson from 'scenes/session-recordings/__mocks__/recording_meta.json'
import { snapshotsAsJSONLines } from 'scenes/session-recordings/__mocks__/recording_snapshots'
import { PlayerInspector } from 'scenes/session-recordings/player/inspector/PlayerInspector'
import { sessionRecordingDataLogic } from 'scenes/session-recordings/player/sessionRecordingDataLogic'
import { sessionRecordingPlayerLogic } from 'scenes/session-recordings/player/sessionRecordingPlayerLogic'
import { mswDecorator } from '~/mocks/browser'
type Story = StoryObj<typeof PlayerInspector>
const meta: Meta<typeof PlayerInspector> = {
title: 'Components/PlayerInspector',
component: PlayerInspector,
decorators: [
mswDecorator({
get: {
'/api/environments/:team_id/session_recordings/:id': recordingMetaJson,
'/api/environments/:team_id/session_recordings/:id/snapshots': (req, res, ctx) => {
// with no sources, returns sources...
if (req.url.searchParams.get('source') === 'blob') {
return res(ctx.text(snapshotsAsJSONLines()))
}
// with no source requested should return sources
return [
200,
{
sources: [
{
source: 'blob',
start_timestamp: '2023-08-11T12:03:36.097000Z',
end_timestamp: '2023-08-11T12:04:52.268000Z',
blob_key: '1691755416097-1691755492268',
},
],
},
]
},
},
post: {
'/api/environments/:team_id/query': (req, res, ctx) => {
const body = req.body as Record<string, any>
if (body.query.kind === 'EventsQuery' && body.query.properties.length === 1) {
return res(ctx.json(recordingEventsJson))
}
// default to an empty response or we duplicate information
return res(ctx.json({ results: [] }))
},
},
}),
],
}
export default meta
const BasicTemplate: StoryFn<typeof PlayerInspector> = () => {
const dataLogic = sessionRecordingDataLogic({ sessionRecordingId: '12345', playerKey: 'story-template' })
const { sessionPlayerMetaData } = useValues(dataLogic)
const { loadSnapshots, loadEvents } = useActions(dataLogic)
loadSnapshots()
// TODO you have to call actions in a particular order
// and only when some other data has already been loaded
// 🫠
useEffect(() => {
loadEvents()
}, [sessionPlayerMetaData])
return (
<div className="flex flex-col gap-2 min-w-96 min-h-120">
<BindLogic
logic={sessionRecordingPlayerLogic}
props={{
sessionRecordingId: '12345',
playerKey: 'story-template',
}}
>
<PlayerInspector />
</BindLogic>
</div>
)
}
export const Default: Story = BasicTemplate.bind({})
Default.args = {}

View File

@ -0,0 +1,11 @@
import { PlayerInspectorControls } from 'scenes/session-recordings/player/inspector/PlayerInspectorControls'
import { PlayerInspectorList } from 'scenes/session-recordings/player/inspector/PlayerInspectorList'
export function PlayerInspector(): JSX.Element {
return (
<>
<PlayerInspectorControls />
<PlayerInspectorList />
</>
)
}

View File

@ -1,9 +1,8 @@
import { useValues } from 'kea'
import { PlayerInspector } from 'scenes/session-recordings/player/inspector/PlayerInspector'
import { SessionRecordingSidebarTab } from '~/types'
import { PlayerInspectorControls } from '../inspector/PlayerInspectorControls'
import { PlayerInspectorList } from '../inspector/PlayerInspectorList'
import { PlayerSidebarDebuggerTab } from './PlayerSidebarDebuggerTab'
import { playerSidebarLogic } from './playerSidebarLogic'
import { PlayerSidebarOverviewTab } from './PlayerSidebarOverviewTab'
@ -15,12 +14,7 @@ export function PlayerSidebarTab(): JSX.Element | null {
case SessionRecordingSidebarTab.OVERVIEW:
return <PlayerSidebarOverviewTab />
case SessionRecordingSidebarTab.INSPECTOR:
return (
<>
<PlayerInspectorControls />
<PlayerInspectorList />
</>
)
return <PlayerInspector />
case SessionRecordingSidebarTab.DEBUGGER:
return <PlayerSidebarDebuggerTab />
default:

View File

@ -10,7 +10,7 @@ import { Group } from 'kea-forms'
import { SortableDragIcon } from 'lib/lemon-ui/icons'
import { LemonField } from 'lib/lemon-ui/LemonField'
import { Survey, SurveyQuestionType } from '~/types'
import { Survey, SurveyQuestionType, SurveyType } from '~/types'
import { defaultSurveyFieldValues, NewSurvey, SurveyQuestionLabel } from './constants'
import { QuestionBranchingInput } from './QuestionBranchingInput'
@ -315,7 +315,7 @@ export function SurveyEditQuestionGroup({ index, question }: { index: number; qu
)
})}
<div className="w-fit flex flex-row flex-wrap gap-2">
{(value || []).length < 6 && (
{((value || []).length < 6 || survey.type != SurveyType.Popover) && (
<>
<LemonButton
icon={<IconPlusSmall />}

View File

@ -55,7 +55,7 @@ const DEFAULT_ENTERED_TIMESTAMP: EnteredTimestamp = EnteredTimestamp {
};
pub fn process_line(line: &str) -> Value {
let args = parse_args(&line);
let args = parse_args(line);
let mut aggregate_funnel_row = AggregateFunnelRow {
results: Vec::with_capacity(args.prop_vals.len()),
breakdown_step: Option::None,
@ -112,7 +112,7 @@ impl AggregateFunnelRow {
self.process_event(
args,
&mut vars,
&events_with_same_timestamp[0],
events_with_same_timestamp[0],
prop_val,
false
);
@ -147,7 +147,7 @@ impl AggregateFunnelRow {
args,
&mut vars,
&event,
&prop_val,
prop_val,
true
);
}
@ -261,4 +261,4 @@ impl AggregateFunnelRow {
}
}
}
}
}

View File

@ -81,7 +81,7 @@ const DEFAULT_ENTERED_TIMESTAMP: EnteredTimestamp = EnteredTimestamp {
};
pub fn process_line(line: &str) -> Value {
let args = parse_args(&line);
let args = parse_args(line);
let mut aggregate_funnel_row = AggregateFunnelRow {
results: HashMap::new(),
breakdown_step: Option::None,
@ -128,7 +128,7 @@ impl AggregateFunnelRow {
self.process_event(
args,
&mut vars,
&event,
event,
prop_val,
);
}
@ -242,4 +242,4 @@ impl AggregateFunnelRow {
}
}
}
}
}

View File

@ -203,6 +203,7 @@ class EarlyAccessFeatureSerializerCreateOnly(EarlyAccessFeatureSerializer):
"key": feature_flag_key,
"name": f"Feature Flag for Feature {validated_data['name']}",
"filters": filters,
"creation_context": "early_access_features",
},
context=self.context,
)

View File

@ -115,6 +115,14 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo
)
can_edit = serializers.SerializerMethodField()
CREATION_CONTEXT_CHOICES = ("feature_flags", "experiments", "surveys", "early_access_features", "web_experiments")
creation_context = serializers.ChoiceField(
choices=CREATION_CONTEXT_CHOICES,
write_only=True,
required=False,
help_text="Indicates the origin product of the feature flag. Choices: 'feature_flags', 'experiments', 'surveys', 'early_access_features', 'web_experiments'.",
)
class Meta:
model = FeatureFlag
fields = [
@ -139,6 +147,7 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo
"usage_dashboard",
"analytics_dashboards",
"has_enriched_analytics",
"creation_context",
]
def get_can_edit(self, feature_flag: FeatureFlag) -> bool:
@ -317,6 +326,9 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo
validated_data["created_by"] = request.user
validated_data["team_id"] = self.context["team_id"]
tags = validated_data.pop("tags", None) # tags are created separately below as global tag relationships
creation_context = validated_data.pop(
"creation_context", "feature_flags"
) # default to "feature_flags" if an alternative value is not provided
self._update_filters(validated_data)
@ -347,7 +359,9 @@ class FeatureFlagSerializer(TaggedItemSerializerMixin, serializers.HyperlinkedMo
_create_usage_dashboard(instance, request.user)
report_user_action(request.user, "feature flag created", instance.get_analytics_metadata())
analytics_metadata = instance.get_analytics_metadata()
analytics_metadata["creation_context"] = creation_context
report_user_action(request.user, "feature flag created", analytics_metadata)
return instance

View File

@ -640,6 +640,7 @@ class SurveySerializerCreateUpdateOnly(serializers.ModelSerializer):
"name": f"Targeting flag for survey {name}",
"filters": filters,
"active": active,
"creation_context": "surveys",
},
context=self.context,
)

View File

@ -3,6 +3,7 @@ from unittest.mock import ANY
from rest_framework import status
from django.core.cache import cache
from django.test.client import Client
from unittest.mock import patch
from posthog.models.early_access_feature import EarlyAccessFeature
from posthog.models import FeatureFlag, Person
@ -520,6 +521,36 @@ class TestEarlyAccessFeature(APIBaseTest):
],
}
@patch("posthog.api.feature_flag.report_user_action")
def test_creation_context_is_set_to_early_access_features(self, mock_capture):
response = self.client.post(
f"/api/projects/{self.team.id}/early_access_feature/",
data={
"name": "Hick bondoogling",
"description": 'Boondoogle your hicks with one click. Just click "bazinga"!',
"stage": "concept",
},
format="json",
)
response_data = response.json()
ff_instance = FeatureFlag.objects.get(id=response_data["feature_flag"]["id"])
mock_capture.assert_called_once_with(
ANY,
"feature flag created",
{
"groups_count": 1,
"has_variants": False,
"variants_count": 0,
"has_rollout_percentage": False,
"has_filters": False,
"filter_count": 0,
"created_at": ff_instance.created_at,
"aggregating_by_groups": False,
"payload_count": 0,
"creation_context": "early_access_features",
},
)
class TestPreviewList(BaseTest, QueryMatchingTest):
def setUp(self):

View File

@ -300,6 +300,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin):
"created_at": instance.created_at,
"aggregating_by_groups": True,
"payload_count": 0,
"creation_context": "feature_flags",
},
)
@ -334,6 +335,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin):
"created_at": instance.created_at,
"aggregating_by_groups": False,
"payload_count": 0,
"creation_context": "feature_flags",
},
)
@ -385,6 +387,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin):
"created_at": instance.created_at,
"aggregating_by_groups": False,
"payload_count": 0,
"creation_context": "feature_flags",
},
)
@ -438,6 +441,7 @@ class TestFeatureFlag(APIBaseTest, ClickhouseTestMixin):
"created_at": instance.created_at,
"aggregating_by_groups": False,
"payload_count": 0,
"creation_context": "feature_flags",
},
)

View File

@ -60,6 +60,59 @@ class TestSurvey(APIBaseTest):
]
assert response_data["created_by"]["id"] == self.user.id
@patch("posthog.api.feature_flag.report_user_action")
def test_creation_context_is_set_to_surveys(self, mock_capture):
response = self.client.post(
f"/api/projects/{self.team.id}/surveys/",
data={
"name": "survey with targeting",
"type": "popover",
"targeting_flag_filters": {
"groups": [
{
"variant": None,
"rollout_percentage": None,
"properties": [
{
"key": "billing_plan",
"value": ["cloud"],
"operator": "exact",
"type": "person",
}
],
}
]
},
"conditions": {"url": "https://app.posthog.com/notebooks"},
},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
response_data = response.json()
# Ensure that a FeatureFlag has been created
ff_instance = FeatureFlag.objects.get(id=response_data["internal_targeting_flag"]["id"])
self.assertIsNotNone(ff_instance)
# Verify that report_user_action was called for the feature flag creation
mock_capture.assert_any_call(
ANY,
"feature flag created",
{
"groups_count": 1,
"has_variants": False,
"variants_count": 0,
"has_rollout_percentage": True,
"has_filters": True,
"filter_count": 2,
"created_at": ff_instance.created_at,
"aggregating_by_groups": False,
"payload_count": 0,
"creation_context": "surveys",
},
)
def test_create_adds_user_interactivity_filters(self):
response = self.client.post(
f"/api/projects/{self.team.id}/surveys/",

View File

@ -1,6 +1,7 @@
from datetime import datetime, timedelta
from rest_framework import status
from unittest.mock import ANY, patch
from posthog.models import WebExperiment
from posthog.test.base import APIBaseTest
@ -30,7 +31,8 @@ class TestWebExperiment(APIBaseTest):
format="json",
)
def test_can_create_basic_web_experiment(self):
@patch("posthog.api.feature_flag.report_user_action")
def test_can_create_basic_web_experiment(self, mock_capture):
response = self._create_web_experiment()
response_data = response.json()
assert response.status_code == status.HTTP_201_CREATED, response_data
@ -53,6 +55,22 @@ class TestWebExperiment(APIBaseTest):
assert web_experiment.type == "web"
assert web_experiment.variants.get("control") is not None
assert web_experiment.variants.get("test") is not None
mock_capture.assert_called_once_with(
ANY,
"feature flag created",
{
"groups_count": 1,
"has_variants": True,
"variants_count": 2,
"has_rollout_percentage": True,
"has_filters": False,
"filter_count": 0,
"created_at": linked_flag.created_at,
"aggregating_by_groups": False,
"payload_count": 0,
"creation_context": "web_experiments",
},
)
def test_can_list_active_web_experiments(self):
response = self._create_web_experiment("active_web_experiment")

View File

@ -98,6 +98,7 @@ class WebExperimentsAPISerializer(serializers.ModelSerializer):
"name": f'Feature Flag for Experiment {validated_data["name"]}',
"filters": filters,
"active": False,
"creation_context": "web_experiments",
},
context=self.context,
)

6
rust/Cargo.lock generated
View File

@ -1808,7 +1808,9 @@ dependencies = [
"futures",
"health",
"maxminddb",
"moka",
"once_cell",
"petgraph",
"rand",
"redis",
"regex",
@ -3046,9 +3048,13 @@ version = "0.12.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cf62eb4dd975d2dde76432fb1075c49e3ee2331cf36f1f8fd4b66550d32b6f"
dependencies = [
"async-lock 3.4.0",
"async-trait",
"crossbeam-channel",
"crossbeam-epoch",
"crossbeam-utils",
"event-listener 5.3.1",
"futures-util",
"once_cell",
"parking_lot",
"quanta 0.12.2",

View File

@ -39,6 +39,8 @@ health = { path = "../common/health" }
common-metrics = { path = "../common/metrics" }
tower = { workspace = true }
derive_builder = "0.20.1"
petgraph = "0.6.5"
moka = { version = "0.12.8", features = ["future"] }
[lints]
workspace = true

View File

@ -89,7 +89,7 @@ pub enum FlagError {
#[error("Row not found in postgres")]
RowNotFound,
#[error("failed to parse redis cache data")]
DataParsingError,
RedisDataParsingError,
#[error("failed to update redis cache")]
CacheUpdateError,
#[error("redis unavailable")]
@ -102,6 +102,14 @@ pub enum FlagError {
TimeoutError,
#[error("No group type mappings")]
NoGroupTypeMappings,
#[error("Cohort not found")]
CohortNotFound(String),
#[error("Failed to parse cohort filters")]
CohortFiltersParsingError,
#[error("Cohort dependency cycle")]
CohortDependencyCycle(String),
#[error("Person not found")]
PersonNotFound,
}
impl IntoResponse for FlagError {
@ -138,7 +146,7 @@ impl IntoResponse for FlagError {
FlagError::TokenValidationError => {
(StatusCode::UNAUTHORIZED, "The provided API key is invalid or has expired. Please check your API key and try again.".to_string())
}
FlagError::DataParsingError => {
FlagError::RedisDataParsingError => {
tracing::error!("Data parsing error: {:?}", self);
(
StatusCode::SERVICE_UNAVAILABLE,
@ -194,6 +202,21 @@ impl IntoResponse for FlagError {
"The requested row was not found in the database. Please try again later or contact support if the problem persists.".to_string(),
)
}
FlagError::CohortNotFound(msg) => {
tracing::error!("Cohort not found: {}", msg);
(StatusCode::NOT_FOUND, msg)
}
FlagError::CohortFiltersParsingError => {
tracing::error!("Failed to parse cohort filters: {:?}", self);
(StatusCode::BAD_REQUEST, "Failed to parse cohort filters. Please try again later or contact support if the problem persists.".to_string())
}
FlagError::CohortDependencyCycle(msg) => {
tracing::error!("Cohort dependency cycle: {}", msg);
(StatusCode::BAD_REQUEST, msg)
}
FlagError::PersonNotFound => {
(StatusCode::BAD_REQUEST, "Person not found. Please check your distinct_id and try again.".to_string())
}
}
.into_response()
}
@ -205,7 +228,7 @@ impl From<CustomRedisError> for FlagError {
CustomRedisError::NotFound => FlagError::TokenValidationError,
CustomRedisError::PickleError(e) => {
tracing::error!("failed to fetch data: {}", e);
FlagError::DataParsingError
FlagError::RedisDataParsingError
}
CustomRedisError::Timeout(_) => FlagError::TimeoutError,
CustomRedisError::Other(e) => {

View File

@ -0,0 +1,221 @@
use crate::api::FlagError;
use crate::cohort_models::Cohort;
use crate::flag_matching::{PostgresReader, TeamId};
use moka::future::Cache;
use std::time::Duration;
/// CohortCacheManager manages the in-memory cache of cohorts using `moka` for caching.
///
/// Features:
/// - **TTL**: Each cache entry expires after 5 minutes.
/// - **Size-based eviction**: The cache evicts least recently used entries when the maximum capacity is reached.
///
/// ```text
/// CohortCacheManager {
/// postgres_reader: PostgresReader,
/// per_team_cohorts: Cache<TeamId, Vec<Cohort>> {
/// // Example:
/// 2: [
/// Cohort { id: 1, name: "Power Users", filters: {...} },
/// Cohort { id: 2, name: "Churned", filters: {...} }
/// ],
/// 5: [
/// Cohort { id: 3, name: "Beta Users", filters: {...} }
/// ]
/// }
/// }
/// ```
///
#[derive(Clone)]
pub struct CohortCacheManager {
postgres_reader: PostgresReader,
per_team_cohort_cache: Cache<TeamId, Vec<Cohort>>,
}
impl CohortCacheManager {
pub fn new(
postgres_reader: PostgresReader,
max_capacity: Option<u64>,
ttl_seconds: Option<u64>,
) -> Self {
// We use the size of the cohort list (i.e., the number of cohorts for a given team)as the weight of the entry
let weigher =
|_: &TeamId, value: &Vec<Cohort>| -> u32 { value.len().try_into().unwrap_or(u32::MAX) };
let cache = Cache::builder()
.time_to_live(Duration::from_secs(ttl_seconds.unwrap_or(300))) // Default to 5 minutes
.weigher(weigher)
.max_capacity(max_capacity.unwrap_or(10_000)) // Default to 10,000 cohorts
.build();
Self {
postgres_reader,
per_team_cohort_cache: cache,
}
}
/// Retrieves cohorts for a given team.
///
/// If the cohorts are not present in the cache or have expired, it fetches them from the database,
/// caches the result upon successful retrieval, and then returns it.
pub async fn get_cohorts_for_team(&self, team_id: TeamId) -> Result<Vec<Cohort>, FlagError> {
if let Some(cached_cohorts) = self.per_team_cohort_cache.get(&team_id).await {
return Ok(cached_cohorts.clone());
}
let fetched_cohorts = Cohort::list_from_pg(self.postgres_reader.clone(), team_id).await?;
self.per_team_cohort_cache
.insert(team_id, fetched_cohorts.clone())
.await;
Ok(fetched_cohorts)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cohort_models::Cohort;
use crate::test_utils::{
insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client,
setup_pg_writer_client,
};
use std::sync::Arc;
use tokio::time::{sleep, Duration};
/// Helper function to setup a new team for testing.
async fn setup_test_team(
writer_client: Arc<dyn crate::database::Client + Send + Sync>,
) -> Result<TeamId, anyhow::Error> {
let team = crate::test_utils::insert_new_team_in_pg(writer_client, None).await?;
Ok(team.id)
}
/// Helper function to insert a cohort for a team.
async fn setup_test_cohort(
writer_client: Arc<dyn crate::database::Client + Send + Sync>,
team_id: TeamId,
name: Option<String>,
) -> Result<Cohort, anyhow::Error> {
let filters = serde_json::json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$active", "type": "person", "value": [true], "negation": false, "operator": "exact"}]}]}});
insert_cohort_for_team_in_pg(writer_client, team_id, name, filters, false).await
}
/// Tests that cache entries expire after the specified TTL.
#[tokio::test]
async fn test_cache_expiry() -> Result<(), anyhow::Error> {
let writer_client = setup_pg_writer_client(None).await;
let reader_client = setup_pg_reader_client(None).await;
let team_id = setup_test_team(writer_client.clone()).await?;
let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?;
// Initialize CohortCacheManager with a short TTL for testing
let cohort_cache = CohortCacheManager::new(
reader_client.clone(),
Some(100),
Some(1), // 1-second TTL
);
let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?;
assert_eq!(cohorts.len(), 1);
assert_eq!(cohorts[0].team_id, team_id);
let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
assert!(cached_cohorts.is_some());
// Wait for TTL to expire
sleep(Duration::from_secs(2)).await;
// Attempt to retrieve from cache again
let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
assert!(cached_cohorts.is_none(), "Cache entry should have expired");
Ok(())
}
/// Tests that the cache correctly evicts least recently used entries based on the weigher.
#[tokio::test]
async fn test_cache_weigher() -> Result<(), anyhow::Error> {
let writer_client = setup_pg_writer_client(None).await;
let reader_client = setup_pg_reader_client(None).await;
// Define a smaller max_capacity for testing
let max_capacity: u64 = 3;
let cohort_cache = CohortCacheManager::new(reader_client.clone(), Some(max_capacity), None);
let mut inserted_team_ids = Vec::new();
// Insert multiple teams and their cohorts
for _ in 0..max_capacity {
let team = insert_new_team_in_pg(writer_client.clone(), None).await?;
let team_id = team.id;
inserted_team_ids.push(team_id);
setup_test_cohort(writer_client.clone(), team_id, None).await?;
cohort_cache.get_cohorts_for_team(team_id).await?;
}
cohort_cache.per_team_cohort_cache.run_pending_tasks().await;
let cache_size = cohort_cache.per_team_cohort_cache.entry_count();
assert_eq!(
cache_size, max_capacity,
"Cache size should be equal to max_capacity"
);
let new_team = insert_new_team_in_pg(writer_client.clone(), None).await?;
let new_team_id = new_team.id;
setup_test_cohort(writer_client.clone(), new_team_id, None).await?;
cohort_cache.get_cohorts_for_team(new_team_id).await?;
cohort_cache.per_team_cohort_cache.run_pending_tasks().await;
let cache_size_after = cohort_cache.per_team_cohort_cache.entry_count();
assert_eq!(
cache_size_after, max_capacity,
"Cache size should remain equal to max_capacity after eviction"
);
let evicted_team_id = &inserted_team_ids[0];
let cached_cohorts = cohort_cache
.per_team_cohort_cache
.get(evicted_team_id)
.await;
assert!(
cached_cohorts.is_none(),
"Least recently used cache entry should have been evicted"
);
let cached_new_team = cohort_cache.per_team_cohort_cache.get(&new_team_id).await;
assert!(
cached_new_team.is_some(),
"Newly added cache entry should be present"
);
Ok(())
}
#[tokio::test]
async fn test_get_cohorts_for_team() -> Result<(), anyhow::Error> {
let writer_client = setup_pg_writer_client(None).await;
let reader_client = setup_pg_reader_client(None).await;
let team_id = setup_test_team(writer_client.clone()).await?;
let _cohort = setup_test_cohort(writer_client.clone(), team_id, None).await?;
let cohort_cache = CohortCacheManager::new(reader_client.clone(), None, None);
let cached_cohorts = cohort_cache.per_team_cohort_cache.get(&team_id).await;
assert!(cached_cohorts.is_none(), "Cache should initially be empty");
let cohorts = cohort_cache.get_cohorts_for_team(team_id).await?;
assert_eq!(cohorts.len(), 1);
assert_eq!(cohorts[0].team_id, team_id);
let cached_cohorts = cohort_cache
.per_team_cohort_cache
.get(&team_id)
.await
.unwrap();
assert_eq!(cached_cohorts.len(), 1);
assert_eq!(cached_cohorts[0].team_id, team_id);
Ok(())
}
}

View File

@ -0,0 +1,50 @@
use crate::flag_definitions::PropertyFilter;
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Cohort {
pub id: i32,
pub name: String,
pub description: Option<String>,
pub team_id: i32,
pub deleted: bool,
pub filters: serde_json::Value,
pub query: Option<serde_json::Value>,
pub version: Option<i32>,
pub pending_version: Option<i32>,
pub count: Option<i32>,
pub is_calculating: bool,
pub is_static: bool,
pub errors_calculating: i32,
pub groups: serde_json::Value,
pub created_by_id: Option<i32>,
}
pub type CohortId = i32;
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "UPPERCASE")]
pub enum CohortPropertyType {
AND,
OR,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CohortProperty {
pub properties: InnerCohortProperty,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InnerCohortProperty {
#[serde(rename = "type")]
pub prop_type: CohortPropertyType,
pub values: Vec<CohortValues>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CohortValues {
#[serde(rename = "type")]
pub prop_type: String,
pub values: Vec<PropertyFilter>,
}

View File

@ -0,0 +1,369 @@
use std::collections::HashSet;
use std::sync::Arc;
use tracing::instrument;
use crate::cohort_models::{Cohort, CohortId, CohortProperty, InnerCohortProperty};
use crate::{api::FlagError, database::Client as DatabaseClient, flag_definitions::PropertyFilter};
impl Cohort {
/// Returns a cohort from postgres given a cohort_id and team_id
#[instrument(skip_all)]
pub async fn from_pg(
client: Arc<dyn DatabaseClient + Send + Sync>,
cohort_id: i32,
team_id: i32,
) -> Result<Cohort, FlagError> {
let mut conn = client.get_connection().await.map_err(|e| {
tracing::error!("Failed to get database connection: {}", e);
// TODO should I model my errors more generally? Like, yes, everything behind this API is technically a FlagError,
// but I'm not sure if accessing Cohort definitions should be a FlagError (vs idk, a CohortError? A more general API error?)
FlagError::DatabaseUnavailable
})?;
let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id FROM posthog_cohort WHERE id = $1 AND team_id = $2";
let cohort = sqlx::query_as::<_, Cohort>(query)
.bind(cohort_id)
.bind(team_id)
.fetch_optional(&mut *conn)
.await
.map_err(|e| {
tracing::error!("Failed to fetch cohort from database: {}", e);
FlagError::Internal(format!("Database query error: {}", e))
})?;
cohort.ok_or_else(|| {
FlagError::CohortNotFound(format!(
"Cohort with id {} not found for team {}",
cohort_id, team_id
))
})
}
/// Returns all cohorts for a given team
#[instrument(skip_all)]
pub async fn list_from_pg(
client: Arc<dyn DatabaseClient + Send + Sync>,
team_id: i32,
) -> Result<Vec<Cohort>, FlagError> {
let mut conn = client.get_connection().await.map_err(|e| {
tracing::error!("Failed to get database connection: {}", e);
FlagError::DatabaseUnavailable
})?;
let query = "SELECT id, name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id FROM posthog_cohort WHERE team_id = $1";
let cohorts = sqlx::query_as::<_, Cohort>(query)
.bind(team_id)
.fetch_all(&mut *conn)
.await
.map_err(|e| {
tracing::error!("Failed to fetch cohorts from database: {}", e);
FlagError::Internal(format!("Database query error: {}", e))
})?;
Ok(cohorts)
}
/// Parses the filters JSON into a CohortProperty structure
// TODO: this doesn't handle the deprecated "groups" field, see
// https://github.com/PostHog/posthog/blob/feat/dynamic-cohorts-rust/posthog/models/cohort/cohort.py#L114-L169
// I'll handle that in a separate PR.
pub fn parse_filters(&self) -> Result<Vec<PropertyFilter>, FlagError> {
let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone())
.map_err(|e| {
tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e);
FlagError::CohortFiltersParsingError
})?;
Ok(cohort_property
.properties
.to_property_filters()
.into_iter()
.filter(|f| !(f.key == "id" && f.prop_type == "cohort"))
.collect())
}
/// Extracts dependent CohortIds from the cohort's filters
pub fn extract_dependencies(&self) -> Result<HashSet<CohortId>, FlagError> {
let cohort_property: CohortProperty = serde_json::from_value(self.filters.clone())
.map_err(|e| {
tracing::error!("Failed to parse filters for cohort {}: {}", self.id, e);
FlagError::CohortFiltersParsingError
})?;
let mut dependencies = HashSet::new();
Self::traverse_filters(&cohort_property.properties, &mut dependencies)?;
Ok(dependencies)
}
/// Recursively traverses the filter tree to find cohort dependencies
///
/// Example filter tree structure:
/// ```json
/// {
/// "properties": {
/// "type": "OR",
/// "values": [
/// {
/// "type": "OR",
/// "values": [
/// {
/// "key": "id",
/// "value": 123,
/// "type": "cohort",
/// "operator": "exact"
/// },
/// {
/// "key": "email",
/// "value": "@posthog.com",
/// "type": "person",
/// "operator": "icontains"
/// }
/// ]
/// }
/// ]
/// }
/// }
/// ```
fn traverse_filters(
inner: &InnerCohortProperty,
dependencies: &mut HashSet<CohortId>,
) -> Result<(), FlagError> {
for cohort_values in &inner.values {
for filter in &cohort_values.values {
if filter.is_cohort() {
// Assuming the value is a single integer CohortId
if let Some(cohort_id) = filter.value.as_i64() {
dependencies.insert(cohort_id as CohortId);
} else {
return Err(FlagError::CohortFiltersParsingError);
}
}
// NB: we don't support nested cohort properties, so we don't need to traverse further
}
}
Ok(())
}
}
impl InnerCohortProperty {
/// Flattens the nested cohort property structure into a list of property filters.
///
/// The cohort property structure in Postgres looks like:
/// ```json
/// {
/// "type": "OR",
/// "values": [
/// {
/// "type": "OR",
/// "values": [
/// {
/// "key": "email",
/// "value": "@posthog.com",
/// "type": "person",
/// "operator": "icontains"
/// },
/// {
/// "key": "age",
/// "value": 25,
/// "type": "person",
/// "operator": "gt"
/// }
/// ]
/// }
/// ]
/// }
/// ```
pub fn to_property_filters(&self) -> Vec<PropertyFilter> {
self.values
.iter()
.flat_map(|value| &value.values)
.cloned()
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
cohort_models::{CohortPropertyType, CohortValues},
test_utils::{
insert_cohort_for_team_in_pg, insert_new_team_in_pg, setup_pg_reader_client,
setup_pg_writer_client,
},
};
use serde_json::json;
#[tokio::test]
async fn test_cohort_from_pg() {
let postgres_reader = setup_pg_reader_client(None).await;
let postgres_writer = setup_pg_writer_client(None).await;
let team = insert_new_team_in_pg(postgres_reader.clone(), None)
.await
.expect("Failed to insert team");
let cohort = insert_cohort_for_team_in_pg(
postgres_writer.clone(),
team.id,
None,
json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$initial_browser_version", "type": "person", "value": ["125"], "negation": false, "operator": "exact"}]}]}}),
false,
)
.await
.expect("Failed to insert cohort");
let fetched_cohort = Cohort::from_pg(postgres_reader, cohort.id, team.id)
.await
.expect("Failed to fetch cohort");
assert_eq!(fetched_cohort.id, cohort.id);
assert_eq!(fetched_cohort.name, "Test Cohort");
assert_eq!(fetched_cohort.team_id, team.id);
}
#[tokio::test]
async fn test_list_from_pg() {
let postgres_reader = setup_pg_reader_client(None).await;
let postgres_writer = setup_pg_writer_client(None).await;
let team = insert_new_team_in_pg(postgres_reader.clone(), None)
.await
.expect("Failed to insert team");
// Insert multiple cohorts for the team
insert_cohort_for_team_in_pg(
postgres_writer.clone(),
team.id,
Some("Cohort 1".to_string()),
json!({"properties": {"type": "AND", "values": [{"type": "property", "values": [{"key": "age", "type": "person", "value": [30], "negation": false, "operator": "gt"}]}]}}),
false,
)
.await
.expect("Failed to insert cohort1");
insert_cohort_for_team_in_pg(
postgres_writer.clone(),
team.id,
Some("Cohort 2".to_string()),
json!({"properties": {"type": "OR", "values": [{"type": "property", "values": [{"key": "country", "type": "person", "value": ["USA"], "negation": false, "operator": "exact"}]}]}}),
false,
)
.await
.expect("Failed to insert cohort2");
let cohorts = Cohort::list_from_pg(postgres_reader, team.id)
.await
.expect("Failed to list cohorts");
assert_eq!(cohorts.len(), 2);
let names: HashSet<String> = cohorts.into_iter().map(|c| c.name).collect();
assert!(names.contains("Cohort 1"));
assert!(names.contains("Cohort 2"));
}
#[test]
fn test_cohort_parse_filters() {
let cohort = Cohort {
id: 1,
name: "Test Cohort".to_string(),
description: None,
team_id: 1,
deleted: false,
filters: json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$initial_browser_version", "type": "person", "value": ["125"], "negation": false, "operator": "exact"}]}]}}),
query: None,
version: None,
pending_version: None,
count: None,
is_calculating: false,
is_static: false,
errors_calculating: 0,
groups: json!({}),
created_by_id: None,
};
let result = cohort.parse_filters().unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].key, "$initial_browser_version");
assert_eq!(result[0].value, json!(["125"]));
assert_eq!(result[0].prop_type, "person");
}
#[test]
fn test_cohort_property_to_property_filters() {
let cohort_property = InnerCohortProperty {
prop_type: CohortPropertyType::AND,
values: vec![CohortValues {
prop_type: "property".to_string(),
values: vec![
PropertyFilter {
key: "email".to_string(),
value: json!("test@example.com"),
operator: None,
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
},
PropertyFilter {
key: "age".to_string(),
value: json!(25),
operator: None,
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
},
],
}],
};
let result = cohort_property.to_property_filters();
assert_eq!(result.len(), 2);
assert_eq!(result[0].key, "email");
assert_eq!(result[0].value, json!("test@example.com"));
assert_eq!(result[1].key, "age");
assert_eq!(result[1].value, json!(25));
}
#[tokio::test]
async fn test_extract_dependencies() {
let postgres_reader = setup_pg_reader_client(None).await;
let postgres_writer = setup_pg_writer_client(None).await;
let team = insert_new_team_in_pg(postgres_reader.clone(), None)
.await
.expect("Failed to insert team");
// Insert a single cohort that is dependent on another cohort
let dependent_cohort = insert_cohort_for_team_in_pg(
postgres_writer.clone(),
team.id,
Some("Dependent Cohort".to_string()),
json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "$browser", "type": "person", "value": ["Safari"], "negation": false, "operator": "exact"}]}]}}),
false,
)
.await
.expect("Failed to insert dependent_cohort");
// Insert main cohort with a single dependency
let main_cohort = insert_cohort_for_team_in_pg(
postgres_writer.clone(),
team.id,
Some("Main Cohort".to_string()),
json!({"properties": {"type": "OR", "values": [{"type": "OR", "values": [{"key": "id", "type": "cohort", "value": dependent_cohort.id, "negation": false}]}]}}),
false,
)
.await
.expect("Failed to insert main_cohort");
let fetched_main_cohort = Cohort::from_pg(postgres_reader.clone(), main_cohort.id, team.id)
.await
.expect("Failed to fetch main cohort");
println!("fetched_main_cohort: {:?}", fetched_main_cohort);
let dependencies = fetched_main_cohort.extract_dependencies().unwrap();
let expected_dependencies: HashSet<CohortId> =
[dependent_cohort.id].iter().cloned().collect();
assert_eq!(dependencies, expected_dependencies);
}
}

View File

@ -1,4 +1,7 @@
use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient};
use crate::{
api::FlagError, cohort_models::CohortId, database::Client as DatabaseClient,
redis::Client as RedisClient,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::instrument;
@ -7,7 +10,7 @@ use tracing::instrument;
// TODO: Add integration tests across repos to ensure this doesn't happen.
pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_";
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum OperatorType {
Exact,
@ -25,6 +28,8 @@ pub enum OperatorType {
IsDateExact,
IsDateAfter,
IsDateBefore,
In,
NotIn,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@ -36,10 +41,28 @@ pub struct PropertyFilter {
pub value: serde_json::Value,
pub operator: Option<OperatorType>,
#[serde(rename = "type")]
// TODO: worth making a enum here to differentiate between cohort and person filters?
pub prop_type: String,
pub negation: Option<bool>,
pub group_type_index: Option<i32>,
}
impl PropertyFilter {
/// Checks if the filter is a cohort filter
pub fn is_cohort(&self) -> bool {
self.key == "id" && self.prop_type == "cohort"
}
/// Returns the cohort id if the filter is a cohort filter, or None if it's not a cohort filter
/// or if the value cannot be parsed as a cohort id
pub fn get_cohort_id(&self) -> Option<CohortId> {
if !self.is_cohort() {
return None;
}
self.value.as_i64().map(|id| id as CohortId)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FlagGroupType {
pub properties: Option<Vec<PropertyFilter>>,
@ -68,6 +91,9 @@ pub struct FlagFilters {
pub super_groups: Option<Vec<FlagGroupType>>,
}
// TODO: see if you can combine these two structs, like we do with cohort models
// this will require not deserializing on read and instead doing it lazily, on-demand
// (which, tbh, is probably a better idea)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FeatureFlag {
pub id: i32,
@ -142,7 +168,7 @@ impl FeatureFlagList {
tracing::error!("failed to parse data to flags list: {}", e);
println!("failed to parse data: {}", e);
FlagError::DataParsingError
FlagError::RedisDataParsingError
})?;
Ok(FeatureFlagList { flags: flags_list })
@ -174,7 +200,7 @@ impl FeatureFlagList {
.map(|row| {
let filters = serde_json::from_value(row.filters).map_err(|e| {
tracing::error!("Failed to deserialize filters for flag {}: {}", row.key, e);
FlagError::DataParsingError
FlagError::RedisDataParsingError
})?;
Ok(FeatureFlag {
@ -200,7 +226,7 @@ impl FeatureFlagList {
) -> Result<(), FlagError> {
let payload = serde_json::to_string(&flags.flags).map_err(|e| {
tracing::error!("Failed to serialize flags: {}", e);
FlagError::DataParsingError
FlagError::RedisDataParsingError
})?;
client
@ -1095,7 +1121,7 @@ mod tests {
.expect("Failed to set malformed JSON in Redis");
let result = FeatureFlagList::from_redis(redis_client, team.id).await;
assert!(matches!(result, Err(FlagError::DataParsingError)));
assert!(matches!(result, Err(FlagError::RedisDataParsingError)));
// Test database query error (using a non-existent table)
let result = sqlx::query("SELECT * FROM non_existent_table")

File diff suppressed because it is too large Load Diff

View File

@ -158,8 +158,8 @@ impl FlagRequest {
pub async fn get_flags_from_cache_or_pg(
&self,
team_id: i32,
redis_client: Arc<dyn RedisClient + Send + Sync>,
pg_client: Arc<dyn DatabaseClient + Send + Sync>,
redis_client: &Arc<dyn RedisClient + Send + Sync>,
pg_client: &Arc<dyn DatabaseClient + Send + Sync>,
) -> Result<FeatureFlagList, FlagError> {
let mut cache_hit = false;
let flags = match FeatureFlagList::from_redis(redis_client.clone(), team_id).await {
@ -167,10 +167,14 @@ impl FlagRequest {
cache_hit = true;
Ok(flags)
}
Err(_) => match FeatureFlagList::from_pg(pg_client, team_id).await {
Err(_) => match FeatureFlagList::from_pg(pg_client.clone(), team_id).await {
Ok(flags) => {
if let Err(e) =
FeatureFlagList::update_flags_in_redis(redis_client, team_id, &flags).await
if let Err(e) = FeatureFlagList::update_flags_in_redis(
redis_client.clone(),
team_id,
&flags,
)
.await
{
tracing::warn!("Failed to update Redis cache: {}", e);
// TODO add new metric category for this
@ -206,7 +210,6 @@ mod tests {
TEAM_FLAGS_CACHE_PREFIX,
};
use crate::flag_request::FlagRequest;
use crate::redis::Client as RedisClient;
use crate::team::Team;
use crate::test_utils::{insert_new_team_in_redis, setup_pg_reader_client, setup_redis_client};
use bytes::Bytes;
@ -360,6 +363,7 @@ mod tests {
operator: Some(OperatorType::Exact),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
}]),
rollout_percentage: Some(50.0),
variant: None,
@ -402,6 +406,7 @@ mod tests {
operator: Some(OperatorType::Exact),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
}]),
rollout_percentage: Some(100.0),
variant: None,
@ -426,7 +431,7 @@ mod tests {
// Test fetching from Redis
let result = flag_request
.get_flags_from_cache_or_pg(team.id, redis_client.clone(), pg_client.clone())
.get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client)
.await;
assert!(result.is_ok());
let fetched_flags = result.unwrap();
@ -483,7 +488,7 @@ mod tests {
.expect("Failed to remove flags from Redis");
let result = flag_request
.get_flags_from_cache_or_pg(team.id, redis_client.clone(), pg_client.clone())
.get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client)
.await;
assert!(result.is_ok());
// Verify that the flags were re-added to Redis

View File

@ -1,4 +1,7 @@
pub mod api;
pub mod cohort_cache;
pub mod cohort_models;
pub mod cohort_operations;
pub mod config;
pub mod database;
pub mod feature_flag_match_reason;
@ -8,13 +11,13 @@ pub mod flag_matching;
pub mod flag_request;
pub mod geoip;
pub mod metrics_consts;
pub mod metrics_utils;
pub mod property_matching;
pub mod redis;
pub mod request_handler;
pub mod router;
pub mod server;
pub mod team;
pub mod utils;
pub mod v0_endpoint;
// Test modules don't need to be compiled with main binary

View File

@ -44,7 +44,7 @@ pub fn match_property(
}
let key = &property.key;
let operator = property.operator.clone().unwrap_or(OperatorType::Exact);
let operator = property.operator.unwrap_or(OperatorType::Exact);
let value = &property.value;
let match_value = matching_property_values.get(key);
@ -193,6 +193,12 @@ pub fn match_property(
// Ok(false)
// }
}
OperatorType::In | OperatorType::NotIn => {
// TODO: we handle these in cohort matching, so we can just return false here
// because by the time we match properties, we've already decomposed the cohort
// filter into multiple property filters
Ok(false)
}
}
}
@ -260,6 +266,7 @@ mod test_match_properties {
operator: None,
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -313,6 +320,7 @@ mod test_match_properties {
operator: Some(OperatorType::Exact),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -335,6 +343,7 @@ mod test_match_properties {
operator: Some(OperatorType::Exact),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -379,6 +388,7 @@ mod test_match_properties {
operator: Some(OperatorType::IsNot),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -416,6 +426,7 @@ mod test_match_properties {
operator: Some(OperatorType::IsNot),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -490,6 +501,7 @@ mod test_match_properties {
operator: Some(OperatorType::IsSet),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -538,6 +550,7 @@ mod test_match_properties {
operator: Some(OperatorType::Icontains),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -595,6 +608,7 @@ mod test_match_properties {
operator: Some(OperatorType::Icontains),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -634,6 +648,7 @@ mod test_match_properties {
operator: Some(OperatorType::Regex),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -674,6 +689,7 @@ mod test_match_properties {
operator: Some(OperatorType::Regex),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
&property_b,
@ -708,6 +724,7 @@ mod test_match_properties {
operator: Some(OperatorType::Regex),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -730,6 +747,7 @@ mod test_match_properties {
operator: Some(OperatorType::Regex),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
&property_d,
@ -760,6 +778,7 @@ mod test_match_properties {
operator: Some(OperatorType::Gt),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -802,6 +821,7 @@ mod test_match_properties {
operator: Some(OperatorType::Lt),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -848,6 +868,7 @@ mod test_match_properties {
operator: Some(OperatorType::Gte),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -889,6 +910,7 @@ mod test_match_properties {
operator: Some(OperatorType::Lt),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -935,6 +957,7 @@ mod test_match_properties {
operator: Some(OperatorType::Lt),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -1013,6 +1036,7 @@ mod test_match_properties {
operator: Some(OperatorType::IsNot),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1034,6 +1058,7 @@ mod test_match_properties {
operator: Some(OperatorType::IsSet),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -1049,6 +1074,7 @@ mod test_match_properties {
operator: Some(OperatorType::Icontains),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -1070,6 +1096,7 @@ mod test_match_properties {
operator: Some(OperatorType::Regex),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1085,6 +1112,7 @@ mod test_match_properties {
operator: Some(OperatorType::Regex),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1118,6 +1146,7 @@ mod test_match_properties {
operator: None,
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1137,6 +1166,7 @@ mod test_match_properties {
operator: Some(OperatorType::Exact),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1152,6 +1182,7 @@ mod test_match_properties {
operator: Some(OperatorType::IsSet),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1167,6 +1198,7 @@ mod test_match_properties {
operator: Some(OperatorType::IsNotSet),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(match_property(
@ -1203,6 +1235,7 @@ mod test_match_properties {
operator: Some(OperatorType::Icontains),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1218,6 +1251,7 @@ mod test_match_properties {
operator: Some(OperatorType::NotIcontains),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1233,6 +1267,7 @@ mod test_match_properties {
operator: Some(OperatorType::Regex),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1248,6 +1283,7 @@ mod test_match_properties {
operator: Some(OperatorType::NotRegex),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1263,6 +1299,7 @@ mod test_match_properties {
operator: Some(OperatorType::Gt),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1278,6 +1315,7 @@ mod test_match_properties {
operator: Some(OperatorType::Gte),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1293,6 +1331,7 @@ mod test_match_properties {
operator: Some(OperatorType::Lt),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1308,6 +1347,7 @@ mod test_match_properties {
operator: Some(OperatorType::Lte),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(
@ -1324,6 +1364,7 @@ mod test_match_properties {
operator: Some(OperatorType::IsDateBefore),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
};
assert!(!match_property(

View File

@ -1,5 +1,6 @@
use crate::{
api::{FlagError, FlagsResponse},
cohort_cache::CohortCacheManager,
database::Client,
flag_definitions::FeatureFlagList,
flag_matching::{FeatureFlagMatcher, GroupTypeMappingCache},
@ -69,6 +70,7 @@ pub struct FeatureFlagEvaluationContext {
feature_flags: FeatureFlagList,
postgres_reader: Arc<dyn Client + Send + Sync>,
postgres_writer: Arc<dyn Client + Send + Sync>,
cohort_cache: Arc<CohortCacheManager>,
#[builder(default)]
person_property_overrides: Option<HashMap<String, Value>>,
#[builder(default)]
@ -95,6 +97,7 @@ pub async fn process_request(context: RequestContext) -> Result<FlagsResponse, F
let team = request
.get_team_from_cache_or_pg(&token, state.redis.clone(), state.postgres_reader.clone())
.await?;
let distinct_id = request.extract_distinct_id()?;
let groups = request.groups.clone();
let team_id = team.id;
@ -108,18 +111,16 @@ pub async fn process_request(context: RequestContext) -> Result<FlagsResponse, F
let hash_key_override = request.anon_distinct_id.clone();
let feature_flags_from_cache_or_pg = request
.get_flags_from_cache_or_pg(team_id, state.redis.clone(), state.postgres_reader.clone())
.get_flags_from_cache_or_pg(team_id, &state.redis, &state.postgres_reader)
.await?;
let postgres_reader_dyn: Arc<dyn Client + Send + Sync> = state.postgres_reader.clone();
let postgres_writer_dyn: Arc<dyn Client + Send + Sync> = state.postgres_writer.clone();
let evaluation_context = FeatureFlagEvaluationContextBuilder::default()
.team_id(team_id)
.distinct_id(distinct_id)
.feature_flags(feature_flags_from_cache_or_pg)
.postgres_reader(postgres_reader_dyn)
.postgres_writer(postgres_writer_dyn)
.postgres_reader(state.postgres_reader.clone())
.postgres_writer(state.postgres_writer.clone())
.cohort_cache(state.cohort_cache.clone())
.person_property_overrides(person_property_overrides)
.group_property_overrides(group_property_overrides)
.groups(groups)
@ -224,8 +225,8 @@ pub async fn evaluate_feature_flags(context: FeatureFlagEvaluationContext) -> Fl
context.team_id,
context.postgres_reader,
context.postgres_writer,
context.cohort_cache,
Some(group_type_mapping_cache),
None, // TODO maybe remove this from the matcher struct, since it's used internally but not passed around
context.groups,
);
feature_flag_matcher
@ -359,6 +360,7 @@ mod tests {
async fn test_evaluate_feature_flags() {
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let flag = FeatureFlag {
name: Some("Test Flag".to_string()),
id: 1,
@ -374,6 +376,7 @@ mod tests {
operator: Some(OperatorType::Exact),
prop_type: "person".to_string(),
group_type_index: None,
negation: None,
}]),
rollout_percentage: Some(100.0), // Set to 100% to ensure it's always on
variant: None,
@ -397,6 +400,7 @@ mod tests {
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.cohort_cache(cohort_cache)
.person_property_overrides(Some(person_properties))
.build()
.expect("Failed to build FeatureFlagEvaluationContext");
@ -505,6 +509,7 @@ mod tests {
async fn test_evaluate_feature_flags_multiple_flags() {
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let flags = vec![
FeatureFlag {
name: Some("Flag 1".to_string()),
@ -556,6 +561,7 @@ mod tests {
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.cohort_cache(cohort_cache)
.build()
.expect("Failed to build FeatureFlagEvaluationContext");
@ -608,6 +614,7 @@ mod tests {
async fn test_evaluate_feature_flags_with_overrides() {
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let team = insert_new_team_in_pg(postgres_reader.clone(), None)
.await
.unwrap();
@ -627,6 +634,7 @@ mod tests {
operator: Some(OperatorType::Exact),
prop_type: "group".to_string(),
group_type_index: Some(0),
negation: None,
}]),
rollout_percentage: Some(100.0),
variant: None,
@ -655,6 +663,7 @@ mod tests {
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.cohort_cache(cohort_cache)
.group_property_overrides(Some(group_property_overrides))
.groups(Some(groups))
.build()
@ -688,6 +697,7 @@ mod tests {
let long_id = "a".repeat(1000);
let postgres_reader: Arc<dyn Client + Send + Sync> = setup_pg_reader_client(None).await;
let postgres_writer: Arc<dyn Client + Send + Sync> = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let flag = FeatureFlag {
name: Some("Test Flag".to_string()),
id: 1,
@ -717,6 +727,7 @@ mod tests {
.feature_flags(feature_flag_list)
.postgres_reader(postgres_reader)
.postgres_writer(postgres_writer)
.cohort_cache(cohort_cache)
.build()
.expect("Failed to build FeatureFlagEvaluationContext");

View File

@ -9,11 +9,12 @@ use health::HealthRegistry;
use tower::limit::ConcurrencyLimitLayer;
use crate::{
cohort_cache::CohortCacheManager,
config::{Config, TeamIdsToTrack},
database::Client as DatabaseClient,
geoip::GeoIpClient,
metrics_utils::team_id_label_filter,
redis::Client as RedisClient,
utils::team_id_label_filter,
v0_endpoint,
};
@ -22,6 +23,7 @@ pub struct State {
pub redis: Arc<dyn RedisClient + Send + Sync>,
pub postgres_reader: Arc<dyn DatabaseClient + Send + Sync>,
pub postgres_writer: Arc<dyn DatabaseClient + Send + Sync>,
pub cohort_cache: Arc<CohortCacheManager>, // TODO does this need a better name than just `cohort_cache`?
pub geoip: Arc<GeoIpClient>,
pub team_ids_to_track: TeamIdsToTrack,
}
@ -30,6 +32,7 @@ pub fn router<R, D>(
redis: Arc<R>,
postgres_reader: Arc<D>,
postgres_writer: Arc<D>,
cohort_cache: Arc<CohortCacheManager>,
geoip: Arc<GeoIpClient>,
liveness: HealthRegistry,
config: Config,
@ -42,6 +45,7 @@ where
redis,
postgres_reader,
postgres_writer,
cohort_cache,
geoip,
team_ids_to_track: config.team_ids_to_track.clone(),
};

View File

@ -6,6 +6,7 @@ use std::time::Duration;
use health::{HealthHandle, HealthRegistry};
use tokio::net::TcpListener;
use crate::cohort_cache::CohortCacheManager;
use crate::config::Config;
use crate::database::get_pool;
use crate::geoip::GeoIpClient;
@ -54,6 +55,8 @@ where
}
};
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let health = HealthRegistry::new("liveness");
// TODO - we don't have a more complex health check yet, but we should add e.g. some around DB operations
@ -67,6 +70,7 @@ where
redis_client,
postgres_reader,
postgres_writer,
cohort_cache,
geoip_service,
health,
config,

View File

@ -42,7 +42,7 @@ impl Team {
// TODO: Consider an LRU cache for teams as well, with small TTL to skip redis/pg lookups
let team: Team = serde_json::from_str(&serialized_team).map_err(|e| {
tracing::error!("failed to parse data to team: {}", e);
FlagError::DataParsingError
FlagError::RedisDataParsingError
})?;
Ok(team)
@ -55,7 +55,7 @@ impl Team {
) -> Result<(), FlagError> {
let serialized_team = serde_json::to_string(&team).map_err(|e| {
tracing::error!("Failed to serialize team: {}", e);
FlagError::DataParsingError
FlagError::RedisDataParsingError
})?;
client
@ -173,7 +173,7 @@ mod tests {
let client = setup_redis_client(None);
match Team::from_redis(client.clone(), team.api_token.clone()).await {
Err(FlagError::DataParsingError) => (),
Err(FlagError::RedisDataParsingError) => (),
Err(other) => panic!("Expected DataParsingError, got {:?}", other),
Ok(_) => panic!("Expected DataParsingError"),
};

View File

@ -1,11 +1,12 @@
use anyhow::Error;
use axum::async_trait;
use serde_json::{json, Value};
use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, PgPool, Postgres};
use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, Postgres, Row};
use std::sync::Arc;
use uuid::Uuid;
use crate::{
cohort_models::Cohort,
config::{Config, DEFAULT_TEST_CONFIG},
database::{get_pool, Client, CustomDatabaseError},
flag_definitions::{self, FeatureFlag, FeatureFlagRow},
@ -23,7 +24,9 @@ pub fn random_string(prefix: &str, length: usize) -> String {
format!("{}{}", prefix, suffix)
}
pub async fn insert_new_team_in_redis(client: Arc<RedisClient>) -> Result<Team, Error> {
pub async fn insert_new_team_in_redis(
client: Arc<dyn RedisClientTrait + Send + Sync>,
) -> Result<Team, Error> {
let id = rand::thread_rng().gen_range(0..10_000_000);
let token = random_string("phc_", 12);
let team = Team {
@ -48,7 +51,7 @@ pub async fn insert_new_team_in_redis(client: Arc<RedisClient>) -> Result<Team,
}
pub async fn insert_flags_for_team_in_redis(
client: Arc<RedisClient>,
client: Arc<dyn RedisClientTrait + Send + Sync>,
team_id: i32,
json_value: Option<String>,
) -> Result<(), Error> {
@ -88,7 +91,7 @@ pub async fn insert_flags_for_team_in_redis(
Ok(())
}
pub fn setup_redis_client(url: Option<String>) -> Arc<RedisClient> {
pub fn setup_redis_client(url: Option<String>) -> Arc<dyn RedisClientTrait + Send + Sync> {
let redis_url = match url {
Some(value) => value,
None => "redis://localhost:6379/".to_string(),
@ -130,7 +133,7 @@ pub fn create_flag_from_json(json_value: Option<String>) -> Vec<FeatureFlag> {
flags
}
pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc<PgPool> {
pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc<dyn Client + Send + Sync> {
let config = config.unwrap_or(&DEFAULT_TEST_CONFIG);
Arc::new(
get_pool(&config.read_database_url, config.max_pg_connections)
@ -139,7 +142,7 @@ pub async fn setup_pg_reader_client(config: Option<&Config>) -> Arc<PgPool> {
)
}
pub async fn setup_pg_writer_client(config: Option<&Config>) -> Arc<PgPool> {
pub async fn setup_pg_writer_client(config: Option<&Config>) -> Arc<dyn Client + Send + Sync> {
let config = config.unwrap_or(&DEFAULT_TEST_CONFIG);
Arc::new(
get_pool(&config.write_database_url, config.max_pg_connections)
@ -261,7 +264,7 @@ pub async fn insert_new_team_in_pg(
}
pub async fn insert_flag_for_team_in_pg(
client: Arc<PgPool>,
client: Arc<dyn Client + Send + Sync>,
team_id: i32,
flag: Option<FeatureFlagRow>,
) -> Result<FeatureFlagRow, Error> {
@ -310,11 +313,12 @@ pub async fn insert_flag_for_team_in_pg(
}
pub async fn insert_person_for_team_in_pg(
client: Arc<PgPool>,
client: Arc<dyn Client + Send + Sync>,
team_id: i32,
distinct_id: String,
properties: Option<Value>,
) -> Result<(), Error> {
) -> Result<i32, Error> {
// Changed return type to Result<i32, Error>
let payload = match properties {
Some(value) => value,
None => json!({
@ -326,7 +330,7 @@ pub async fn insert_person_for_team_in_pg(
let uuid = Uuid::now_v7();
let mut conn = client.get_connection().await?;
let res = sqlx::query(
let row = sqlx::query(
r#"
WITH inserted_person AS (
INSERT INTO posthog_person (
@ -334,10 +338,11 @@ pub async fn insert_person_for_team_in_pg(
properties_last_operation, team_id, is_user_id, is_identified, uuid, version
)
VALUES ('2023-04-05', $1, '{}', '{}', $2, NULL, true, $3, 0)
RETURNING *
RETURNING id
)
INSERT INTO posthog_persondistinctid (distinct_id, person_id, team_id, version)
VALUES ($4, (SELECT id FROM inserted_person), $5, 0)
RETURNING person_id
"#,
)
.bind(&payload)
@ -345,10 +350,109 @@ pub async fn insert_person_for_team_in_pg(
.bind(uuid)
.bind(&distinct_id)
.bind(team_id)
.fetch_one(&mut *conn)
.await?;
let person_id: i32 = row.get::<i32, _>("person_id");
Ok(person_id)
}
pub async fn insert_cohort_for_team_in_pg(
client: Arc<dyn Client + Send + Sync>,
team_id: i32,
name: Option<String>,
filters: serde_json::Value,
is_static: bool,
) -> Result<Cohort, Error> {
let cohort = Cohort {
id: 0, // Placeholder, will be updated after insertion
name: name.unwrap_or("Test Cohort".to_string()),
description: Some("Description for cohort".to_string()),
team_id,
deleted: false,
filters,
query: None,
version: Some(1),
pending_version: None,
count: None,
is_calculating: false,
is_static,
errors_calculating: 0,
groups: serde_json::json!([]),
created_by_id: None,
};
let mut conn = client.get_connection().await?;
let row: (i32,) = sqlx::query_as(
r#"INSERT INTO posthog_cohort
(name, description, team_id, deleted, filters, query, version, pending_version, count, is_calculating, is_static, errors_calculating, groups, created_by_id) VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
RETURNING id"#,
)
.bind(&cohort.name)
.bind(&cohort.description)
.bind(cohort.team_id)
.bind(cohort.deleted)
.bind(&cohort.filters)
.bind(&cohort.query)
.bind(cohort.version)
.bind(cohort.pending_version)
.bind(cohort.count)
.bind(cohort.is_calculating)
.bind(cohort.is_static)
.bind(cohort.errors_calculating)
.bind(&cohort.groups)
.bind(cohort.created_by_id)
.fetch_one(&mut *conn)
.await?;
// Update the cohort_row with the actual id generated by sqlx
let id = row.0;
Ok(Cohort { id, ..cohort })
}
pub async fn get_person_id_by_distinct_id(
client: Arc<dyn Client + Send + Sync>,
team_id: i32,
distinct_id: &str,
) -> Result<i32, Error> {
let mut conn = client.get_connection().await?;
let row: (i32,) = sqlx::query_as(
r#"SELECT id FROM posthog_person
WHERE team_id = $1 AND id = (
SELECT person_id FROM posthog_persondistinctid
WHERE team_id = $1 AND distinct_id = $2
LIMIT 1
)
LIMIT 1"#,
)
.bind(team_id)
.bind(distinct_id)
.fetch_one(&mut *conn)
.await
.map_err(|_| anyhow::anyhow!("Person not found"))?;
Ok(row.0)
}
pub async fn add_person_to_cohort(
client: Arc<dyn Client + Send + Sync>,
person_id: i32,
cohort_id: i32,
) -> Result<(), Error> {
let mut conn = client.get_connection().await?;
let res = sqlx::query(
r#"INSERT INTO posthog_cohortpeople (cohort_id, person_id)
VALUES ($1, $2)
ON CONFLICT DO NOTHING"#,
)
.bind(cohort_id)
.bind(person_id)
.execute(&mut *conn)
.await?;
assert_eq!(res.rows_affected(), 1);
assert!(res.rows_affected() > 0, "Failed to add person to cohort");
Ok(())
}

View File

@ -1,3 +1,6 @@
use std::sync::Arc;
use feature_flags::cohort_cache::CohortCacheManager;
use feature_flags::feature_flag_match_reason::FeatureFlagMatchReason;
/// These tests are common between all libraries doing local evaluation of feature flags.
/// This ensures there are no mismatches between implementations.
@ -110,6 +113,7 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() {
for (i, result) in results.iter().enumerate().take(1000) {
let postgres_reader = setup_pg_reader_client(None).await;
let postgres_writer = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let distinct_id = format!("distinct_id_{}", i);
@ -118,7 +122,7 @@ async fn it_is_consistent_with_rollout_calculation_for_simple_flags() {
1,
postgres_reader,
postgres_writer,
None,
cohort_cache,
None,
None,
)
@ -1209,6 +1213,7 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() {
for (i, result) in results.iter().enumerate().take(1000) {
let postgres_reader = setup_pg_reader_client(None).await;
let postgres_writer = setup_pg_writer_client(None).await;
let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None));
let distinct_id = format!("distinct_id_{}", i);
let feature_flag_match = FeatureFlagMatcher::new(
@ -1216,7 +1221,7 @@ async fn it_is_consistent_with_rollout_calculation_for_multivariate_flags() {
1,
postgres_reader,
postgres_writer,
None,
cohort_cache,
None,
None,
)