From 9a09d9c9f0c841f42abf0853b38b5922b67f7d23 Mon Sep 17 00:00:00 2001 From: David Newell Date: Tue, 13 Feb 2024 16:36:38 +0000 Subject: [PATCH] update embedding input to remove duplicates --- .../ai/generate_embeddings.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/ee/session_recordings/ai/generate_embeddings.py b/ee/session_recordings/ai/generate_embeddings.py index 5837b9d5f7c..68a4f6c6217 100644 --- a/ee/session_recordings/ai/generate_embeddings.py +++ b/ee/session_recordings/ai/generate_embeddings.py @@ -16,8 +16,9 @@ from ee.session_recordings.ai.utils import ( ) from structlog import get_logger from posthog.clickhouse.client import sync_execute -import datetime -import pytz +from datetime import datetime +from pytz import UTC +from numpy import transpose GENERATE_RECORDING_EMBEDDING_TIMING = Histogram( "posthog_session_recordings_generate_recording_embedding", @@ -154,24 +155,25 @@ def generate_recording_embeddings(session_id: str, team: Team | int) -> List[flo reduce_elements_chain( simplify_window_id(SessionSummaryPromptData(columns=session_events[0], results=session_events[1])) ), - start=datetime.datetime(1970, 1, 1, tzinfo=pytz.UTC), # epoch timestamp + start=datetime(1970, 1, 1, tzinfo=UTC), # epoch timestamp ) ) - processed_sessions_index = processed_sessions.column_index("event") + event_name_index = processed_sessions.column_index("event") current_url_index = processed_sessions.column_index("$current_url") elements_chain_index = processed_sessions.column_index("elements_chain") + processed_sessions = transpose(processed_sessions) + input = ( str(session_metadata) + "\n" + + "\n".join(set(processed_sessions[event_name_index])) + + "\n" + + "\n".join(set(processed_sessions[current_url_index])) + + "\n" + "\n".join( - compact_result( - event_name=result[processed_sessions_index] if processed_sessions_index is not None else "", - current_url=result[current_url_index] if current_url_index is not None else "", - elements_chain=result[elements_chain_index] if elements_chain_index is not None else "", - ) - for result in processed_sessions.results + set(compact_chain(result) for result in processed_sessions[elements_chain_index] if result is not None) ) ) @@ -187,6 +189,5 @@ def generate_recording_embeddings(session_id: str, team: Team | int) -> List[flo return embeddings -def compact_result(event_name: str, current_url: int, elements_chain: Dict[str, str] | str) -> str: - elements_string = elements_chain if isinstance(elements_chain, str) else ", ".join(str(e) for e in elements_chain) - return f"{event_name} {current_url} {elements_string}" +def compact_chain(elements_chain: Dict[str, str] | str) -> str: + return elements_chain if isinstance(elements_chain, str) else ", ".join(str(e) for e in elements_chain)