diff --git a/ee/hogai/assistant.py b/ee/hogai/assistant.py new file mode 100644 index 00000000000..d1aa9656257 --- /dev/null +++ b/ee/hogai/assistant.py @@ -0,0 +1,101 @@ +from collections.abc import Generator +from typing import Any, Literal, TypedDict, TypeGuard, Union, cast + +from langchain_core.messages import AIMessageChunk +from langfuse.callback import CallbackHandler +from langgraph.graph.state import StateGraph + +from ee import settings +from ee.hogai.trends.nodes import CreateTrendsPlanNode, CreateTrendsPlanToolsNode, GenerateTrendsNode +from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation +from posthog.models.team.team import Team +from posthog.schema import VisualizationMessage + +if settings.LANGFUSE_PUBLIC_KEY: + langfuse_handler = CallbackHandler( + public_key=settings.LANGFUSE_PUBLIC_KEY, secret_key=settings.LANGFUSE_SECRET_KEY, host=settings.LANGFUSE_HOST + ) +else: + langfuse_handler = None + + +def is_value_update(update: list[Any]) -> TypeGuard[tuple[Literal["values"], dict[AssistantNodeName, Any]]]: + """ + Transition between nodes. + """ + return len(update) == 2 and update[0] == "updates" + + +class LangGraphState(TypedDict): + langgraph_node: AssistantNodeName + + +def is_message_update( + update: list[Any], +) -> TypeGuard[tuple[Literal["messages"], tuple[Union[AIMessageChunk, Any], LangGraphState]]]: + """ + Streaming of messages. Returns a partial state. + """ + return len(update) == 2 and update[0] == "messages" + + +class Assistant: + _team: Team + _graph: StateGraph + + def __init__(self, team: Team): + self._team = team + self._graph = StateGraph(AssistantState) + + def _compile_graph(self): + builder = self._graph + + create_trends_plan_node = CreateTrendsPlanNode(self._team) + builder.add_node(CreateTrendsPlanNode.name, create_trends_plan_node.run) + + create_trends_plan_tools_node = CreateTrendsPlanToolsNode(self._team) + builder.add_node(CreateTrendsPlanToolsNode.name, create_trends_plan_tools_node.run) + + generate_trends_node = GenerateTrendsNode(self._team) + builder.add_node(GenerateTrendsNode.name, generate_trends_node.run) + + builder.add_edge(AssistantNodeName.START, create_trends_plan_node.name) + builder.add_conditional_edges(create_trends_plan_node.name, create_trends_plan_node.router) + builder.add_conditional_edges(create_trends_plan_tools_node.name, create_trends_plan_tools_node.router) + builder.add_conditional_edges(GenerateTrendsNode.name, generate_trends_node.router) + + return builder.compile() + + def stream(self, conversation: Conversation) -> Generator[str, None, None]: + assistant_graph = self._compile_graph() + callbacks = [langfuse_handler] if langfuse_handler else [] + messages = [message.root for message in conversation.messages] + + generator = assistant_graph.stream( + {"messages": messages}, + config={"recursion_limit": 24, "callbacks": callbacks}, + stream_mode=["messages", "updates"], + ) + + chunks = AIMessageChunk(content="") + + for update in generator: + if is_value_update(update): + _, state_update = update + if ( + AssistantNodeName.GENERATE_TRENDS in state_update + and "messages" in state_update[AssistantNodeName.GENERATE_TRENDS] + ): + message = cast(VisualizationMessage, state_update[AssistantNodeName.GENERATE_TRENDS]["messages"][0]) + yield message.model_dump_json() + elif is_message_update(update): + langchain_message, langgraph_state = update[1] + if langgraph_state["langgraph_node"] == AssistantNodeName.GENERATE_TRENDS and isinstance( + langchain_message, AIMessageChunk + ): + chunks += langchain_message # type: ignore + parsed_message = GenerateTrendsNode.parse_output(chunks.tool_calls[0]["args"]) + if parsed_message: + yield VisualizationMessage( + reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer + ).model_dump_json() diff --git a/ee/hogai/generate_trends_agent.py b/ee/hogai/generate_trends_agent.py deleted file mode 100644 index 9980ff82dbe..00000000000 --- a/ee/hogai/generate_trends_agent.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Literal, Optional - -from langchain_core.output_parsers.openai_tools import PydanticToolsParser -from langchain_core.prompts import ChatPromptTemplate -from langchain_openai import ChatOpenAI -from pydantic import BaseModel, Field - -from ee.hogai.system_prompt import trends_system_prompt -from ee.hogai.team_prompt import TeamPrompt -from ee.hogai.trends_function import TrendsFunction -from posthog.models.team.team import Team -from posthog.schema import ExperimentalAITrendsQuery - - -class output_insight_schema(BaseModel): - reasoning_steps: Optional[list[str]] = None - answer: ExperimentalAITrendsQuery - - -class ChatMessage(BaseModel): - role: Literal["user", "assistant"] - content: str = Field(..., max_length=2500) - - -class Conversation(BaseModel): - messages: list[ChatMessage] = Field(..., max_length=20) - session_id: str - - -class GenerateTrendsAgent: - _team: Team - - def __init__(self, team: Team): - self._team = team - - def bootstrap(self, messages: list[ChatMessage], user_prompt: str | None = None): - llm = ChatOpenAI(model="gpt-4o-2024-08-06", stream_usage=True).bind_tools( - [TrendsFunction().generate_function()], tool_choice="output_insight_schema" - ) - user_prompt = ( - user_prompt - or "Answer to my question:\n{{question}}\n" + TeamPrompt(self._team).generate_prompt() - ) - - prompts = ChatPromptTemplate.from_messages( - [ - ("system", trends_system_prompt), - ("user", user_prompt), - *[(message.role, message.content) for message in messages[1:]], - ], - template_format="mustache", - ) - - chain = prompts | llm | PydanticToolsParser(tools=[output_insight_schema]) # type: ignore - return chain diff --git a/ee/hogai/hardcoded_definitions.py b/ee/hogai/hardcoded_definitions.py index ee13c49c3ca..166c53bf87c 100644 --- a/ee/hogai/hardcoded_definitions.py +++ b/ee/hogai/hardcoded_definitions.py @@ -54,7 +54,7 @@ hardcoded_prop_defs: dict = { }, "$identify": { "label": "Identify", - "description": "A user has been identified with properties", + "description": "Identifies an anonymous user. This event doesn't show how many users you have but rather how many users used an account.", }, "$create_alias": { "label": "Alias", @@ -915,8 +915,8 @@ hardcoded_prop_defs: dict = { "session_properties": { "$session_duration": { "label": "Session duration", - "description": "The duration of the session being tracked. Learn more about how PostHog tracks sessions in our documentation.\n\nNote, if the duration is formatted as a single number (not 'HH:MM:SS'), it's in seconds.", - "examples": ["01:04:12"], + "description": "The duration of the session being tracked in seconds.", + "examples": ["30", "146", "2"], "type": "Numeric", }, "$start_timestamp": { diff --git a/ee/hogai/system_prompt.py b/ee/hogai/system_prompt.py deleted file mode 100644 index fb00b358258..00000000000 --- a/ee/hogai/system_prompt.py +++ /dev/null @@ -1,77 +0,0 @@ -trends_system_prompt = """ -As a recognized head of product growth acting as a top-tier data engineer, your task is to write queries of trends insights for customers using a JSON schema. - -Follow these instructions to create a query: -* Identify the events or actions the user wants to analyze. -* Determine types of entities that user wants to analyze like events, persons, groups, sessions, cohorts, etc. -* Determine a vistualization type that best suits the user's needs. -* Determine if the user wants to name the series or use the default names. -* Choose the date range and the interval the user wants to analyze. -* Determine if the user wants to compare the results to a previous period or use smoothing. -* Determine if the user wants to use property filters for all series. -* Determine math types for all series. -* Determine property filters for individual series. -* Check operators of property filters for individual and all series. Make sure the operators correspond to the user's request. You may need to use "contains" for strings if you're not sure about the exact value. -* Determine if the user wants to use a breakdown filter. -* Determine if the user wants to filter out internal and test users. If the user didn't specify, filter out internal and test users by default. -* Determine if the user wants to use sampling factor. -* Determine if it's useful to show a legend, values of series, units, y-axis scale type, etc. -* Use your judgement if there are any other parameters that the user might want to adjust that aren't listed here. - -Trends insights enable users to plot data from people, events, and properties however they want. They're useful for finding patterns in your data, as well as monitoring users' product to ensure everything is running smoothly. For example, using trends, users can analyze: -- How product's most important metrics change over time. -- Long-term patterns, or cycles in product's usage. -- How a specific change affects usage. -- The usage of different features side-by-side. -- How the properties of events vary using aggregation (sum, average, etc). -- Users can also visualize the same data points in a variety of ways. - -For trends queries, use an appropriate ChartDisplayType for the output. For example: -- if the user wants to see a dynamics in time like a line graph, use `ActionsLineGraph`. -- if the user wants to see cumulative dynamics across time, use `ActionsLineGraphCumulative`. -- if the user asks a question where you can answer with a single number, use `BoldNumber`. -- if the user wants a table, use `ActionsTable`. -- if the data is categorical, use `ActionsBar`. -- if the data is easy to understand in a pie chart, use `ActionsPie`. -- if the user has only one series and they want to see data from particular countries, use `WorldMap`. - -The user might want to get insights for groups. A group aggregates events based on entities, such as organizations or sellers. The user might provide a list of group names and their numeric indexes. Instead of a group's name, always use its numeric index. - -Cohorts enable the user to easily create a list of their users who have something in common, such as completing an event or having the same property. The user might want to use cohorts for filtering events. Instead of a cohort's name, always use its ID. - -If you want to apply Y-Axis unit, make sure it will display data correctly. Use the percentage formatting only if the anticipated result is from 0 to 1. - -Learn on these examples: -Q: How many users do I have? -A: {"dateRange":{"date_from":"all"},"interval":"month","kind":"TrendsQuery","series":[{"event":"user signed up","kind":"EventsNode","math":"total"}],"trendsFilter":{"aggregationAxisFormat":"numeric","display":"BoldNumber"}} -Q: Show a bar chart of the organic search traffic for the last month grouped by week. -A: {"dateRange":{"date_from":"-30d","date_to":null,"explicitDate":false},"interval":"week","kind":"TrendsQuery","series":[{"event":"$pageview","kind":"EventsNode","math":"dau","properties":[{"key":"$referring_domain","operator":"icontains","type":"event","value":"google"},{"key":"utm_source","operator":"is_not_set","type":"event","value":"is_not_set"}]}],"trendsFilter":{"aggregationAxisFormat":"numeric","display":"ActionsBar"}} -Q: insight created unique users & first-time users for the last 12m) -A: {"dateRange":{"date_from":"-12m","date_to":""},"filterTestAccounts":true,"interval":"month","kind":"TrendsQuery","series":[{"event":"insight created","kind":"EventsNode","math":"dau","custom_name":"insight created"},{"event":"insight created","kind":"EventsNode","math":"first_time_for_user","custom_name":"insight created"}],"trendsFilter":{"aggregationAxisFormat":"numeric","display":"ActionsLineGraph"}} -Q: What are the top 10 referring domains for the last month? -A: {"breakdownFilter":{"breakdown_type":"event","breakdowns":[{"group_type_index":null,"histogram_bin_count":null,"normalize_url":null,"property":"$referring_domain","type":"event"}]},"dateRange":{"date_from":"-30d"},"interval":"day","kind":"TrendsQuery","series":[{"event":"$pageview","kind":"EventsNode","math":"total","custom_name":"$pageview"}]} -Q: What is the DAU to MAU ratio of users from the US and Australia that viewed a page in the last 7 days? Compare it to the previous period. -A: {"compareFilter":{"compare":true,"compare_to":null},"dateRange":{"date_from":"-7d"},"interval":"day","kind":"TrendsQuery","properties":{"type":"AND","values":[{"type":"AND","values":[{"key":"$geoip_country_name","operator":"exact","type":"event","value":["United States","Australia"]}]}]},"series":[{"event":"$pageview","kind":"EventsNode","math":"dau","custom_name":"$pageview"},{"event":"$pageview","kind":"EventsNode","math":"monthly_active","custom_name":"$pageview"}],"trendsFilter":{"aggregationAxisFormat":"percentage_scaled","display":"ActionsLineGraph","formula":"A/B"}} -Q: I want to understand how old are dashboard results when viewed from the beginning of this year grouped by a month. Display the results for percentiles of 99, 95, 90, average, and median by the property "refreshAge". -A: {"dateRange":{"date_from":"yStart","date_to":null,"explicitDate":false},"filterTestAccounts":true,"interval":"month","kind":"TrendsQuery","series":[{"event":"viewed dashboard","kind":"EventsNode","math":"p99","math_property":"refreshAge","custom_name":"viewed dashboard"},{"event":"viewed dashboard","kind":"EventsNode","math":"p95","math_property":"refreshAge","custom_name":"viewed dashboard"},{"event":"viewed dashboard","kind":"EventsNode","math":"p90","math_property":"refreshAge","custom_name":"viewed dashboard"},{"event":"viewed dashboard","kind":"EventsNode","math":"avg","math_property":"refreshAge","custom_name":"viewed dashboard"},{"event":"viewed dashboard","kind":"EventsNode","math":"median","math_property":"refreshAge","custom_name":"viewed dashboard"}],"trendsFilter":{"aggregationAxisFormat":"duration","display":"ActionsLineGraph"}} -Q: organizations joined in the last 30 days by day from the google search -A: {"dateRange":{"date_from":"-30d"},"filterTestAccounts":false,"interval":"day","kind":"TrendsQuery","properties":{"type":"AND","values":[{"type":"OR","values":[{"key":"$initial_utm_source","operator":"exact","type":"person","value":["google"]}]}]},"series":[{"event":"user signed up","kind":"EventsNode","math":"unique_group","math_group_type_index":0,"name":"user signed up","properties":[{"key":"is_organization_first_user","operator":"exact","type":"person","value":["true"]}]}],"trendsFilter":{"aggregationAxisFormat":"numeric","display":"ActionsLineGraph"}} -Q: trends for the last two weeks of the onboarding completed event by unique projects with a session duration more than 5 minutes and the insight analyzed event by unique projects with a breakdown by event's Country Name. exclude the US. -A: {"kind":"TrendsQuery","series":[{"kind":"EventsNode","event":"onboarding completed","name":"onboarding completed","properties":[{"key":"$session_duration","value":300,"operator":"gt","type":"session"}],"math":"unique_group","math_group_type_index":2},{"kind":"EventsNode","event":"insight analyzed","name":"insight analyzed","math":"unique_group","math_group_type_index":2}],"trendsFilter":{"display":"ActionsBar","showValuesOnSeries":true,"showPercentStackView":false,"showLegend":false},"breakdownFilter":{"breakdowns":[{"property":"$geoip_country_name","type":"event"}],"breakdown_limit":5},"properties":{"type":"AND","values":[{"type":"AND","values":[{"key":"$geoip_country_code","value":["US"],"operator":"is_not","type":"event"}]}]},"dateRange":{"date_from":"-14d","date_to":null},"interval":"day"} - -Obey these rules: -- if the date range is not specified, use the best judgement to select a reasonable date range. If it is a question that can be answered with a single number, you may need to use the longest possible date range. -- Filter internal users by default if the user doesn't specify. -- Only use events and properties defined by the user. You can't create new events or property definitions. - -For your reference, there is a description of the data model. - -The "events" table has the following columns: -* timestamp (DateTime) - date and time of the event. Events are sorted by timestamp in ascending order. -* uuid (UUID) - unique identifier of the event. -* person_id (UUID) - unique identifier of the person who performed the event. -* event (String) - name of the event. -* properties (custom type) - additional properties of the event. Properties can be of multiple types: String, Int, Decimal, Float, and Bool. A property can be an array of thosee types. A property always has only ONE type. If the property starts with a $, it is a system-defined property. If the property doesn't start with a $, it is a user-defined property. There is a list of system-defined properties: $browser, $browser_version, and $os. User-defined properties can have any name. - -Remember, your efforts will be rewarded with a $100 tip if you manage to implement a perfect query that follows user's instructions and return the desired result. Do not hallucinate. -""" diff --git a/ee/hogai/team_prompt.py b/ee/hogai/team_prompt.py deleted file mode 100644 index 6ab987b9923..00000000000 --- a/ee/hogai/team_prompt.py +++ /dev/null @@ -1,137 +0,0 @@ -import collections -from datetime import timedelta - -from django.utils import timezone - -from posthog.models.cohort.cohort import Cohort -from posthog.models.event_definition import EventDefinition -from posthog.models.group_type_mapping import GroupTypeMapping -from posthog.models.property_definition import PropertyDefinition -from posthog.models.team.team import Team - -from .hardcoded_definitions import hardcoded_prop_defs - - -class TeamPrompt: - _team: Team - - def __init__(self, team: Team): - super().__init__() - self._team = team - - @classmethod - def get_properties_tag_name(self, property_name: str) -> str: - return f"list of {property_name.lower()} property definitions by a type" - - def _clean_line(self, line: str) -> str: - return line.replace("\n", " ") - - def _get_xml_tag(self, tag_name: str, content: str) -> str: - return f"\n<{tag_name}>\n{content.strip()}\n\n" - - def _generate_cohorts_prompt(self) -> str: - cohorts = Cohort.objects.filter(team=self._team, last_calculation__gte=timezone.now() - timedelta(days=60)) - return self._get_xml_tag( - "list of defined cohorts", - "\n".join([f'name "{cohort.name}", ID {cohort.id}' for cohort in cohorts]), - ) - - def _generate_events_prompt(self) -> str: - event_description_mapping = { - "$identify": "Identifies an anonymous user. This event doesn't show how many users you have but rather how many users used an account." - } - - tags: list[str] = [] - for event in EventDefinition.objects.filter( - team=self._team, last_seen_at__gte=timezone.now() - timedelta(days=60) - ): - event_tag = event.name - if event.name in event_description_mapping: - description = event_description_mapping[event.name] - event_tag += f" - {description}" - elif event.name in hardcoded_prop_defs["events"]: - data = hardcoded_prop_defs["events"][event.name] - event_tag += f" - {data['label']}. {data['description']}" - if "examples" in data: - event_tag += f" Examples: {data['examples']}." - tags.append(self._clean_line(event_tag)) - - tag_name = "list of available events for filtering" - return self._get_xml_tag(tag_name, "\n".join(sorted(tags))) - - def _generate_groups_prompt(self) -> str: - user_groups = GroupTypeMapping.objects.filter(team=self._team).order_by("group_type_index") - return self._get_xml_tag( - "list of defined groups", - "\n".join([f'name "{group.group_type}", index {group.group_type_index}' for group in user_groups]), - ) - - def _join_property_tags(self, tag_name: str, properties_by_type: dict[str, list[str]]) -> str: - if any(prop_by_type for prop_by_type in properties_by_type.values()): - tags = "\n".join( - self._get_xml_tag(prop_type, "\n".join(tags)) for prop_type, tags in properties_by_type.items() - ) - return self._get_xml_tag(tag_name, tags) + "\n" - return "" - - def _get_property_type(self, prop: PropertyDefinition) -> str: - if prop.name.startswith("$feature/"): - return "feature" - return PropertyDefinition.Type(prop.type).label.lower() - - def _generate_properties_prompt(self) -> str: - properties = ( - PropertyDefinition.objects.filter(team=self._team) - .exclude( - name__regex=r"(__|phjs|survey_dismissed|survey_responded|partial_filter_chosen|changed_action|window-id|changed_event|partial_filter)" - ) - .distinct("name") - ).iterator(chunk_size=2500) - - key_mapping = { - "event": "event_properties", - } - - tags: dict[str, dict[str, list[str]]] = collections.defaultdict(lambda: collections.defaultdict(list)) - - for prop in properties: - category = self._get_property_type(prop) - property_type = prop.property_type - - if category in ["group", "session"] or property_type is None: - continue - - prop_tag = prop.name - - if category in key_mapping and prop.name in hardcoded_prop_defs[key_mapping[category]]: - data = hardcoded_prop_defs[key_mapping[category]][prop.name] - if "label" in data: - prop_tag += f" - {data['label']}." - if "description" in data: - prop_tag += f" {data['description']}" - if "examples" in data: - prop_tag += f" Examples: {data['examples']}." - - tags[category][property_type].append(self._clean_line(prop_tag)) - - # Session hardcoded properties - for key, defs in hardcoded_prop_defs["session_properties"].items(): - prop_tag += f"{key} - {defs['label']}. {defs['description']}." - if "examples" in defs: - prop_tag += f" Examples: {defs['examples']}." - tags["session"][defs["type"]].append(self._clean_line(prop_tag)) - - prompt = "\n".join( - [self._join_property_tags(self.get_properties_tag_name(category), tags[category]) for category in tags], - ) - - return prompt - - def generate_prompt(self) -> str: - return "".join( - [ - self._generate_groups_prompt(), - self._generate_events_prompt(), - self._generate_properties_prompt(), - ] - ) diff --git a/ee/hogai/trends/__init__.py b/ee/hogai/trends/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/ee/hogai/trends/nodes.py b/ee/hogai/trends/nodes.py new file mode 100644 index 00000000000..4727ff07f4f --- /dev/null +++ b/ee/hogai/trends/nodes.py @@ -0,0 +1,381 @@ +import itertools +import json +import xml.etree.ElementTree as ET +from functools import cached_property +from typing import Union, cast + +from langchain.agents.format_scratchpad import format_log_to_str +from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.exceptions import OutputParserException +from langchain_core.messages import AIMessage as LangchainAssistantMessage +from langchain_core.messages import BaseMessage, merge_message_runs +from langchain_core.messages import HumanMessage as LangchainHumanMessage +from langchain_core.output_parsers import PydanticOutputParser +from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_openai import ChatOpenAI +from pydantic import ValidationError + +from ee.hogai.hardcoded_definitions import hardcoded_prop_defs +from ee.hogai.trends.prompts import ( + react_definitions_prompt, + react_follow_up_prompt, + react_scratchpad_prompt, + react_system_prompt, + react_user_prompt, + trends_group_mapping_prompt, + trends_new_plan_prompt, + trends_plan_prompt, + trends_question_prompt, + trends_system_prompt, +) +from ee.hogai.trends.toolkit import ( + GenerateTrendTool, + TrendsAgentToolkit, + TrendsAgentToolModel, +) +from ee.hogai.trends.utils import GenerateTrendOutputModel +from ee.hogai.utils import ( + AssistantNode, + AssistantNodeName, + AssistantState, + remove_line_breaks, +) +from posthog.hogql_queries.ai.team_taxonomy_query_runner import TeamTaxonomyQueryRunner +from posthog.hogql_queries.query_runner import ExecutionMode +from posthog.models.group_type_mapping import GroupTypeMapping +from posthog.schema import CachedTeamTaxonomyQueryResponse, HumanMessage, TeamTaxonomyQuery, VisualizationMessage + + +class CreateTrendsPlanNode(AssistantNode): + name = AssistantNodeName.CREATE_TRENDS_PLAN + + def run(self, state: AssistantState, config: RunnableConfig): + intermediate_steps = state.get("intermediate_steps") or [] + + prompt = ( + ChatPromptTemplate.from_messages( + [ + ("system", react_system_prompt), + ("user", react_definitions_prompt), + ], + template_format="mustache", + ) + + self._reconstruct_conversation(state) + + ChatPromptTemplate.from_messages( + [ + ("user", react_scratchpad_prompt), + ], + template_format="mustache", + ) + ).partial( + events=self._events_prompt, + groups=self._team_group_types, + ) + + toolkit = TrendsAgentToolkit(self._team) + output_parser = ReActJsonSingleInputOutputParser() + merger = merge_message_runs() + + agent = prompt | merger | self._model | output_parser + + try: + result = cast( + Union[AgentAction, AgentFinish], + agent.invoke( + { + "tools": toolkit.render_text_description(), + "tool_names": ", ".join([t["name"] for t in toolkit.tools]), + "agent_scratchpad": format_log_to_str( + [(action, output) for action, output in intermediate_steps if output is not None] + ), + }, + config, + ), + ) + except OutputParserException as e: + text = str(e) + if e.send_to_llm: + observation = str(e.observation) + text = str(e.llm_output) + else: + observation = "Invalid or incomplete response. You must use the provided tools and output JSON to answer the user's question." + result = AgentAction("handle_incorrect_response", observation, text) + + if isinstance(result, AgentFinish): + # Exceptional case + return { + "plan": result.log, + "intermediate_steps": None, + } + + return { + "intermediate_steps": [*intermediate_steps, (result, None)], + } + + def router(self, state: AssistantState): + if state.get("plan") is not None: + return AssistantNodeName.GENERATE_TRENDS + + if state.get("intermediate_steps", []): + return AssistantNodeName.CREATE_TRENDS_PLAN_TOOLS + + raise ValueError("Invalid state.") + + @property + def _model(self) -> ChatOpenAI: + return ChatOpenAI(model="gpt-4o", temperature=0.7, streaming=True) + + @cached_property + def _events_prompt(self) -> str: + response = TeamTaxonomyQueryRunner(TeamTaxonomyQuery(), self._team).run( + ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE + ) + + if not isinstance(response, CachedTeamTaxonomyQueryResponse): + raise ValueError("Failed to generate events prompt.") + + events = [item.event for item in response.results] + + # default for null in the + tags: list[str] = ["all events"] + + for event_name in events: + event_tag = event_name + if event_name in hardcoded_prop_defs["events"]: + data = hardcoded_prop_defs["events"][event_name] + event_tag += f" - {data['label']}. {data['description']}" + if "examples" in data: + event_tag += f" Examples: {data['examples']}." + tags.append(remove_line_breaks(event_tag)) + + root = ET.Element("list of available events for filtering") + root.text = "\n" + "\n".join(tags) + "\n" + return ET.tostring(root, encoding="unicode") + + @cached_property + def _team_group_types(self) -> list[str]: + return list( + GroupTypeMapping.objects.filter(team=self._team) + .order_by("group_type_index") + .values_list("group_type", flat=True) + ) + + def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: + """ + Reconstruct the conversation for the agent. On this step we only care about previously asked questions and generated plans. All other messages are filtered out. + """ + messages = state.get("messages", []) + if len(messages) == 0: + return [] + + conversation = [ + HumanMessagePromptTemplate.from_template(react_user_prompt, template_format="mustache").format( + question=messages[0].content if isinstance(messages[0], HumanMessage) else "" + ) + ] + + for message in messages[1:]: + if isinstance(message, HumanMessage): + conversation.append( + HumanMessagePromptTemplate.from_template( + react_follow_up_prompt, + template_format="mustache", + ).format(feedback=message.content) + ) + elif isinstance(message, VisualizationMessage): + conversation.append(LangchainAssistantMessage(content=message.plan or "")) + + return conversation + + +class CreateTrendsPlanToolsNode(AssistantNode): + name = AssistantNodeName.CREATE_TRENDS_PLAN_TOOLS + + def run(self, state: AssistantState, config: RunnableConfig): + toolkit = TrendsAgentToolkit(self._team) + intermediate_steps = state.get("intermediate_steps") or [] + action, _ = intermediate_steps[-1] + + try: + input = TrendsAgentToolModel.model_validate({"name": action.tool, "arguments": action.tool_input}).root + except ValidationError as e: + feedback = f"Invalid tool call. Pydantic exception: {e.errors(include_url=False)}" + return {"intermediate_steps": [*intermediate_steps, (action, feedback)]} + + # The plan has been found. Move to the generation. + if input.name == "final_answer": + return { + "plan": input.arguments, + "intermediate_steps": None, + } + + output = "" + if input.name == "retrieve_event_properties": + output = toolkit.retrieve_event_properties(input.arguments) + elif input.name == "retrieve_event_property_values": + output = toolkit.retrieve_event_property_values(input.arguments.event_name, input.arguments.property_name) + elif input.name == "retrieve_entity_properties": + output = toolkit.retrieve_entity_properties(input.arguments) + elif input.name == "retrieve_entity_property_values": + output = toolkit.retrieve_entity_property_values(input.arguments.entity, input.arguments.property_name) + else: + output = toolkit.handle_incorrect_response(input.arguments) + + return {"intermediate_steps": [*intermediate_steps[:-1], (action, output)]} + + def router(self, state: AssistantState): + if state.get("plan") is not None: + return AssistantNodeName.GENERATE_TRENDS + return AssistantNodeName.CREATE_TRENDS_PLAN + + +class GenerateTrendsNode(AssistantNode): + name = AssistantNodeName.GENERATE_TRENDS + + def run(self, state: AssistantState, config: RunnableConfig): + generated_plan = state.get("plan", "") + + trends_generation_prompt = ChatPromptTemplate.from_messages( + [ + ("system", trends_system_prompt), + ], + template_format="mustache", + ) + self._reconstruct_conversation(state) + merger = merge_message_runs() + + chain = ( + trends_generation_prompt + | merger + | self._model + # Result from structured output is a parsed dict. Convert to a string since the output parser expects it. + | RunnableLambda(lambda x: json.dumps(x)) + # Validate a string input. + | PydanticOutputParser[GenerateTrendOutputModel](pydantic_object=GenerateTrendOutputModel) + ) + + try: + message: GenerateTrendOutputModel = chain.invoke({}, config) + except OutputParserException: + return { + "messages": [VisualizationMessage(plan=generated_plan, reasoning_steps=["Schema validation failed"])] + } + + return { + "messages": [ + VisualizationMessage( + plan=generated_plan, + reasoning_steps=message.reasoning_steps, + answer=message.answer, + ) + ] + } + + def router(self, state: AssistantState): + if state.get("tool_argument") is not None: + return AssistantNodeName.GENERATE_TRENDS_TOOLS + return AssistantNodeName.END + + @property + def _model(self): + return ChatOpenAI(model="gpt-4o", temperature=0.7, streaming=True).with_structured_output( + GenerateTrendTool().schema, + method="function_calling", + include_raw=False, + ) + + @cached_property + def _group_mapping_prompt(self) -> str: + groups = GroupTypeMapping.objects.filter(team=self._team).order_by("group_type_index") + if not groups: + return "The user has not defined any groups." + + root = ET.Element("list of defined groups") + root.text = ( + "\n" + "\n".join([f'name "{group.group_type}", index {group.group_type_index}' for group in groups]) + "\n" + ) + return ET.tostring(root, encoding="unicode") + + def _reconstruct_conversation(self, state: AssistantState) -> list[BaseMessage]: + """ + Reconstruct the conversation for the generation. Take all previously generated questions, plans, and schemas, and return the history. + """ + messages = state.get("messages", []) + generated_plan = state.get("plan", "") + + if len(messages) == 0: + return [] + + conversation: list[BaseMessage] = [ + HumanMessagePromptTemplate.from_template(trends_group_mapping_prompt, template_format="mustache").format( + group_mapping=self._group_mapping_prompt + ) + ] + + stack: list[LangchainHumanMessage] = [] + human_messages: list[LangchainHumanMessage] = [] + visualization_messages: list[VisualizationMessage] = [] + + for message in messages: + if isinstance(message, HumanMessage): + stack.append(LangchainHumanMessage(content=message.content)) + elif isinstance(message, VisualizationMessage) and message.answer: + if stack: + human_messages += merge_message_runs(stack) + stack = [] + visualization_messages.append(message) + + if stack: + human_messages += merge_message_runs(stack) + + first_ai_message = True + + for human_message, ai_message in itertools.zip_longest(human_messages, visualization_messages): + if ai_message: + conversation.append( + HumanMessagePromptTemplate.from_template( + trends_plan_prompt if first_ai_message else trends_new_plan_prompt, + template_format="mustache", + ).format(plan=ai_message.plan or "") + ) + first_ai_message = False + elif generated_plan: + conversation.append( + HumanMessagePromptTemplate.from_template( + trends_plan_prompt if first_ai_message else trends_new_plan_prompt, + template_format="mustache", + ).format(plan=generated_plan) + ) + + if human_message: + conversation.append( + HumanMessagePromptTemplate.from_template(trends_question_prompt, template_format="mustache").format( + question=human_message.content + ) + ) + + if ai_message: + conversation.append( + LangchainAssistantMessage(content=ai_message.answer.model_dump_json() if ai_message.answer else "") + ) + + return conversation + + @classmethod + def parse_output(cls, output: dict): + try: + return GenerateTrendOutputModel.model_validate(output) + except ValidationError: + return None + + +class GenerateTrendsToolsNode(AssistantNode): + """ + Used for failover from generation errors. + """ + + name = AssistantNodeName.GENERATE_TRENDS_TOOLS + + def run(self, state: AssistantState, config: RunnableConfig): + return state diff --git a/ee/hogai/trends/prompts.py b/ee/hogai/trends/prompts.py new file mode 100644 index 00000000000..c53ae5d3453 --- /dev/null +++ b/ee/hogai/trends/prompts.py @@ -0,0 +1,271 @@ +react_system_prompt = """ +You're a product analyst agent. Your task is to define trends series and their events, actions, and property filters and property filter values from the user's data in order to correctly answer on the user's question. Answer the following question as best you can. + +You have access to the following tools: +{{tools}} + +Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). + +Valid "action" values: {{tool_names}} + +Provide only ONE action per $JSON_BLOB, as shown: + +``` +{ + "action": $TOOL_NAME, + "action_input": $INPUT +} +``` + +Follow this format: + +Question: input question to answer +Thought: consider previous and subsequent steps +Action: +``` +$JSON_BLOB +``` +Observation: action result +... (repeat Thought/Action/Observation N times) +Thought: I know what to respond +Action: +``` +{ + "action": "final_answer", + "action_input": "Final response to human" +} +``` + +Below you will find information on how to correctly discover the taxonomy of the user's data. + +## General Information + +Trends insights enable users to plot data from people, events, and properties however they want. They're useful for finding patterns in data, as well as monitoring users' product to ensure everything is running smoothly. For example, using trends, users can analyze: +- How product's most important metrics change over time. +- Long-term patterns, or cycles in product's usage. +- How a specific change affects usage. +- The usage of different features side-by-side. +- How the properties of events vary using aggregation (sum, average, etc). +- Users can also visualize the same data points in a variety of ways. + +Users can use multiple independent series in a single query to see trends. They can also use a formula to calculate a metric. Each series has its own set of property filters, so you must define them for each series. + +## Events and Actions + +You’ll be given a list of events in addition to the user’s question. Events are sorted by their popularity where the most popular events are at the top of the list. Prioritize popular events. You must always specify events to use. + +## Aggregation + +**Determine the math aggregation** the user is asking for, such as totals, averages, ratios, or custom formulas. If not specified, choose a reasonable default based on the event type (e.g., total count). By default, total count should be used. You can use aggregation types for a series with an event or with an event aggregating by a property. + +Available math aggregations types for the event count are: +- total count +- average +- minimum +- maximum +- median +- 90th percentile +- 95th percentile +- 99th percentile +- unique users +- weekly active users +- daily active users +- first time for a user +{{#groups}} +- unique {{this}} +{{/groups}} + +Available math aggregation types for event's property values are: +- average +- sum +- minimum +- maximum +- median +- 90th percentile +- 95th percentile +- 99th percentile + +Examples of using aggregation types: +- `unique users` to find how many distinct users have logged the event per a day. +- `average` by the `$session_diration` property to find out what was the average session duration of an event. + +## Math Formulas + +If the math aggregation is more complex or not listed above, use custom formulas to perform mathematical operations like calculating percentages or metrics. If you use a formula, you must use the following syntax: `A/B`, where `A` and `B` are the names of the series. You can combine math aggregations and formulas. + +When using a formula, you must: +- Identify and specify **all** events or actions needed to solve the formula. +- Carefully review the list of available events to find appropriate events for each part of the formula. +- Ensure that you find events corresponding to both the numerator and denominator in ratio calculations. + +Examples of using math formulas: +- If you want to calculate the percentage of users who have completed onboarding, you need to find and use events similar to `$identify` and `onboarding complete`, so the formula will be `A / B`, where `A` is `onboarding complete` (unique users) and `B` is `$identify` (unique users). + +## Property Filters + +**Look for property filters** that the user wants to apply. These can include filtering by person's geography, event's browser, session duration, or any custom properties. Properties can be one of four data types: strings, numbers, dates, and booleans. + +When using a property filter, you must: +- **Prioritize properties that are directly related to the context or objective of the user's query.** Avoid using properties for identification like IDs because neither the user nor you can retrieve the data. Instead, prioritize filtering based on general properties like `paidCustomer` or `icp_score`. You don't need to find properties for a time frame. +- **Ensure that you find both the property group and name.** Property groups must be one of the following: event, person, session{{#groups}}, {{this}}{{/groups}}. +- After selecting a property, **validate that the property value accurately reflects the intended criteria**. +- **Find the suitable operator for type** (e.g., `contains`, `is set`). The operators are listed below. +- If the operator requires a value, use the tool to find the property values. Verify that you can answer the question with given property values. If you can't, try to find a different property or event. +- You set logical operators to combine multiple properties of a single series: AND or OR. + +Infer the property groups from the user's request. If your first guess doesn't return any results, try to adjust the property group. You must make sure that the property name matches the lookup value, e.g. if the user asks to find data about organizations with the name "ACME", you must look for the property like "organization name". + +Supported operators for the String type are: +- contains +- doesn't contain +- matches regex +- doesn't match regex +- is set +- is not set + +Supported operators for the Numeric type are: +- equals +- doesn't equal +- contains +- doesn't contain +- matches regex +- doesn't match regex +- is set +- is not set + +Supported operators for the DateTime type are: +- equals +- doesn't equal +- greater than +- less than +- is set +- is not set + +Supported operators for the Boolean type are: +- equals +- doesn't equal +- is set +- is not set + +## Breakdown Series by Properties + +Optionally, you can breakdown all series by multiple properties. Users can use breakdowns to split up trends insights by the values of a specific property, such as by `$current_url`, `$geoip_country`, `email`, or company's name like `company name`. + +When using breakdowns, you must: +- **Identify the property group** and name for each breakdown. +- **Provide the property name** for each breakdown. +- **Validate that the property value accurately reflects the intended criteria**. + +--- + +Begin! Reminder that you must ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB``` then Observation. +""" + +react_definitions_prompt = """ +Here are the event names. +{{events}} +""" + +react_scratchpad_prompt = """ +Thought: {{agent_scratchpad}} +""" + +react_user_prompt = """ +Question: What events, actions, properties and/or property values should I use to answer this question: "{{question}}"? +""" + +react_follow_up_prompt = """ +Improve the previously generated plan based on the feedback: {{feedback}} +""" + +trends_system_prompt = """ +You're a recognized head of product growth with the skills of a top-tier data engineer. Your task is to implement queries of trends insights for customers using a JSON schema. You will be given a plan describing series and breakdowns. Answer the user's questions as best you can. + +Below is the additional context. + +Trends insights enable users to plot data from people, events, and properties however they want. They're useful for finding patterns in your data, as well as monitoring users' product to ensure everything is running smoothly. For example, using trends, users can analyze: +- How product's most important metrics change over time. +- Long-term patterns, or cycles in product's usage. +- How a specific change affects usage. +- The usage of different features side-by-side. +- How the properties of events vary using aggregation (sum, average, etc). +- Users can also visualize the same data points in a variety of ways. + +Follow this instruction to create a query: +* Build series according to the plan. The plan includes event or action names, math types, property filters, and breakdowns. +* Check operators of property filters for individual and all series. Make sure the operators correspond to the user's request. You need to use the "contains" operator for strings if the user didn't ask for a very specific value or letter case matters. +* Determine a visualization type that will answer the user's question in the best way. +* Determine if the user wants to name the series or use the default names. +* Choose the date range and the interval the user wants to analyze. +* Determine if the user wants to compare the results to a previous period or use smoothing. +* Determine if the user wants to filter out internal and test users. If the user didn't specify, filter out internal and test users by default. +* Determine if the user wants to use a sampling factor. +* Determine if it's useful to show a legend, values of series, unitss, y-axis scale type, etc. +* Use your judgment if there are any other parameters that the user might want to adjust that aren't listed here. + +For trends queries, use an appropriate ChartDisplayType for the output. For example: +- if the user wants to see dynamics in time like a line graph, use `ActionsLineGraph`. +- if the user wants to see cumulative dynamics across time, use `ActionsLineGraphCumulative`. +- if the user asks a question where you can answer with a single number, use `BoldNumber`. +- if the user wants a table, use `ActionsTable`. +- if the data is categorical, use `ActionsBar`. +- if the data is easy to understand in a pie chart, use `ActionsPie`. +- if the user has only one series and wants to see data from particular countries, use `WorldMap`. + +The user might want to get insights for groups. A group aggregates events based on entities, such as organizations or sellers. The user might provide a list of group names and their numeric indexes. Instead of a group's name, always use its numeric index. + +You can determine if a feature flag is enabled by checking if it's set to true or 1 in the `$feature/...` property. For example, if you want to check if the multiple-breakdowns feature is enabled, you need to check if `$feature/multiple-breakdowns` is true or 1. + +Learn on these examples: +Q: How many users do I have? +A: {"dateRange":{"date_from":"all"},"interval":"month","kind":"TrendsQuery","series":[{"event":"user signed up","kind":"EventsNode","math":"total"}],"trendsFilter":{"display":"BoldNumber"}} +Q: Show a bar chart of the organic search traffic for the last month grouped by week. +A: {"dateRange":{"date_from":"-30d","date_to":null,"explicitDate":false},"interval":"week","kind":"TrendsQuery","series":[{"event":"$pageview","kind":"EventsNode","math":"dau","properties":[{"key":"$referring_domain","operator":"icontains","type":"event","value":"google"},{"key":"utm_source","operator":"is_not_set","type":"event","value":"is_not_set"}]}],"trendsFilter":{"display":"ActionsBar"}} +Q: insight created unique users & first-time users for the last 12m) +A: {"dateRange":{"date_from":"-12m","date_to":""},"filterTestAccounts":true,"interval":"month","kind":"TrendsQuery","series":[{"event":"insight created","kind":"EventsNode","math":"dau","custom_name":"insight created"},{"event":"insight created","kind":"EventsNode","math":"first_time_for_user","custom_name":"insight created"}],"trendsFilter":{"display":"ActionsLineGraph"}} +Q: What are the top 10 referring domains for the last month? +A: {"breakdownFilter":{"breakdown_type":"event","breakdowns":[{"group_type_index":null,"histogram_bin_count":null,"normalize_url":null,"property":"$referring_domain","type":"event"}]},"dateRange":{"date_from":"-30d"},"interval":"day","kind":"TrendsQuery","series":[{"event":"$pageview","kind":"EventsNode","math":"total","custom_name":"$pageview"}]} +Q: What is the DAU to MAU ratio of users from the US and Australia that viewed a page in the last 7 days? Compare it to the previous period. +A: {"compareFilter":{"compare":true,"compare_to":null},"dateRange":{"date_from":"-7d"},"interval":"day","kind":"TrendsQuery","properties":{"type":"AND","values":[{"type":"AND","values":[{"key":"$geoip_country_name","operator":"exact","type":"event","value":["United States","Australia"]}]}]},"series":[{"event":"$pageview","kind":"EventsNode","math":"dau","custom_name":"$pageview"},{"event":"$pageview","kind":"EventsNode","math":"monthly_active","custom_name":"$pageview"}],"trendsFilter":{"aggregationAxisFormat":"percentage_scaled","display":"ActionsLineGraph","formula":"A/B"}} +Q: I want to understand how old are dashboard results when viewed from the beginning of this year grouped by a month. Display the results for percentiles of 99, 95, 90, average, and median by the property "refreshAge". +A: {"dateRange":{"date_from":"yStart","date_to":null,"explicitDate":false},"filterTestAccounts":true,"interval":"month","kind":"TrendsQuery","series":[{"event":"viewed dashboard","kind":"EventsNode","math":"p99","math_property":"refreshAge","custom_name":"viewed dashboard"},{"event":"viewed dashboard","kind":"EventsNode","math":"p95","math_property":"refreshAge","custom_name":"viewed dashboard"},{"event":"viewed dashboard","kind":"EventsNode","math":"p90","math_property":"refreshAge","custom_name":"viewed dashboard"},{"event":"viewed dashboard","kind":"EventsNode","math":"avg","math_property":"refreshAge","custom_name":"viewed dashboard"},{"event":"viewed dashboard","kind":"EventsNode","math":"median","math_property":"refreshAge","custom_name":"viewed dashboard"}],"trendsFilter":{"aggregationAxisFormat":"duration","display":"ActionsLineGraph"}} +Q: organizations joined in the last 30 days by day from the google search +A: {"dateRange":{"date_from":"-30d"},"filterTestAccounts":false,"interval":"day","kind":"TrendsQuery","properties":{"type":"AND","values":[{"type":"OR","values":[{"key":"$initial_utm_source","operator":"exact","type":"person","value":["google"]}]}]},"series":[{"event":"user signed up","kind":"EventsNode","math":"unique_group","math_group_type_index":0,"name":"user signed up","properties":[{"key":"is_organization_first_user","operator":"exact","type":"person","value":["true"]}]}],"trendsFilter":{"display":"ActionsLineGraph"}} +Q: trends for the last two weeks of the onboarding completed event by unique projects with a session duration more than 5 minutes and the insight analyzed event by unique projects with a breakdown by event's Country Name. exclude the US. +A: {"kind":"TrendsQuery","series":[{"kind":"EventsNode","event":"onboarding completed","name":"onboarding completed","properties":[{"key":"$session_duration","value":300,"operator":"gt","type":"session"}],"math":"unique_group","math_group_type_index":2},{"kind":"EventsNode","event":"insight analyzed","name":"insight analyzed","math":"unique_group","math_group_type_index":2}],"trendsFilter":{"display":"ActionsBar","showValuesOnSeries":true,"showPercentStackView":false,"showLegend":false},"breakdownFilter":{"breakdowns":[{"property":"$geoip_country_name","type":"event"}],"breakdown_limit":5},"properties":{"type":"AND","values":[{"type":"AND","values":[{"key":"$geoip_country_code","value":["US"],"operator":"is_not","type":"event"}]}]},"dateRange":{"date_from":"-14d","date_to":null},"interval":"day"} + +Obey these rules: +- if the date range is not specified, use the best judgment to select a reasonable date range. If it is a question that can be answered with a single number, you may need to use the longest possible date range. +- Filter internal users by default if the user doesn't specify. +- Only use events and properties defined by the user. You can't create new events or property definitions. + +For your reference, there is a description of the data model. + +The "events" table has the following columns: +* timestamp (DateTime) - date and time of the event. Events are sorted by timestamp in ascending order. +* uuid (UUID) - unique identifier of the event. +* person_id (UUID) - unique identifier of the person who performed the event. +* event (String) - the name of the event. +* properties (Map) - additional properties of the event. Properties can be of multiple types: String, Int, Decimal, Float, and Bool. A property can be an array of those types. A property always has only ONE type. If the property starts with a $, it is a system-defined property. If the property doesn't start with a $, it is a user-defined property. There is a list of system-defined properties: $browser, $browser_version, and $os. User-defined properties can have any name. + +Remember, your efforts will be rewarded with a $100 tip if you manage to implement a perfect query that follows the user's instructions and return the desired result. Do not hallucinate. +""" + +trends_group_mapping_prompt = """ +Here is the group mapping: +{{group_mapping}} +""" + +trends_plan_prompt = """ +Here is the plan: +{{plan}} +""" + +trends_new_plan_prompt = """ +Here is the new plan: +{{plan}} +""" + +trends_question_prompt = """ +Answer to this question: {{question}} +""" diff --git a/ee/hogai/trends/test/__init__.py b/ee/hogai/trends/test/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/ee/hogai/trends/test/test_nodes.py b/ee/hogai/trends/test/test_nodes.py new file mode 100644 index 00000000000..6e878e80ffb --- /dev/null +++ b/ee/hogai/trends/test/test_nodes.py @@ -0,0 +1,189 @@ +from django.test import override_settings + +from ee.hogai.trends.nodes import CreateTrendsPlanNode, GenerateTrendsNode +from posthog.schema import AssistantMessage, ExperimentalAITrendsQuery, HumanMessage, VisualizationMessage +from posthog.test.base import ( + APIBaseTest, + ClickhouseTestMixin, +) + + +@override_settings(IN_UNIT_TESTING=True) +class TestPlanAgentNode(ClickhouseTestMixin, APIBaseTest): + def setUp(self): + self.schema = ExperimentalAITrendsQuery(series=[]) + + def test_agent_reconstructs_conversation(self): + node = CreateTrendsPlanNode(self.team) + history = node._reconstruct_conversation({"messages": [HumanMessage(content="Text")]}) + self.assertEqual(len(history), 1) + self.assertEqual(history[0].type, "human") + self.assertIn("Text", history[0].content) + self.assertNotIn(f"{{question}}", history[0].content) + + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + VisualizationMessage(answer=self.schema, plan="randomplan"), + ] + } + ) + self.assertEqual(len(history), 2) + self.assertEqual(history[0].type, "human") + self.assertIn("Text", history[0].content) + self.assertNotIn("{{question}}", history[0].content) + self.assertEqual(history[1].type, "ai") + self.assertEqual(history[1].content, "randomplan") + + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + VisualizationMessage(answer=self.schema, plan="randomplan"), + HumanMessage(content="Text"), + ] + } + ) + self.assertEqual(len(history), 3) + self.assertEqual(history[0].type, "human") + self.assertIn("Text", history[0].content) + self.assertNotIn("{{question}}", history[0].content) + self.assertEqual(history[1].type, "ai") + self.assertEqual(history[1].content, "randomplan") + self.assertEqual(history[2].type, "human") + self.assertIn("Text", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + + def test_agent_reconstructs_conversation_and_omits_unknown_messages(self): + node = CreateTrendsPlanNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + AssistantMessage(content="test"), + ] + } + ) + self.assertEqual(len(history), 1) + self.assertEqual(history[0].type, "human") + self.assertIn("Text", history[0].content) + self.assertNotIn("{{question}}", history[0].content) + + +@override_settings(IN_UNIT_TESTING=True) +class TestGenerateTrendsNode(ClickhouseTestMixin, APIBaseTest): + def setUp(self): + self.schema = ExperimentalAITrendsQuery(series=[]) + + def test_agent_reconstructs_conversation(self): + node = GenerateTrendsNode(self.team) + history = node._reconstruct_conversation({"messages": [HumanMessage(content="Text")]}) + self.assertEqual(len(history), 2) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("Answer to this question:", history[1].content) + self.assertNotIn("{{question}}", history[1].content) + + history = node._reconstruct_conversation({"messages": [HumanMessage(content="Text")], "plan": "randomplan"}) + self.assertEqual(len(history), 3) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("the plan", history[1].content) + self.assertNotIn("{{plan}}", history[1].content) + self.assertIn("randomplan", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Answer to this question:", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + self.assertIn("Text", history[2].content) + + node = GenerateTrendsNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + VisualizationMessage(answer=self.schema, plan="randomplan"), + HumanMessage(content="Follow Up"), + ], + "plan": "newrandomplan", + } + ) + + self.assertEqual(len(history), 6) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("the plan", history[1].content) + self.assertNotIn("{{plan}}", history[1].content) + self.assertIn("randomplan", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Answer to this question:", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + self.assertIn("Text", history[2].content) + self.assertEqual(history[3].type, "ai") + self.assertEqual(history[3].content, self.schema.model_dump_json()) + self.assertEqual(history[4].type, "human") + self.assertIn("the new plan", history[4].content) + self.assertNotIn("{{plan}}", history[4].content) + self.assertIn("newrandomplan", history[4].content) + self.assertEqual(history[5].type, "human") + self.assertIn("Answer to this question:", history[5].content) + self.assertNotIn("{{question}}", history[5].content) + self.assertIn("Follow Up", history[5].content) + + def test_agent_reconstructs_conversation_and_merges_messages(self): + node = GenerateTrendsNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [HumanMessage(content="Te"), HumanMessage(content="xt")], + "plan": "randomplan", + } + ) + self.assertEqual(len(history), 3) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("the plan", history[1].content) + self.assertNotIn("{{plan}}", history[1].content) + self.assertIn("randomplan", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Answer to this question:", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + self.assertIn("Te\nxt", history[2].content) + + node = GenerateTrendsNode(self.team) + history = node._reconstruct_conversation( + { + "messages": [ + HumanMessage(content="Text"), + VisualizationMessage(answer=self.schema, plan="randomplan"), + HumanMessage(content="Follow"), + HumanMessage(content="Up"), + ], + "plan": "newrandomplan", + } + ) + + self.assertEqual(len(history), 6) + self.assertEqual(history[0].type, "human") + self.assertIn("mapping", history[0].content) + self.assertEqual(history[1].type, "human") + self.assertIn("the plan", history[1].content) + self.assertNotIn("{{plan}}", history[1].content) + self.assertIn("randomplan", history[1].content) + self.assertEqual(history[2].type, "human") + self.assertIn("Answer to this question:", history[2].content) + self.assertNotIn("{{question}}", history[2].content) + self.assertIn("Text", history[2].content) + self.assertEqual(history[3].type, "ai") + self.assertEqual(history[3].content, self.schema.model_dump_json()) + self.assertEqual(history[4].type, "human") + self.assertIn("the new plan", history[4].content) + self.assertNotIn("{{plan}}", history[4].content) + self.assertIn("newrandomplan", history[4].content) + self.assertEqual(history[5].type, "human") + self.assertIn("Answer to this question:", history[5].content) + self.assertNotIn("{{question}}", history[5].content) + self.assertIn("Follow\nUp", history[5].content) diff --git a/ee/hogai/trends/test/test_toolkit.py b/ee/hogai/trends/test/test_toolkit.py new file mode 100644 index 00000000000..12cd086b033 --- /dev/null +++ b/ee/hogai/trends/test/test_toolkit.py @@ -0,0 +1,235 @@ +from datetime import datetime + +from django.test import override_settings +from freezegun import freeze_time + +from ee.hogai.trends.toolkit import TrendsAgentToolkit +from posthog.models.group.util import create_group +from posthog.models.group_type_mapping import GroupTypeMapping +from posthog.models.property_definition import PropertyDefinition, PropertyType +from posthog.test.base import APIBaseTest, ClickhouseTestMixin, _create_event, _create_person + + +@override_settings(IN_UNIT_TESTING=True) +class TestToolkit(ClickhouseTestMixin, APIBaseTest): + def _create_taxonomy(self): + PropertyDefinition.objects.create( + team=self.team, type=PropertyDefinition.Type.EVENT, name="$browser", property_type=PropertyType.String + ) + PropertyDefinition.objects.create( + team=self.team, type=PropertyDefinition.Type.EVENT, name="id", property_type=PropertyType.Numeric + ) + PropertyDefinition.objects.create( + team=self.team, type=PropertyDefinition.Type.EVENT, name="bool", property_type=PropertyType.Boolean + ) + PropertyDefinition.objects.create( + team=self.team, type=PropertyDefinition.Type.EVENT, name="date", property_type=PropertyType.Datetime + ) + + _create_person( + distinct_ids=["person1"], + team=self.team, + properties={"email": "person1@example.com"}, + ) + _create_event( + event="event1", + distinct_id="person1", + properties={ + "$browser": "Chrome", + "date": datetime(2024, 1, 1).isoformat(), + }, + team=self.team, + ) + _create_event( + event="event1", + distinct_id="person1", + properties={ + "$browser": "Firefox", + "bool": True, + }, + team=self.team, + ) + + _create_person( + distinct_ids=["person2"], + properties={"email": "person2@example.com"}, + team=self.team, + ) + for i in range(10): + _create_event( + event="event1", + distinct_id=f"person2", + properties={"id": i}, + team=self.team, + ) + + def test_retrieve_entity_properties(self): + toolkit = TrendsAgentToolkit(self.team) + + PropertyDefinition.objects.create( + team=self.team, type=PropertyDefinition.Type.PERSON, name="test", property_type="String" + ) + self.assertEqual( + toolkit.retrieve_entity_properties("person"), + "test
", + ) + + GroupTypeMapping.objects.create(team=self.team, group_type_index=0, group_type="group") + PropertyDefinition.objects.create( + team=self.team, type=PropertyDefinition.Type.GROUP, group_type_index=0, name="test", property_type="Numeric" + ) + self.assertEqual( + toolkit.retrieve_entity_properties("group"), + "test
", + ) + + self.assertNotEqual( + toolkit.retrieve_entity_properties("session"), + "", + ) + self.assertIn( + "$session_duration", + toolkit.retrieve_entity_properties("session"), + ) + + def test_retrieve_entity_property_values(self): + toolkit = TrendsAgentToolkit(self.team) + self.assertEqual( + toolkit.retrieve_entity_property_values("session", "$session_duration"), + "30, 146, 2 and many more distinct values.", + ) + self.assertEqual( + toolkit.retrieve_entity_property_values("session", "nonsense"), + "The property nonsense does not exist in the taxonomy.", + ) + + PropertyDefinition.objects.create( + team=self.team, type=PropertyDefinition.Type.PERSON, name="email", property_type=PropertyType.String + ) + PropertyDefinition.objects.create( + team=self.team, type=PropertyDefinition.Type.PERSON, name="id", property_type=PropertyType.Numeric + ) + + for i in range(5): + id = f"person{i}" + with freeze_time(f"2024-01-01T{i}:00:00Z"): + _create_person( + distinct_ids=[id], + properties={"email": f"{id}@example.com", "id": i}, + team=self.team, + ) + with freeze_time(f"2024-01-02T00:00:00Z"): + _create_person( + distinct_ids=["person5"], + properties={"email": "person5@example.com", "id": 5}, + team=self.team, + ) + + self.assertEqual( + toolkit.retrieve_entity_property_values("person", "email"), + '"person5@example.com", "person4@example.com", "person3@example.com", "person2@example.com", "person1@example.com" and 1 more distinct value.', + ) + self.assertEqual( + toolkit.retrieve_entity_property_values("person", "id"), + "5, 4, 3, 2, 1 and 1 more distinct value.", + ) + + toolkit = TrendsAgentToolkit(self.team) + GroupTypeMapping.objects.create(team=self.team, group_type_index=0, group_type="proj") + GroupTypeMapping.objects.create(team=self.team, group_type_index=1, group_type="org") + PropertyDefinition.objects.create( + team=self.team, type=PropertyDefinition.Type.GROUP, group_type_index=0, name="test", property_type="Numeric" + ) + PropertyDefinition.objects.create( + team=self.team, type=PropertyDefinition.Type.GROUP, group_type_index=1, name="test", property_type="String" + ) + + for i in range(7): + id = f"group{i}" + with freeze_time(f"2024-01-01T{i}:00:00Z"): + create_group( + group_type_index=0, + group_key=id, + properties={"test": i}, + team_id=self.team.pk, + ) + with freeze_time(f"2024-01-02T00:00:00Z"): + create_group( + group_type_index=1, + group_key="org", + properties={"test": "7"}, + team_id=self.team.pk, + ) + + self.assertEqual( + toolkit.retrieve_entity_property_values("proj", "test"), + "6, 5, 4, 3, 2 and 2 more distinct values.", + ) + self.assertEqual(toolkit.retrieve_entity_property_values("org", "test"), '"7"') + + def test_group_names(self): + GroupTypeMapping.objects.create(team=self.team, group_type_index=0, group_type="proj") + GroupTypeMapping.objects.create(team=self.team, group_type_index=1, group_type="org") + toolkit = TrendsAgentToolkit(self.team) + self.assertEqual(toolkit._entity_names, ["person", "session", "proj", "org"]) + + def test_empty_events(self): + toolkit = TrendsAgentToolkit(self.team) + self.assertEqual( + toolkit.retrieve_event_properties("test"), "Properties do not exist in the taxonomy for the event test." + ) + + _create_person( + distinct_ids=["person1"], + team=self.team, + properties={}, + ) + _create_event( + event="event1", + distinct_id="person1", + properties={}, + team=self.team, + ) + + toolkit = TrendsAgentToolkit(self.team) + self.assertEqual( + toolkit.retrieve_event_properties("event1"), + "Properties do not exist in the taxonomy for the event event1.", + ) + + def test_retrieve_event_properties(self): + self._create_taxonomy() + toolkit = TrendsAgentToolkit(self.team) + prompt = toolkit.retrieve_event_properties("event1") + + self.assertIn( + "id
", + prompt, + ) + self.assertIn( + "$browser
", + prompt, + ) + self.assertIn( + "date
", + prompt, + ) + self.assertIn( + "bool
", + prompt, + ) + + def test_retrieve_event_property_values(self): + self._create_taxonomy() + toolkit = TrendsAgentToolkit(self.team) + + self.assertIn('"Chrome"', toolkit.retrieve_event_property_values("event1", "$browser")) + self.assertIn('"Firefox"', toolkit.retrieve_event_property_values("event1", "$browser")) + self.assertEqual(toolkit.retrieve_event_property_values("event1", "bool"), "true") + self.assertEqual( + toolkit.retrieve_event_property_values("event1", "id"), + "9, 8, 7, 6, 5 and 5 more distinct values.", + ) + self.assertEqual( + toolkit.retrieve_event_property_values("event1", "date"), f'"{datetime(2024, 1, 1).isoformat()}"' + ) diff --git a/ee/hogai/trends/toolkit.py b/ee/hogai/trends/toolkit.py new file mode 100644 index 00000000000..23421847e10 --- /dev/null +++ b/ee/hogai/trends/toolkit.py @@ -0,0 +1,512 @@ +import json +import xml.etree.ElementTree as ET +from functools import cached_property +from textwrap import dedent +from typing import Any, Literal, Optional, TypedDict, Union + +from pydantic import BaseModel, Field, RootModel + +from ee.hogai.hardcoded_definitions import hardcoded_prop_defs +from posthog.hogql.database.schema.channel_type import POSSIBLE_CHANNEL_TYPES +from posthog.hogql_queries.ai.actors_property_taxonomy_query_runner import ActorsPropertyTaxonomyQueryRunner +from posthog.hogql_queries.ai.event_taxonomy_query_runner import EventTaxonomyQueryRunner +from posthog.hogql_queries.query_runner import ExecutionMode +from posthog.models.group_type_mapping import GroupTypeMapping +from posthog.models.property_definition import PropertyDefinition, PropertyType +from posthog.models.team.team import Team +from posthog.schema import ( + ActorsPropertyTaxonomyQuery, + CachedActorsPropertyTaxonomyQueryResponse, + CachedEventTaxonomyQueryResponse, + EventTaxonomyQuery, + ExperimentalAITrendsQuery, +) + + +class ToolkitTool(TypedDict): + name: str + signature: str + description: str + + +class RetrieveEntityPropertiesValuesArgsModel(BaseModel): + entity: str + property_name: str + + +class RetrieveEntityPropertiesValuesModel(BaseModel): + name: Literal["retrieve_entity_property_values"] + arguments: RetrieveEntityPropertiesValuesArgsModel + + +class RetrieveEventPropertiesValuesArgsModel(BaseModel): + event_name: str + property_name: str + + +class RetrieveEventPropertiesValuesModel(BaseModel): + name: Literal["retrieve_event_property_values"] + arguments: RetrieveEventPropertiesValuesArgsModel + + +class SingleArgumentTrendsAgentToolModel(BaseModel): + name: Literal[ + "retrieve_entity_properties", + "retrieve_event_properties", + "final_answer", + "handle_incorrect_response", + ] + arguments: str + + +class TrendsAgentToolModel( + RootModel[ + Union[ + SingleArgumentTrendsAgentToolModel, RetrieveEntityPropertiesValuesModel, RetrieveEventPropertiesValuesModel + ] + ] +): + root: Union[ + SingleArgumentTrendsAgentToolModel, RetrieveEntityPropertiesValuesModel, RetrieveEventPropertiesValuesModel + ] = Field(..., discriminator="name") + + +class TrendsAgentToolkit: + _team: Team + + def __init__(self, team: Team): + self._team = team + + @property + def groups(self): + return GroupTypeMapping.objects.filter(team=self._team).order_by("group_type_index") + + @cached_property + def _entity_names(self) -> list[str]: + """ + The schemas use `group_type_index` for groups complicating things for the agent. Instead, we use groups' names, + so the generation step will handle their indexes. Tools would need to support multiple arguments, or we would need + to create various tools for different group types. Since we don't use function calling here, we want to limit the + number of tools because non-function calling models can't handle many tools. + """ + entities = [ + "person", + "session", + *[group.group_type for group in self.groups], + ] + return entities + + @cached_property + def tools(self) -> list[ToolkitTool]: + """ + Our ReAct agent doesn't use function calling. Instead, it uses tools in natural language to decide next steps. The agent expects the following format: + + ``` + retrieve_entity_properties_tool(entity: "Literal['person', 'session', 'organization', 'instance', 'project']") - description. + ``` + + Events and other entities are intentionally separated for properties retrieval. Potentially, there can be different functions for each entity type. + """ + + stringified_entities = ", ".join([f"'{entity}'" for entity in self._entity_names]) + + tools: list[ToolkitTool] = [ + { + "name": tool["name"], + "signature": tool["signature"], + "description": dedent(tool["description"]), + } + for tool in [ + { + "name": "retrieve_event_properties", + "signature": "(event_name: str)", + "description": """ + Use this tool to retrieve property names of an event that the user has in their taxonomy. You will receive a list of properties, their value types and example values or a message that properties have not been found. + + - **Try other events** if the tool doesn't return any properties. + - **Prioritize properties that are directly related to the context or objective of the user's query.** + - **Avoid using ambiguous properties** unless their relevance is explicitly confirmed. + + Args: + event_name: The name of the event that you want to retrieve properties for. + """, + }, + { + "name": "retrieve_event_property_values", + "signature": "(event_name: str, property_name: str)", + "description": """ + Use this tool to retrieve property values for an event that the user has in their taxonomy. Adjust filters to these values. You will receive a list of property values or a message that property values have not been found. Some properties can have many values, so the output will be truncated. Use your judgement to find a proper value. + + Args: + event_name: The name of the event that you want to retrieve values for. + property_name: The name of the property that you want to retrieve values for. + """, + }, + { + "name": f"retrieve_entity_properties", + "signature": f"(entity: Literal[{stringified_entities}])", + "description": """ + Use this tool to retrieve property names for a property group (entity) that the user has in their taxonomy. You will receive a list of properties and their value types or a message that properties have not been found. + + - **Infer the property groups from the user's request.** + - **Try other entities** if the tool doesn't return any properties. + - **Prioritize properties that are directly related to the context or objective of the user's query.** + - **Avoid using ambiguous properties** unless their relevance is explicitly confirmed. + + Args: + entity: The type of the entity that you want to retrieve properties for. + """, + }, + { + "name": "retrieve_entity_property_values", + "signature": f"(entity: Literal[{stringified_entities}], property_name: str)", + "description": """ + Use this tool to retrieve property values for a property name that the user has in their taxonomy. Adjust filters to these values. You will receive a list of property values or a message that property values have not been found. Some properties can have many values, so the output will be truncated. Use your judgement to find a proper value. + + Args: + entity: The type of the entity that you want to retrieve properties for. + property_name: The name of the property that you want to retrieve values for. + """, + }, + { + "name": "final_answer", + "signature": "(final_response: str)", + "description": """ + Use this tool to provide the final answer to the user's question. + + Answer in the following format: + ``` + Events: + - event 1 + - math operation: total + - property filter 1: + - entity + - property name + - property type + - operator + - property value + - property filter 2... Repeat for each property filter. + - event 2 + - math operation: average by `property name`. + - property filter 1: + - entity + - property name + - property type + - operator + - property value + - property filter 2... Repeat for each property filter. + - Repeat for each event. + + (if a formula is used) + Formula: + `A/B`, where `A` is the first event and `B` is the second event. + + (if a breakdown is used) + Breakdown by: + - breakdown 1: + - entity + - property name + - Repeat for each breakdown. + ``` + + Args: + final_response: List all events, actions, and properties that you want to use to answer the question. + """, + }, + ] + ] + + return tools + + def render_text_description(self) -> str: + """ + Render the tool name and description in plain text. + + Returns: + The rendered text. + + Output will be in the format of: + + .. code-block:: markdown + + search: This tool is used for search + calculator: This tool is used for math + """ + descriptions = [] + for tool in self.tools: + description = f"{tool['name']}{tool['signature']} - {tool['description']}" + descriptions.append(description) + return "\n".join(descriptions) + + def _generate_properties_xml(self, children: list[tuple[str, str | None]]): + root = ET.Element("properties") + property_types = {property_type for _, property_type in children if property_type is not None} + property_type_to_tag = {property_type: ET.SubElement(root, property_type) for property_type in property_types} + + for name, property_type in children: + # Do not include properties that are ambiguous. + if property_type is None: + continue + + type_tag = property_type_to_tag[property_type] + ET.SubElement(type_tag, "name").text = name + # Add a line break between names. Doubtful that it does anything. + ET.SubElement(type_tag, "br") + + return ET.tostring(root, encoding="unicode") + + def retrieve_entity_properties(self, entity: str) -> str: + """ + Retrieve properties for an entitiy like person, session, or one of the groups. + """ + if entity not in ("person", "session", *[group.group_type for group in self.groups]): + return f"Entity {entity} does not exist in the taxonomy." + + if entity == "person": + qs = PropertyDefinition.objects.filter(team=self._team, type=PropertyDefinition.Type.PERSON).values_list( + "name", "property_type" + ) + props = list(qs) + elif entity == "session": + # Session properties are not in the DB. + props = [ + (prop_name, prop["type"]) + for prop_name, prop in hardcoded_prop_defs["session_properties"].items() + if prop.get("type") is not None + ] + else: + group_type_index = next( + (group.group_type_index for group in self.groups if group.group_type == entity), None + ) + if group_type_index is None: + return f"Group {entity} does not exist in the taxonomy." + qs = PropertyDefinition.objects.filter( + team=self._team, type=PropertyDefinition.Type.GROUP, group_type_index=group_type_index + ).values_list("name", "property_type") + props = list(qs) + + return self._generate_properties_xml(props) + + def retrieve_event_properties(self, event_name: str) -> str: + """ + Retrieve properties for an event. + """ + runner = EventTaxonomyQueryRunner(EventTaxonomyQuery(event=event_name), self._team) + response = runner.run(ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE) + + if not isinstance(response, CachedEventTaxonomyQueryResponse): + return "Properties have not been found." + + if not response.results: + return f"Properties do not exist in the taxonomy for the event {event_name}." + + # Intersect properties with their types. + qs = PropertyDefinition.objects.filter( + team=self._team, type=PropertyDefinition.Type.EVENT, name__in=[item.property for item in response.results] + ) + property_to_type = {property_definition.name: property_definition.property_type for property_definition in qs} + + return self._generate_properties_xml( + [ + (item.property, property_to_type.get(item.property)) + for item in response.results + # Exclude properties that exist in the taxonomy, but don't have a type. + if item.property in property_to_type + ] + ) + + def _format_property_values( + self, sample_values: list, sample_count: Optional[int] = 0, format_as_string: bool = False + ) -> str: + if len(sample_values) == 0 or sample_count == 0: + return f"The property does not have any values in the taxonomy." + + # Add quotes to the String type, so the LLM can easily infer a type. + # Strings like "true" or "10" are interpreted as booleans or numbers without quotes, so the schema generation fails. + # Remove the floating point the value is an integer. + formatted_sample_values: list[str] = [] + for value in sample_values: + if format_as_string: + formatted_sample_values.append(f'"{value}"') + elif isinstance(value, float) and value.is_integer(): + formatted_sample_values.append(str(int(value))) + else: + formatted_sample_values.append(str(value)) + prop_values = ", ".join(formatted_sample_values) + + # If there wasn't an exact match with the user's search, we provide a hint that LLM can use an arbitrary value. + if sample_count is None: + return f"{prop_values} and many more distinct values." + elif sample_count > len(sample_values): + diff = sample_count - len(sample_values) + return f"{prop_values} and {diff} more distinct value{'' if diff == 1 else 's'}." + + return prop_values + + def retrieve_event_property_values(self, event_name: str, property_name: str) -> str: + try: + property_definition = PropertyDefinition.objects.get( + team=self._team, name=property_name, type=PropertyDefinition.Type.EVENT + ) + except PropertyDefinition.DoesNotExist: + return f"The property {property_name} does not exist in the taxonomy." + + runner = EventTaxonomyQueryRunner(EventTaxonomyQuery(event=event_name), self._team) + response = runner.run(ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE) + + if not isinstance(response, CachedEventTaxonomyQueryResponse): + return f"The event {event_name} does not exist in the taxonomy." + + if not response.results: + return f"Property values for {property_name} do not exist in the taxonomy for the event {event_name}." + + prop = next((item for item in response.results if item.property == property_name), None) + if not prop: + return f"The property {property_name} does not exist in the taxonomy for the event {event_name}." + + return self._format_property_values( + prop.sample_values, + prop.sample_count, + format_as_string=property_definition.property_type in (PropertyType.String, PropertyType.Datetime), + ) + + def _retrieve_session_properties(self, property_name: str) -> str: + """ + Sessions properties example property values are hardcoded. + """ + if property_name not in hardcoded_prop_defs["session_properties"]: + return f"The property {property_name} does not exist in the taxonomy." + + if property_name == "$channel_type": + sample_values = POSSIBLE_CHANNEL_TYPES.copy() + sample_count = len(sample_values) + is_str = True + elif ( + property_name in hardcoded_prop_defs["session_properties"] + and "examples" in hardcoded_prop_defs["session_properties"][property_name] + ): + sample_values = hardcoded_prop_defs["session_properties"][property_name]["examples"] + sample_count = None + is_str = hardcoded_prop_defs["session_properties"][property_name]["type"] == PropertyType.String + else: + return f"Property values for {property_name} do not exist in the taxonomy for the session entity." + + return self._format_property_values(sample_values, sample_count, format_as_string=is_str) + + def retrieve_entity_property_values(self, entity: str, property_name: str) -> str: + if entity not in self._entity_names: + return f"The entity {entity} does not exist in the taxonomy. You must use one of the following: {', '.join(self._entity_names)}." + + if entity == "session": + return self._retrieve_session_properties(property_name) + + if entity == "person": + query = ActorsPropertyTaxonomyQuery(property=property_name) + else: + group_index = next((group.group_type_index for group in self.groups if group.group_type == entity), None) + if group_index is None: + return f"The entity {entity} does not exist in the taxonomy." + query = ActorsPropertyTaxonomyQuery(group_type_index=group_index, property=property_name) + + try: + if query.group_type_index is not None: + prop_type = PropertyDefinition.Type.GROUP + group_type_index = query.group_type_index + else: + prop_type = PropertyDefinition.Type.PERSON + group_type_index = None + + property_definition = PropertyDefinition.objects.get( + team=self._team, + name=property_name, + type=prop_type, + group_type_index=group_type_index, + ) + except PropertyDefinition.DoesNotExist: + return f"The property {property_name} does not exist in the taxonomy for the entity {entity}." + + response = ActorsPropertyTaxonomyQueryRunner(query, self._team).run( + ExecutionMode.RECENT_CACHE_CALCULATE_BLOCKING_IF_STALE + ) + + if not isinstance(response, CachedActorsPropertyTaxonomyQueryResponse): + return f"The entity {entity} does not exist in the taxonomy." + + if not response.results: + return f"Property values for {property_name} do not exist in the taxonomy for the entity {entity}." + + return self._format_property_values( + response.results.sample_values, + response.results.sample_count, + format_as_string=property_definition.property_type in (PropertyType.String, PropertyType.Datetime), + ) + + def handle_incorrect_response(self, response: str) -> str: + """ + No-op tool. Take a parsing error and return a response that the LLM can use to correct itself. + Used to control a number of retries. + """ + return response + + +class GenerateTrendTool: + def _replace_value_in_dict(self, item: Any, original_schema: Any): + if isinstance(item, list): + return [self._replace_value_in_dict(i, original_schema) for i in item] + elif isinstance(item, dict): + if list(item.keys()) == ["$ref"]: + definitions = item["$ref"][2:].split("/") + res = original_schema.copy() + for definition in definitions: + res = res[definition] + return res + else: + return {key: self._replace_value_in_dict(i, original_schema) for key, i in item.items()} + else: + return item + + def _flatten_schema(self): + schema = ExperimentalAITrendsQuery.model_json_schema() + + # Patch `numeric` types + schema["$defs"]["MathGroupTypeIndex"]["type"] = "number" + property_filters = ( + "EventPropertyFilter", + "PersonPropertyFilter", + "SessionPropertyFilter", + "FeaturePropertyFilter", + ) + + # Clean up the property filters + for key in property_filters: + property_schema = schema["$defs"][key] + property_schema["properties"]["key"]["description"] = ( + f"Use one of the properties the user has provided in the plan." + ) + + for _ in range(100): + if "$ref" not in json.dumps(schema): + break + schema = self._replace_value_in_dict(schema.copy(), schema.copy()) + del schema["$defs"] + return schema + + @cached_property + def schema(self): + return { + "name": "output_insight_schema", + "description": "Outputs the JSON schema of a product analytics insight", + "parameters": { + "type": "object", + "properties": { + "reasoning_steps": { + "type": "array", + "items": {"type": "string"}, + "description": "The reasoning steps leading to the final conclusion that will be shown to the user. Use 'you' if you want to refer to the user.", + }, + "answer": self._flatten_schema(), + }, + "additionalProperties": False, + "required": ["reasoning_steps", "answer"], + }, + } diff --git a/ee/hogai/trends/utils.py b/ee/hogai/trends/utils.py new file mode 100644 index 00000000000..080f85f0256 --- /dev/null +++ b/ee/hogai/trends/utils.py @@ -0,0 +1,10 @@ +from typing import Optional + +from pydantic import BaseModel + +from posthog.schema import ExperimentalAITrendsQuery + + +class GenerateTrendOutputModel(BaseModel): + reasoning_steps: Optional[list[str]] = None + answer: Optional[ExperimentalAITrendsQuery] = None diff --git a/ee/hogai/trends_function.py b/ee/hogai/trends_function.py deleted file mode 100644 index 6f57b475065..00000000000 --- a/ee/hogai/trends_function.py +++ /dev/null @@ -1,71 +0,0 @@ -import json -from functools import cached_property -from typing import Any - -from ee.hogai.team_prompt import TeamPrompt -from posthog.models.property_definition import PropertyDefinition -from posthog.schema import ExperimentalAITrendsQuery - - -class TrendsFunction: - def _replace_value_in_dict(self, item: Any, original_schema: Any): - if isinstance(item, list): - return [self._replace_value_in_dict(i, original_schema) for i in item] - elif isinstance(item, dict): - if list(item.keys()) == ["$ref"]: - definitions = item["$ref"][2:].split("/") - res = original_schema.copy() - for definition in definitions: - res = res[definition] - return res - else: - return {key: self._replace_value_in_dict(i, original_schema) for key, i in item.items()} - else: - return item - - @cached_property - def _flat_schema(self): - schema = ExperimentalAITrendsQuery.model_json_schema() - - # Patch `numeric` types - schema["$defs"]["MathGroupTypeIndex"]["type"] = "number" - - # Clean up the property filters - for key, title in ( - ("EventPropertyFilter", PropertyDefinition.Type.EVENT.label), - ("PersonPropertyFilter", PropertyDefinition.Type.PERSON.label), - ("SessionPropertyFilter", PropertyDefinition.Type.SESSION.label), - ("FeaturePropertyFilter", "feature"), - ("CohortPropertyFilter", "cohort"), - ): - property_schema = schema["$defs"][key] - property_schema["properties"]["key"]["description"] = ( - f"Use one of the properties the user has provided in the <{TeamPrompt.get_properties_tag_name(title)}> tag." - ) - - for _ in range(100): - if "$ref" not in json.dumps(schema): - break - schema = self._replace_value_in_dict(schema.copy(), schema.copy()) - del schema["$defs"] - return schema - - def generate_function(self): - return { - "type": "function", - "function": { - "name": "output_insight_schema", - "description": "Outputs the JSON schema of a product analytics insight", - "parameters": { - "type": "object", - "properties": { - "reasoning_steps": { - "type": "array", - "items": {"type": "string"}, - "description": "The reasoning steps leading to the final conclusion that will be shown to the user. Use 'you' if you want to refer to the user.", - }, - "answer": self._flat_schema, - }, - }, - }, - } diff --git a/ee/hogai/utils.py b/ee/hogai/utils.py new file mode 100644 index 00000000000..65de9303b3f --- /dev/null +++ b/ee/hogai/utils.py @@ -0,0 +1,52 @@ +import operator +from abc import ABC, abstractmethod +from collections.abc import Sequence +from enum import StrEnum +from typing import Annotated, Optional, TypedDict, Union + +from langchain_core.agents import AgentAction +from langchain_core.runnables import RunnableConfig +from langgraph.graph import END, START +from pydantic import BaseModel, Field + +from posthog.models.team.team import Team +from posthog.schema import AssistantMessage, HumanMessage, RootAssistantMessage, VisualizationMessage + +AssistantMessageUnion = Union[AssistantMessage, HumanMessage, VisualizationMessage] + + +class Conversation(BaseModel): + messages: list[RootAssistantMessage] = Field(..., min_length=1, max_length=20) + session_id: str + + +class AssistantState(TypedDict): + messages: Annotated[Sequence[AssistantMessageUnion], operator.add] + intermediate_steps: Optional[list[tuple[AgentAction, Optional[str]]]] + plan: Optional[str] + tool_argument: Optional[str] + + +class AssistantNodeName(StrEnum): + START = START + END = END + CREATE_TRENDS_PLAN = "create_trends_plan" + CREATE_TRENDS_PLAN_TOOLS = "create_trends_plan_tools" + GENERATE_TRENDS = "generate_trends_schema" + GENERATE_TRENDS_TOOLS = "generate_trends_tools" + + +class AssistantNode(ABC): + name: AssistantNodeName + _team: Team + + def __init__(self, team: Team): + self._team = team + + @abstractmethod + def run(cls, state: AssistantState, config: RunnableConfig): + raise NotImplementedError + + +def remove_line_breaks(line: str) -> str: + return line.replace("\n", " ") diff --git a/ee/settings.py b/ee/settings.py index 9844074a956..64e3bfc5b8b 100644 --- a/ee/settings.py +++ b/ee/settings.py @@ -4,11 +4,10 @@ Django settings for PostHog Enterprise Edition. import os -from posthog.settings import AUTHENTICATION_BACKENDS, DEMO, SITE_URL, DEBUG +from posthog.settings import AUTHENTICATION_BACKENDS, DEBUG, DEMO, SITE_URL from posthog.settings.utils import get_from_env from posthog.utils import str_to_bool - # SSO AUTHENTICATION_BACKENDS = [ *AUTHENTICATION_BACKENDS, @@ -69,3 +68,8 @@ PARALLEL_ASSET_GENERATION_MAX_TIMEOUT_MINUTES = get_from_env( ) HOOK_HOG_FUNCTION_TEAMS = get_from_env("HOOK_HOG_FUNCTION_TEAMS", "", type_cast=str) + +# Assistant +LANGFUSE_PUBLIC_KEY = get_from_env("LANGFUSE_PUBLIC_KEY", "", type_cast=str) +LANGFUSE_SECRET_KEY = get_from_env("LANGFUSE_SECRET_KEY", "", type_cast=str) +LANGFUSE_HOST = get_from_env("LANGFUSE_HOST", "https://us.cloud.langfuse.com", type_cast=str) diff --git a/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--dark.png b/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--dark.png index ebceafa1ecc..e86df67429b 100644 Binary files a/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--dark.png and b/frontend/__snapshots__/scenes-app-insights--funnel-top-to-bottom-edit--dark.png differ diff --git a/frontend/src/queries/schema.json b/frontend/src/queries/schema.json index cdb3c4449d0..de989c53b1c 100644 --- a/frontend/src/queries/schema.json +++ b/frontend/src/queries/schema.json @@ -124,9 +124,6 @@ { "$ref": "#/definitions/SessionPropertyFilter" }, - { - "$ref": "#/definitions/CohortPropertyFilter" - }, { "$ref": "#/definitions/GroupPropertyFilter" }, @@ -264,7 +261,20 @@ }, "sample_values": { "items": { - "type": "string" + "anyOf": [ + { + "type": "string" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "integer" + } + ] }, "type": "array" } @@ -607,6 +617,24 @@ } ] }, + "AssistantMessage": { + "additionalProperties": false, + "properties": { + "content": { + "type": "string" + }, + "type": { + "const": "ai", + "type": "string" + } + }, + "required": ["type", "content"], + "type": "object" + }, + "AssistantMessageType": { + "enum": ["human", "ai", "ai/viz"], + "type": "string" + }, "AutocompleteCompletionItem": { "additionalProperties": false, "properties": { @@ -6720,6 +6748,20 @@ "required": ["results"], "type": "object" }, + "HumanMessage": { + "additionalProperties": false, + "properties": { + "content": { + "type": "string" + }, + "type": { + "const": "human", + "type": "string" + } + }, + "required": ["type", "content"], + "type": "object" + }, "InsightActorsQuery": { "additionalProperties": false, "properties": { @@ -10584,6 +10626,19 @@ "required": ["count"], "type": "object" }, + "RootAssistantMessage": { + "anyOf": [ + { + "$ref": "#/definitions/VisualizationMessage" + }, + { + "$ref": "#/definitions/AssistantMessage" + }, + { + "$ref": "#/definitions/HumanMessage" + } + ] + }, "SamplingRate": { "additionalProperties": false, "properties": { @@ -11713,6 +11768,36 @@ "required": ["results"], "type": "object" }, + "VisualizationMessage": { + "additionalProperties": false, + "properties": { + "answer": { + "$ref": "#/definitions/ExperimentalAITrendsQuery" + }, + "plan": { + "type": "string" + }, + "reasoning_steps": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ] + }, + "type": { + "const": "ai/viz", + "type": "string" + } + }, + "required": ["type"], + "type": "object" + }, "VizSpecificOptions": { "additionalProperties": false, "description": "Chart specific rendering options. Use ChartRenderingMetadata for non-serializable values, e.g. onClick handlers", diff --git a/frontend/src/queries/schema.ts b/frontend/src/queries/schema.ts index b288c4f49a3..3fee10aa88c 100644 --- a/frontend/src/queries/schema.ts +++ b/frontend/src/queries/schema.ts @@ -9,7 +9,6 @@ import { BreakdownType, ChartDisplayCategory, ChartDisplayType, - CohortPropertyFilter, CountPerActorMathType, DurationType, EventPropertyFilter, @@ -898,15 +897,7 @@ export interface TrendsQuery extends InsightsQueryBase { export type AIPropertyFilter = | EventPropertyFilter | PersonPropertyFilter - // | ElementPropertyFilter | SessionPropertyFilter - | CohortPropertyFilter - // | RecordingPropertyFilter - // | LogEntryPropertyFilter - // | HogQLPropertyFilter - // | EmptyPropertyFilter - // | DataWarehousePropertyFilter - // | DataWarehousePersonPropertyFilter | GroupPropertyFilter | FeaturePropertyFilter @@ -2077,7 +2068,9 @@ export type EventTaxonomyQueryResponse = AnalyticsQueryResponseBase export interface ActorsPropertyTaxonomyResponse { - sample_values: string[] + // Values can be floats and integers. The comment below is to preserve the `integer` type. + // eslint-disable-next-line @typescript-eslint/no-duplicate-type-constituents + sample_values: (string | number | boolean | integer)[] sample_count: integer } @@ -2090,3 +2083,28 @@ export interface ActorsPropertyTaxonomyQuery extends DataNode export type CachedActorsPropertyTaxonomyQueryResponse = CachedQueryResponse + +export enum AssistantMessageType { + Human = 'human', + Assistant = 'ai', + Visualization = 'ai/viz', +} + +export interface HumanMessage { + type: AssistantMessageType.Human + content: string +} + +export interface AssistantMessage { + type: AssistantMessageType.Assistant + content: string +} + +export interface VisualizationMessage { + type: AssistantMessageType.Visualization + plan?: string + reasoning_steps?: string[] | null + answer?: ExperimentalAITrendsQuery +} + +export type RootAssistantMessage = VisualizationMessage | AssistantMessage | HumanMessage diff --git a/frontend/src/scenes/max/Thread.tsx b/frontend/src/scenes/max/Thread.tsx index dabc0cae374..93e9f403e81 100644 --- a/frontend/src/scenes/max/Thread.tsx +++ b/frontend/src/scenes/max/Thread.tsx @@ -6,13 +6,21 @@ import { BreakdownSummary, PropertiesSummary, SeriesSummary } from 'lib/componen import { TopHeading } from 'lib/components/Cards/InsightCard/TopHeading' import { IconOpenInNew } from 'lib/lemon-ui/icons' import posthog from 'posthog-js' -import React, { useRef, useState } from 'react' +import React, { useMemo, useRef, useState } from 'react' import { urls } from 'scenes/urls' import { Query } from '~/queries/Query/Query' -import { InsightQueryNode, InsightVizNode, NodeKind } from '~/queries/schema' +import { + AssistantMessageType, + HumanMessage, + InsightVizNode, + NodeKind, + TrendsQuery, + VisualizationMessage, +} from '~/queries/schema' -import { maxLogic, ThreadMessage, TrendGenerationResult } from './maxLogic' +import { maxLogic, MessageStatus, ThreadMessage } from './maxLogic' +import { isHumanMessage, isVisualizationMessage } from './utils' export function Thread(): JSX.Element | null { const { thread, threadLoading } = useValues(maxLogic) @@ -20,11 +28,11 @@ export function Thread(): JSX.Element | null { return (
{thread.map((message, index) => { - if (message.role === 'user' || typeof message.content === 'string') { + if (isHumanMessage(message)) { return ( {message.content || No text} @@ -32,16 +40,21 @@ export function Thread(): JSX.Element | null { ) } - return ( - - ) + if (isVisualizationMessage(message)) { + return ( + + ) + } + + return null })} {threadLoading && ( - +
Let me think… @@ -52,52 +65,59 @@ export function Thread(): JSX.Element | null { ) } -const Message = React.forwardRef< - HTMLDivElement, - React.PropsWithChildren<{ role: 'user' | 'assistant'; className?: string }> ->(function Message({ role, children, className }, ref): JSX.Element { - if (role === 'user') { +const Message = React.forwardRef>( + function Message({ type, children, className }, ref): JSX.Element { + if (type === AssistantMessageType.Human) { + return ( +
+ {children} +
+ ) + } + return ( -
+
{children}
) } - - return ( -
- {children} -
- ) -}) +) function Answer({ message, + status, previousMessage, }: { - message: ThreadMessage & { content: TrendGenerationResult } + message: VisualizationMessage + status?: MessageStatus previousMessage: ThreadMessage }): JSX.Element { - const query: InsightVizNode = { - kind: NodeKind.InsightVizNode, - source: message.content?.answer as InsightQueryNode, - showHeader: true, - } + const query = useMemo(() => { + if (message.answer) { + return { + kind: NodeKind.InsightVizNode, + source: message.answer as TrendsQuery, + showHeader: true, + } + } + + return null + }, [message]) return ( <> - {message.content?.reasoning_steps && ( - + {message.reasoning_steps && ( +
    - {message.content.reasoning_steps.map((step, index) => ( + {message.reasoning_steps.map((step, index) => (
  • {step}
  • ))}
)} - {message.status === 'completed' && message.content?.answer && ( + {status === 'completed' && query && ( <> - +
@@ -118,7 +138,9 @@ function Answer({
- + {isHumanMessage(previousMessage) && ( + + )} )} @@ -129,8 +151,8 @@ function AnswerActions({ message, previousMessage, }: { - message: ThreadMessage & { content: TrendGenerationResult } - previousMessage: ThreadMessage + message: VisualizationMessage + previousMessage: HumanMessage }): JSX.Element { const [rating, setRating] = useState<'good' | 'bad' | null>(null) const [feedback, setFeedback] = useState('') @@ -144,7 +166,7 @@ function AnswerActions({ setRating(newRating) posthog.capture('chat rating', { question: previousMessage.content, - answer: message.content, + answer: JSON.stringify(message.answer), answer_rating: rating, }) if (newRating === 'bad') { @@ -158,7 +180,7 @@ function AnswerActions({ } posthog.capture('chat feedback', { question: previousMessage.content, - answer: message.content, + answer: JSON.stringify(message.answer), feedback, }) setFeedbackInputStatus('submitted') @@ -188,7 +210,7 @@ function AnswerActions({
{feedbackInputStatus !== 'hidden' && ( { if (el && !hasScrolledFeedbackInputIntoView.current) { // When the feedback input is first rendered, scroll it into view diff --git a/frontend/src/scenes/max/__mocks__/chatResponse.json b/frontend/src/scenes/max/__mocks__/chatResponse.json index 8be5242f5b6..5fed25c08bf 100644 --- a/frontend/src/scenes/max/__mocks__/chatResponse.json +++ b/frontend/src/scenes/max/__mocks__/chatResponse.json @@ -1,4 +1,6 @@ { + "type": "ai/viz", + "plan": "Test plan", "reasoning_steps": [ "The user's query is to identify the most popular pages.", "To determine the most popular pages, we should analyze the '$pageview' event as it tracks when a user loads or reloads a page.", diff --git a/frontend/src/scenes/max/maxLogic.ts b/frontend/src/scenes/max/maxLogic.ts index be0ca6e2213..69d53bf956b 100644 --- a/frontend/src/scenes/max/maxLogic.ts +++ b/frontend/src/scenes/max/maxLogic.ts @@ -3,7 +3,7 @@ import { actions, kea, key, listeners, path, props, reducers, selectors } from ' import { loaders } from 'kea-loaders' import api from 'lib/api' -import { ExperimentalAITrendsQuery, NodeKind, SuggestedQuestionsQuery } from '~/queries/schema' +import { AssistantMessageType, NodeKind, RootAssistantMessage, SuggestedQuestionsQuery } from '~/queries/schema' import type { maxLogicType } from './maxLogicType' @@ -11,15 +11,10 @@ export interface MaxLogicProps { sessionId: string } -export interface TrendGenerationResult { - reasoning_steps?: string[] - answer?: ExperimentalAITrendsQuery -} +export type MessageStatus = 'loading' | 'completed' | 'error' -export interface ThreadMessage { - role: 'user' | 'assistant' - content?: string | TrendGenerationResult - status?: 'loading' | 'completed' | 'error' +export type ThreadMessage = RootAssistantMessage & { + status?: MessageStatus } export const maxLogic = kea([ @@ -114,16 +109,13 @@ export const maxLogic = kea([ actions.setVisibleSuggestions(allSuggestionsWithoutCurrentlyVisible.slice(0, 3)) }, askMax: async ({ prompt }) => { - actions.addMessage({ role: 'user', content: prompt }) + actions.addMessage({ type: AssistantMessageType.Human, content: prompt }) const newIndex = values.thread.length try { const response = await api.chat({ session_id: props.sessionId, - messages: values.thread.map(({ role, content }) => ({ - role, - content: typeof content === 'string' ? content : JSON.stringify(content), - })), + messages: values.thread.map(({ status, ...message }) => message), }) const reader = response.body?.getReader() const decoder = new TextDecoder() @@ -145,12 +137,11 @@ export const maxLogic = kea([ firstChunk = false if (parsedResponse) { - actions.addMessage({ role: 'assistant', content: parsedResponse, status: 'loading' }) + actions.addMessage({ ...parsedResponse, status: 'loading' }) } } else if (parsedResponse) { actions.replaceMessage(newIndex, { - role: 'assistant', - content: parsedResponse, + ...parsedResponse, status: 'loading', }) } @@ -172,10 +163,10 @@ export const maxLogic = kea([ * Parses the generation result from the API. Some generation chunks might be sent in batches. * @param response */ -function parseResponse(response: string, recursive = true): TrendGenerationResult | null { +function parseResponse(response: string, recursive = true): RootAssistantMessage | null { try { const parsed = JSON.parse(response) - return parsed as TrendGenerationResult + return parsed as RootAssistantMessage } catch { if (!recursive) { return null diff --git a/frontend/src/scenes/max/utils.ts b/frontend/src/scenes/max/utils.ts new file mode 100644 index 00000000000..263eb2f521b --- /dev/null +++ b/frontend/src/scenes/max/utils.ts @@ -0,0 +1,11 @@ +import { AssistantMessageType, HumanMessage, RootAssistantMessage, VisualizationMessage } from '~/queries/schema' + +export function isVisualizationMessage( + message: RootAssistantMessage | undefined | null +): message is VisualizationMessage { + return message?.type === AssistantMessageType.Visualization +} + +export function isHumanMessage(message: RootAssistantMessage | undefined | null): message is HumanMessage { + return message?.type === AssistantMessageType.Human +} diff --git a/mypy-baseline.txt b/mypy-baseline.txt index d1d7bf60c91..6f5de8a4ccf 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -73,10 +73,10 @@ posthog/models/subscription.py:0: error: Argument 2 to "SubscriptionResourceInfo posthog/models/exported_asset.py:0: error: Value of type variable "_StrOrPromiseT" of "slugify" cannot be "str | None" [type-var] posthog/models/action/action.py:0: error: Need type annotation for "events" [var-annotated] posthog/models/action/action.py:0: error: Argument 1 to "len" has incompatible type "str | None"; expected "Sized" [arg-type] -posthog/hogql/ast.py:0: error: Incompatible return value type (got "bool | None", expected "bool") [return-value] posthog/hogql/database/schema/numbers.py:0: error: Incompatible types in assignment (expression has type "dict[str, IntegerDatabaseField]", variable has type "dict[str, FieldOrTable]") [assignment] posthog/hogql/database/schema/numbers.py:0: note: "Dict" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance posthog/hogql/database/schema/numbers.py:0: note: Consider using "Mapping" instead, which is covariant in the value type +posthog/hogql/ast.py:0: error: Incompatible return value type (got "bool | None", expected "bool") [return-value] ee/models/license.py:0: error: Incompatible return value type (got "_T", expected "License | None") [return-value] ee/models/license.py:0: error: Cannot use a covariant type variable as a parameter [misc] ee/models/license.py:0: error: "_T" has no attribute "plan" [attr-defined] @@ -126,10 +126,10 @@ posthog/models/user.py:0: error: "User" has no attribute "social_auth" [attr-de posthog/models/plugin.py:0: error: Argument 1 to "extract_plugin_code" has incompatible type "bytes | memoryview | None"; expected "bytes" [arg-type] posthog/models/plugin.py:0: error: Name "timezone.datetime" is not defined [name-defined] posthog/models/plugin.py:0: error: Name "timezone.datetime" is not defined [name-defined] -posthog/models/organization_invite.py:0: error: Argument "level" to "join" of "User" has incompatible type "int"; expected "Level" [arg-type] posthog/models/person/person.py:0: error: "_T" has no attribute "_add_distinct_ids" [attr-defined] posthog/models/person/person.py:0: error: Argument "version" to "create_person" has incompatible type "int | None"; expected "int" [arg-type] posthog/models/person/person.py:0: error: Incompatible types in assignment (expression has type "list[Never]", variable has type "QuerySet[PersonDistinctId, str]") [assignment] +posthog/models/organization_invite.py:0: error: Argument "level" to "join" of "User" has incompatible type "int"; expected "Level" [arg-type] posthog/hogql_queries/legacy_compatibility/filter_to_query.py:0: error: Dict entry 4 has incompatible type "str": "Literal[0, 1, 2, 3, 4] | None"; expected "str": "str | None" [dict-item] posthog/hogql_queries/legacy_compatibility/filter_to_query.py:0: error: Item "None" of "Any | None" has no attribute "__iter__" (not iterable) [union-attr] posthog/hogql_queries/legacy_compatibility/filter_to_query.py:0: error: Argument 1 to "float" has incompatible type "Any | None"; expected "str | Buffer | SupportsFloat | SupportsIndex" [arg-type] @@ -227,7 +227,6 @@ posthog/hogql/transforms/in_cohort.py:0: error: Item "Expr" of "Expr | Any" has posthog/hogql/transforms/in_cohort.py:0: error: List item 0 has incompatible type "SelectQueryType | None"; expected "SelectQueryType" [list-item] posthog/hogql/transforms/in_cohort.py:0: error: List item 0 has incompatible type "SelectQueryType | None"; expected "SelectQueryType" [list-item] posthog/hogql/database/database.py:0: error: Argument "week_start_day" to "Database" has incompatible type "int | Any | None"; expected "WeekStartDay | None" [arg-type] -posthog/hogql/database/database.py:0: error: "FieldOrTable" has no attribute "fields" [attr-defined] posthog/warehouse/models/datawarehouse_saved_query.py:0: error: Argument 1 to "create_hogql_database" has incompatible type "int | None"; expected "int" [arg-type] posthog/models/feature_flag/flag_matching.py:0: error: Statement is unreachable [unreachable] posthog/models/feature_flag/flag_matching.py:0: error: Value expression in dictionary comprehension has incompatible type "int"; expected type "Literal[0, 1, 2, 3, 4]" [misc] @@ -256,7 +255,13 @@ posthog/hogql/printer.py:0: error: Argument 1 to "_print_identifier" of "_Printe posthog/user_permissions.py:0: error: Incompatible return value type (got "int", expected "Level | None") [return-value] posthog/user_permissions.py:0: error: Incompatible return value type (got "int", expected "Level | None") [return-value] posthog/user_permissions.py:0: error: Incompatible return value type (got "int", expected "RestrictionLevel") [return-value] +posthog/tasks/update_survey_iteration.py:0: error: Incompatible types in assignment (expression has type "ForeignKey[Any, _ST] | Any", variable has type "FeatureFlag | Combinable | None") [assignment] +posthog/tasks/update_survey_iteration.py:0: error: Item "None" of "FeatureFlag | None" has no attribute "filters" [union-attr] +posthog/tasks/update_survey_iteration.py:0: error: Item "None" of "FeatureFlag | None" has no attribute "filters" [union-attr] +posthog/tasks/update_survey_iteration.py:0: error: Item "None" of "FeatureFlag | None" has no attribute "save" [union-attr] posthog/permissions.py:0: error: Argument 2 to "feature_enabled" has incompatible type "str | None"; expected "str" [arg-type] +posthog/models/event/util.py:0: error: Incompatible types in assignment (expression has type "str", variable has type "datetime") [assignment] +posthog/models/event/util.py:0: error: Module has no attribute "utc" [attr-defined] posthog/event_usage.py:0: error: Argument 1 to "capture" has incompatible type "str | None"; expected "str" [arg-type] posthog/event_usage.py:0: error: Argument 1 to "capture" has incompatible type "str | None"; expected "str" [arg-type] posthog/event_usage.py:0: error: Argument 1 to "alias" has incompatible type "str | None"; expected "str" [arg-type] @@ -269,12 +274,6 @@ posthog/event_usage.py:0: error: Argument 1 to "capture" has incompatible type " posthog/event_usage.py:0: error: Argument 1 to "capture" has incompatible type "str | None"; expected "str" [arg-type] posthog/event_usage.py:0: error: Argument 1 to "capture" has incompatible type "str | None"; expected "str" [arg-type] posthog/event_usage.py:0: error: Argument 1 to "capture" has incompatible type "str | None"; expected "str" [arg-type] -posthog/tasks/update_survey_iteration.py:0: error: Incompatible types in assignment (expression has type "ForeignKey[Any, _ST] | Any", variable has type "FeatureFlag | Combinable | None") [assignment] -posthog/tasks/update_survey_iteration.py:0: error: Item "None" of "FeatureFlag | None" has no attribute "filters" [union-attr] -posthog/tasks/update_survey_iteration.py:0: error: Item "None" of "FeatureFlag | None" has no attribute "filters" [union-attr] -posthog/tasks/update_survey_iteration.py:0: error: Item "None" of "FeatureFlag | None" has no attribute "save" [union-attr] -posthog/models/event/util.py:0: error: Incompatible types in assignment (expression has type "str", variable has type "datetime") [assignment] -posthog/models/event/util.py:0: error: Module has no attribute "utc" [attr-defined] posthog/demo/matrix/taxonomy_inference.py:0: error: Name "timezone.datetime" is not defined [name-defined] posthog/demo/matrix/matrix.py:0: error: Name "timezone.datetime" is not defined [name-defined] posthog/demo/matrix/matrix.py:0: error: Name "timezone.datetime" is not defined [name-defined] @@ -381,6 +380,8 @@ posthog/api/feature_flag.py:0: error: Incompatible type for lookup 'pk': (got "s posthog/api/feature_flag.py:0: error: Argument 2 to "get_all_feature_flags" has incompatible type "str | None"; expected "str" [arg-type] posthog/hogql_queries/web_analytics/web_analytics_query_runner.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined] posthog/hogql_queries/web_analytics/web_analytics_query_runner.py:0: error: Argument 1 to "append" of "list" has incompatible type "EventPropertyFilter"; expected "Expr" [arg-type] +posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined] +posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Statement is unreachable [unreachable] posthog/hogql_queries/insights/stickiness_query_runner.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined] posthog/hogql_queries/insights/retention_query_runner.py:0: error: Item "None" of "JoinExpr | None" has no attribute "sample" [union-attr] posthog/hogql_queries/insights/retention_query_runner.py:0: error: Unsupported operand types for - ("int" and "None") [operator] @@ -395,8 +396,6 @@ posthog/hogql_queries/insights/lifecycle_query_runner.py:0: note: "List" is inva posthog/hogql_queries/insights/lifecycle_query_runner.py:0: note: Consider using "Sequence" instead, which is covariant posthog/hogql_queries/insights/lifecycle_query_runner.py:0: error: Item "SelectUnionQuery" of "SelectQuery | SelectUnionQuery" has no attribute "select_from" [union-attr] posthog/hogql_queries/insights/lifecycle_query_runner.py:0: error: Item "None" of "JoinExpr | Any | None" has no attribute "sample" [union-attr] -posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined] -posthog/hogql_queries/insights/trends/trends_query_runner.py:0: error: Statement is unreachable [unreachable] posthog/hogql_queries/insights/funnels/funnels_query_runner.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined] posthog/api/survey.py:0: error: Incompatible types in assignment (expression has type "Any | Sequence[Any] | None", variable has type "Survey | None") [assignment] posthog/api/survey.py:0: error: Item "list[_ErrorFullDetails]" of "_FullDetailDict | list[_ErrorFullDetails] | dict[str, _ErrorFullDetails]" has no attribute "get" [union-attr] @@ -552,23 +551,6 @@ posthog/heatmaps/test/test_heatmaps_api.py:0: error: "HttpResponse" has no attri posthog/heatmaps/test/test_heatmaps_api.py:0: error: "HttpResponse" has no attribute "json" [attr-defined] posthog/api/uploaded_media.py:0: error: Argument 1 to "read_bytes" has incompatible type "str | None"; expected "str" [arg-type] posthog/api/uploaded_media.py:0: error: Argument 1 to "read_bytes" has incompatible type "str | None"; expected "str" [arg-type] -posthog/api/signup.py:0: error: Argument 1 to "create_user" of "UserManager" has incompatible type "str | None"; expected "str" [arg-type] -posthog/api/organization_member.py:0: error: "User" has no attribute "totpdevice_set" [attr-defined] -posthog/api/organization_member.py:0: error: "User" has no attribute "social_auth" [attr-defined] -posthog/api/organization_member.py:0: error: Signature of "update" incompatible with supertype "ModelSerializer" [override] -posthog/api/organization_member.py:0: note: Superclass: -posthog/api/organization_member.py:0: note: def update(self, instance: Any, validated_data: Any) -> Any -posthog/api/organization_member.py:0: note: Subclass: -posthog/api/organization_member.py:0: note: def update(self, updated_membership: Any, validated_data: Any, **kwargs: Any) -> Any -posthog/api/organization_member.py:0: error: Signature of "update" incompatible with supertype "BaseSerializer" [override] -posthog/api/organization_member.py:0: note: Superclass: -posthog/api/organization_member.py:0: note: def update(self, instance: Any, validated_data: Any) -> Any -posthog/api/organization_member.py:0: note: Subclass: -posthog/api/organization_member.py:0: note: def update(self, updated_membership: Any, validated_data: Any, **kwargs: Any) -> Any -posthog/api/organization_feature_flag.py:0: error: Invalid index type "str | None" for "dict[str, int]"; expected type "str" [index] -posthog/api/organization_feature_flag.py:0: error: Invalid index type "str | None" for "dict[str, int]"; expected type "str" [index] -posthog/api/organization_feature_flag.py:0: error: Invalid index type "str | None" for "dict[str, int]"; expected type "str" [index] -posthog/api/notebook.py:0: error: Incompatible types in assignment (expression has type "int", variable has type "str | None") [assignment] posthog/api/test/test_utils.py:0: error: Incompatible types in assignment (expression has type "dict[str, str]", variable has type "QueryDict") [assignment] posthog/api/test/test_survey.py:0: error: Item "None" of "FeatureFlag | None" has no attribute "active" [union-attr] posthog/api/test/test_stickiness.py:0: error: Module has no attribute "utc" [attr-defined] @@ -590,6 +572,23 @@ posthog/api/test/test_personal_api_keys.py:0: error: Item "None" of "str | None" posthog/api/test/test_personal_api_keys.py:0: error: Item "None" of "str | None" has no attribute "startswith" [union-attr] posthog/api/test/test_person.py:0: error: Argument "data" to "get" of "APIClient" has incompatible type "dict[str, object]"; expected "Mapping[str, str | bytes | int | Iterable[str | bytes | int]] | Iterable[tuple[str, str | bytes | int | Iterable[str | bytes | int]]] | None" [arg-type] posthog/api/test/test_organization_domain.py:0: error: Item "None" of "datetime | None" has no attribute "strftime" [union-attr] +posthog/api/signup.py:0: error: Argument 1 to "create_user" of "UserManager" has incompatible type "str | None"; expected "str" [arg-type] +posthog/api/organization_member.py:0: error: "User" has no attribute "totpdevice_set" [attr-defined] +posthog/api/organization_member.py:0: error: "User" has no attribute "social_auth" [attr-defined] +posthog/api/organization_member.py:0: error: Signature of "update" incompatible with supertype "ModelSerializer" [override] +posthog/api/organization_member.py:0: note: Superclass: +posthog/api/organization_member.py:0: note: def update(self, instance: Any, validated_data: Any) -> Any +posthog/api/organization_member.py:0: note: Subclass: +posthog/api/organization_member.py:0: note: def update(self, updated_membership: Any, validated_data: Any, **kwargs: Any) -> Any +posthog/api/organization_member.py:0: error: Signature of "update" incompatible with supertype "BaseSerializer" [override] +posthog/api/organization_member.py:0: note: Superclass: +posthog/api/organization_member.py:0: note: def update(self, instance: Any, validated_data: Any) -> Any +posthog/api/organization_member.py:0: note: Subclass: +posthog/api/organization_member.py:0: note: def update(self, updated_membership: Any, validated_data: Any, **kwargs: Any) -> Any +posthog/api/organization_feature_flag.py:0: error: Invalid index type "str | None" for "dict[str, int]"; expected type "str" [index] +posthog/api/organization_feature_flag.py:0: error: Invalid index type "str | None" for "dict[str, int]"; expected type "str" [index] +posthog/api/organization_feature_flag.py:0: error: Invalid index type "str | None" for "dict[str, int]"; expected type "str" [index] +posthog/api/notebook.py:0: error: Incompatible types in assignment (expression has type "int", variable has type "str | None") [assignment] posthog/warehouse/external_data_source/source.py:0: error: Incompatible types in assignment (expression has type "int", target has type "str") [assignment] posthog/warehouse/external_data_source/source.py:0: error: Incompatible types in assignment (expression has type "int", target has type "str") [assignment] posthog/warehouse/external_data_source/source.py:0: error: Incompatible types in assignment (expression has type "dict[str, Collection[str]]", variable has type "StripeSourcePayload") [assignment] @@ -690,6 +689,21 @@ posthog/hogql/test/test_parse_string_cpp.py:0: error: Unsupported dynamic base c posthog/hogql/database/test/test_view.py:0: error: Argument "dialect" to "print_ast" has incompatible type "str"; expected "Literal['hogql', 'clickhouse']" [arg-type] posthog/hogql/database/test/test_s3_table.py:0: error: Argument "dialect" to "print_ast" has incompatible type "str"; expected "Literal['hogql', 'clickhouse']" [arg-type] posthog/async_migrations/test/test_runner.py:0: error: Item "None" of "datetime | None" has no attribute "day" [union-attr] +posthog/api/test/test_insight.py:0: error: Argument "data" to "get" of "APIClient" has incompatible type "dict[str, object]"; expected "Mapping[str, str | bytes | int | Iterable[str | bytes | int]] | Iterable[tuple[str, str | bytes | int | Iterable[str | bytes | int]]] | None" [arg-type] +posthog/api/test/test_insight.py:0: error: Argument "data" to "get" of "APIClient" has incompatible type "dict[str, object]"; expected "Mapping[str, str | bytes | int | Iterable[str | bytes | int]] | Iterable[tuple[str, str | bytes | int | Iterable[str | bytes | int]]] | None" [arg-type] +posthog/api/test/test_insight.py:0: error: Argument "data" to "get" of "APIClient" has incompatible type "dict[str, object]"; expected "Mapping[str, str | bytes | int | Iterable[str | bytes | int]] | Iterable[tuple[str, str | bytes | int | Iterable[str | bytes | int]]] | None" [arg-type] +posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "tiles" [union-attr] +posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "name" [union-attr] +posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "description" [union-attr] +posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "filters" [union-attr] +posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "tiles" [union-attr] +posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "name" [union-attr] +posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "description" [union-attr] +posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "filters" [union-attr] +posthog/api/test/dashboards/test_dashboard.py:0: error: Value of type variable "_S" of "assertAlmostEqual" of "TestCase" cannot be "datetime | None" [type-var] +posthog/api/test/dashboards/test_dashboard.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "timedelta" [attr-defined] +posthog/api/test/dashboards/test_dashboard.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "timedelta" [attr-defined] +posthog/api/test/dashboards/test_dashboard.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "timedelta" [attr-defined] posthog/api/search.py:0: error: Argument "klass" to "class_queryset" has incompatible type "object"; expected "type[Model]" [arg-type] posthog/api/search.py:0: error: Argument "search_fields" to "class_queryset" has incompatible type "object"; expected "dict[str, str]" [arg-type] posthog/api/search.py:0: error: Argument "extra_fields" to "class_queryset" has incompatible type "object"; expected "dict[Any, Any] | None" [arg-type] @@ -713,21 +727,6 @@ posthog/api/property_definition.py:0: error: Incompatible types in assignment (e posthog/api/property_definition.py:0: error: Item "AnonymousUser" of "User | AnonymousUser" has no attribute "organization" [union-attr] posthog/api/property_definition.py:0: error: Item "None" of "Organization | Any | None" has no attribute "is_feature_available" [union-attr] posthog/api/event.py:0: error: Argument 1 to has incompatible type "*tuple[str, ...]"; expected "type[BaseRenderer]" [arg-type] -posthog/api/test/test_insight.py:0: error: Argument "data" to "get" of "APIClient" has incompatible type "dict[str, object]"; expected "Mapping[str, str | bytes | int | Iterable[str | bytes | int]] | Iterable[tuple[str, str | bytes | int | Iterable[str | bytes | int]]] | None" [arg-type] -posthog/api/test/test_insight.py:0: error: Argument "data" to "get" of "APIClient" has incompatible type "dict[str, object]"; expected "Mapping[str, str | bytes | int | Iterable[str | bytes | int]] | Iterable[tuple[str, str | bytes | int | Iterable[str | bytes | int]]] | None" [arg-type] -posthog/api/test/test_insight.py:0: error: Argument "data" to "get" of "APIClient" has incompatible type "dict[str, object]"; expected "Mapping[str, str | bytes | int | Iterable[str | bytes | int]] | Iterable[tuple[str, str | bytes | int | Iterable[str | bytes | int]]] | None" [arg-type] -posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "tiles" [union-attr] -posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "name" [union-attr] -posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "description" [union-attr] -posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "filters" [union-attr] -posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "tiles" [union-attr] -posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "name" [union-attr] -posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "description" [union-attr] -posthog/api/test/test_feature_flag.py:0: error: Item "None" of "Dashboard | None" has no attribute "filters" [union-attr] -posthog/api/test/dashboards/test_dashboard.py:0: error: Value of type variable "_S" of "assertAlmostEqual" of "TestCase" cannot be "datetime | None" [type-var] -posthog/api/test/dashboards/test_dashboard.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "timedelta" [attr-defined] -posthog/api/test/dashboards/test_dashboard.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "timedelta" [attr-defined] -posthog/api/test/dashboards/test_dashboard.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "timedelta" [attr-defined] posthog/admin/inlines/plugin_attachment_inline.py:0: error: Signature of "has_add_permission" incompatible with supertype "BaseModelAdmin" [override] posthog/admin/inlines/plugin_attachment_inline.py:0: note: Superclass: posthog/admin/inlines/plugin_attachment_inline.py:0: note: def has_add_permission(self, request: HttpRequest) -> bool @@ -767,10 +766,13 @@ posthog/temporal/tests/batch_exports/test_batch_exports.py:0: error: TypedDict k posthog/session_recordings/session_recording_api.py:0: error: Argument "team_id" to "get_realtime_snapshots" has incompatible type "int"; expected "str" [arg-type] posthog/session_recordings/session_recording_api.py:0: error: Value of type variable "SupportsRichComparisonT" of "sorted" cannot be "str | None" [type-var] posthog/session_recordings/session_recording_api.py:0: error: Argument 1 to "get" of "dict" has incompatible type "str | None"; expected "str" [arg-type] +posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] +posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] +posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] posthog/queries/app_metrics/historical_exports.py:0: error: Argument 1 to "loads" has incompatible type "str | None"; expected "str | bytes | bytearray" [arg-type] -posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] -posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] -posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] +posthog/api/test/test_decide.py:0: error: Item "None" of "User | None" has no attribute "toolbar_mode" [union-attr] +posthog/api/test/test_decide.py:0: error: Item "None" of "User | None" has no attribute "save" [union-attr] +posthog/api/test/test_authentication.py:0: error: Module has no attribute "utc" [attr-defined] posthog/api/plugin.py:0: error: Item "None" of "Team | None" has no attribute "organization" [union-attr] posthog/api/plugin.py:0: error: Item "None" of "Team | None" has no attribute "id" [union-attr] posthog/api/plugin.py:0: error: Item "None" of "Team | None" has no attribute "organization" [union-attr] @@ -784,9 +786,6 @@ posthog/api/plugin.py:0: error: Incompatible type for "file_size" of "PluginAtta posthog/api/plugin.py:0: error: Item "None" of "IO[Any] | None" has no attribute "read" [union-attr] posthog/api/plugin.py:0: error: Item "None" of "Team | None" has no attribute "organization" [union-attr] posthog/api/plugin.py:0: error: Item "None" of "Team | None" has no attribute "id" [union-attr] -posthog/api/test/test_decide.py:0: error: Item "None" of "User | None" has no attribute "toolbar_mode" [union-attr] -posthog/api/test/test_decide.py:0: error: Item "None" of "User | None" has no attribute "save" [union-attr] -posthog/api/test/test_authentication.py:0: error: Module has no attribute "utc" [attr-defined] posthog/admin/admins/plugin_config_admin.py:0: error: Item "None" of "Team | None" has no attribute "name" [union-attr] posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_execute_calls" (hint: "_execute_calls: list[] = ...") [var-annotated] posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_execute_async_calls" (hint: "_execute_async_calls: list[] = ...") [var-annotated] @@ -799,7 +798,6 @@ posthog/temporal/data_imports/workflow_activities/import_data.py:0: error: Argum posthog/temporal/data_imports/workflow_activities/import_data.py:0: error: Argument "source_type" to "sql_source_for_type" has incompatible type "str"; expected "Type" [arg-type] posthog/migrations/0237_remove_timezone_from_teams.py:0: error: Argument 2 to "RunPython" has incompatible type "Callable[[Migration, Any], None]"; expected "_CodeCallable | None" [arg-type] posthog/migrations/0228_fix_tile_layouts.py:0: error: Argument 2 to "RunPython" has incompatible type "Callable[[Migration, Any], None]"; expected "_CodeCallable | None" [arg-type] -posthog/api/query.py:0: error: Statement is unreachable [unreachable] posthog/api/plugin_log_entry.py:0: error: Name "timezone.datetime" is not defined [name-defined] posthog/api/plugin_log_entry.py:0: error: Module "django.utils.timezone" does not explicitly export attribute "datetime" [attr-defined] posthog/api/plugin_log_entry.py:0: error: Name "timezone.datetime" is not defined [name-defined] @@ -820,18 +818,6 @@ posthog/api/test/batch_exports/conftest.py:0: error: Argument "activities" to "T posthog/temporal/tests/data_imports/test_end_to_end.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/api/test/test_team.py:0: error: "HttpResponse" has no attribute "json" [attr-defined] posthog/api/test/test_team.py:0: error: "HttpResponse" has no attribute "json" [attr-defined] -posthog/api/test/test_capture.py:0: error: Statement is unreachable [unreachable] -posthog/api/test/test_capture.py:0: error: Incompatible return value type (got "_MonkeyPatchedWSGIResponse", expected "HttpResponse") [return-value] -posthog/api/test/test_capture.py:0: error: Module has no attribute "utc" [attr-defined] -posthog/api/test/test_capture.py:0: error: Unpacked dict entry 0 has incompatible type "Collection[str]"; expected "SupportsKeysAndGetItem[str, dict[Never, Never]]" [dict-item] -posthog/api/test/test_capture.py:0: error: Unpacked dict entry 0 has incompatible type "Collection[str]"; expected "SupportsKeysAndGetItem[str, dict[Never, Never]]" [dict-item] -posthog/api/test/test_capture.py:0: error: Unpacked dict entry 0 has incompatible type "Collection[str]"; expected "SupportsKeysAndGetItem[str, dict[Never, Never]]" [dict-item] -posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] -posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] -posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] -posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] -posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] -posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] posthog/test/test_middleware.py:0: error: Incompatible types in assignment (expression has type "_MonkeyPatchedWSGIResponse", variable has type "_MonkeyPatchedResponse") [assignment] posthog/management/commands/test/test_create_batch_export_from_app.py:0: error: Incompatible return value type (got "dict[str, Collection[str]]", expected "dict[str, str]") [return-value] posthog/management/commands/test/test_create_batch_export_from_app.py:0: error: Incompatible types in assignment (expression has type "dict[str, Collection[str]]", variable has type "dict[str, str]") [assignment] @@ -874,3 +860,16 @@ posthog/api/test/batch_exports/test_update.py:0: error: Value of type "BatchExpo posthog/api/test/batch_exports/test_update.py:0: error: Value of type "BatchExport" is not indexable [index] posthog/api/test/batch_exports/test_update.py:0: error: Value of type "BatchExport" is not indexable [index] posthog/api/test/batch_exports/test_pause.py:0: error: "batch_export_delete_schedule" does not return a value (it only ever returns None) [func-returns-value] +posthog/api/query.py:0: error: Statement is unreachable [unreachable] +posthog/api/test/test_capture.py:0: error: Statement is unreachable [unreachable] +posthog/api/test/test_capture.py:0: error: Incompatible return value type (got "_MonkeyPatchedWSGIResponse", expected "HttpResponse") [return-value] +posthog/api/test/test_capture.py:0: error: Module has no attribute "utc" [attr-defined] +posthog/api/test/test_capture.py:0: error: Unpacked dict entry 0 has incompatible type "Collection[str]"; expected "SupportsKeysAndGetItem[str, dict[Never, Never]]" [dict-item] +posthog/api/test/test_capture.py:0: error: Unpacked dict entry 0 has incompatible type "Collection[str]"; expected "SupportsKeysAndGetItem[str, dict[Never, Never]]" [dict-item] +posthog/api/test/test_capture.py:0: error: Unpacked dict entry 0 has incompatible type "Collection[str]"; expected "SupportsKeysAndGetItem[str, dict[Never, Never]]" [dict-item] +posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] +posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] +posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] +posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] +posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] +posthog/api/test/test_capture.py:0: error: Dict entry 0 has incompatible type "str": "float"; expected "str": "int" [dict-item] diff --git a/posthog/api/query.py b/posthog/api/query.py index 1d3bf3f67ed..d4d45ce66a2 100644 --- a/posthog/api/query.py +++ b/posthog/api/query.py @@ -1,4 +1,3 @@ -import json import re import uuid @@ -12,7 +11,8 @@ from rest_framework.request import Request from rest_framework.response import Response from sentry_sdk import capture_exception, set_tag -from ee.hogai.generate_trends_agent import Conversation, GenerateTrendsAgent +from ee.hogai.assistant import Assistant +from ee.hogai.utils import Conversation from posthog.api.documentation import extend_schema from posthog.api.mixins import PydanticModelMixin from posthog.api.monitoring import Feature, monitor @@ -37,11 +37,11 @@ from posthog.models.user import User from posthog.rate_limit import ( AIBurstRateThrottle, AISustainedRateThrottle, - HogQLQueryThrottle, ClickHouseBurstRateThrottle, ClickHouseSustainedRateThrottle, + HogQLQueryThrottle, ) -from posthog.schema import QueryRequest, QueryResponseAlternative, QueryStatusResponse +from posthog.schema import HumanMessage, QueryRequest, QueryResponseAlternative, QueryStatusResponse class ServerSentEventRenderer(BaseRenderer): @@ -179,23 +179,21 @@ class QueryViewSet(TeamAndOrgViewSetMixin, PydanticModelMixin, viewsets.ViewSet) def chat(self, request: Request, *args, **kwargs): assert request.user is not None validated_body = Conversation.model_validate(request.data) - chain = GenerateTrendsAgent(self.team).bootstrap(validated_body.messages) + assistant = Assistant(self.team) def generate(): last_message = None - for message in chain.stream({"question": validated_body.messages[0].content}): - if message: - last_message = message[0].model_dump_json() - yield last_message + for message in assistant.stream(validated_body): + last_message = message + yield last_message - if not last_message: - yield json.dumps({"reasoning_steps": ["Schema validation failed"]}) - - report_user_action( - request.user, # type: ignore - "chat with ai", - {"prompt": validated_body.messages[-1].content, "response": last_message}, - ) + human_message = validated_body.messages[-1].root + if isinstance(human_message, HumanMessage): + report_user_action( + request.user, # type: ignore + "chat with ai", + {"prompt": human_message.content, "response": last_message}, + ) return StreamingHttpResponse(generate(), content_type=ServerSentEventRenderer.media_type) diff --git a/posthog/api/test/__snapshots__/test_api_docs.ambr b/posthog/api/test/__snapshots__/test_api_docs.ambr index 6f47bd73221..6ef31c65301 100644 --- a/posthog/api/test/__snapshots__/test_api_docs.ambr +++ b/posthog/api/test/__snapshots__/test_api_docs.ambr @@ -97,6 +97,9 @@ '/home/runner/work/posthog/posthog/posthog/api/survey.py: Warning [SurveyViewSet > SurveySerializer]: unable to resolve type hint for function "get_conditions". Consider using a type hint or @extend_schema_field. Defaulting to string.', '/home/runner/work/posthog/posthog/posthog/api/web_experiment.py: Warning [WebExperimentViewSet]: could not derive type of path parameter "project_id" because model "posthog.models.web_experiment.WebExperiment" contained no such field. Consider annotating parameter with @extend_schema. Defaulting to "string".', 'Warning: encountered multiple names for the same choice set (HrefMatchingEnum). This may be unwanted even though the generated schema is technically correct. Add an entry to ENUM_NAME_OVERRIDES to fix the naming.', + 'Warning: enum naming encountered a non-optimally resolvable collision for fields named "kind". The same name has been used for multiple choice sets in multiple components. The collision was resolved with "Kind069Enum". add an entry to ENUM_NAME_OVERRIDES to fix the naming.', + 'Warning: enum naming encountered a non-optimally resolvable collision for fields named "kind". The same name has been used for multiple choice sets in multiple components. The collision was resolved with "KindCfaEnum". add an entry to ENUM_NAME_OVERRIDES to fix the naming.', + 'Warning: enum naming encountered a non-optimally resolvable collision for fields named "type". The same name has been used for multiple choice sets in multiple components. The collision was resolved with "TypeF73Enum". add an entry to ENUM_NAME_OVERRIDES to fix the naming.', 'Warning: encountered multiple names for the same choice set (EffectivePrivilegeLevelEnum). This may be unwanted even though the generated schema is technically correct. Add an entry to ENUM_NAME_OVERRIDES to fix the naming.', 'Warning: encountered multiple names for the same choice set (MembershipLevelEnum). This may be unwanted even though the generated schema is technically correct. Add an entry to ENUM_NAME_OVERRIDES to fix the naming.', 'Warning: operationId "environments_app_metrics_historical_exports_retrieve" has collisions [(\'/api/environments/{project_id}/app_metrics/{plugin_config_id}/historical_exports/\', \'get\'), (\'/api/environments/{project_id}/app_metrics/{plugin_config_id}/historical_exports/{id}/\', \'get\')]. resolving with numeral suffixes.', diff --git a/posthog/celery.py b/posthog/celery.py index ac7f5c90138..f6c7aa9d14b 100644 --- a/posthog/celery.py +++ b/posthog/celery.py @@ -81,13 +81,12 @@ task_timings: dict[str, float] = {} @setup_logging.connect def receiver_setup_logging(loglevel, logfile, format, colorize, **kwargs) -> None: - import logging + from logging import config as logging_config from posthog.settings import logs # following instructions from here https://django-structlog.readthedocs.io/en/latest/celery.html - # mypy thinks that there is no `logging.config` but there is ¯\_(ツ)_/¯ - logging.config.dictConfig(logs.LOGGING) # type: ignore + logging_config.dictConfig(logs.LOGGING) @receiver(signals.bind_extra_task_metadata) diff --git a/posthog/hogql/database/database.py b/posthog/hogql/database/database.py index e1e3fd26f82..5f5bb9c1a8b 100644 --- a/posthog/hogql/database/database.py +++ b/posthog/hogql/database/database.py @@ -1,32 +1,32 @@ import dataclasses from collections.abc import Callable -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, TypeAlias, cast, Union +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, TypeAlias, Union, cast from zoneinfo import ZoneInfo, ZoneInfoNotFoundError from django.db.models import Q -from pydantic import ConfigDict, BaseModel +from pydantic import BaseModel, ConfigDict from sentry_sdk import capture_exception from posthog.hogql import ast from posthog.hogql.context import HogQLContext from posthog.hogql.database.models import ( + BooleanDatabaseField, + DatabaseField, + DateDatabaseField, + DateTimeDatabaseField, + ExpressionField, FieldOrTable, FieldTraverser, - SavedQuery, - StringDatabaseField, - DatabaseField, - IntegerDatabaseField, - DateTimeDatabaseField, - BooleanDatabaseField, - StringJSONDatabaseField, - StringArrayDatabaseField, - LazyJoin, - VirtualTable, - Table, - DateDatabaseField, FloatDatabaseField, FunctionCallTable, - ExpressionField, + IntegerDatabaseField, + LazyJoin, + SavedQuery, + StringArrayDatabaseField, + StringDatabaseField, + StringJSONDatabaseField, + Table, + VirtualTable, ) from posthog.hogql.database.schema.channel_type import create_initial_channel_type, create_initial_domain_type from posthog.hogql.database.schema.cohort_people import CohortPeople, RawCohortPeople @@ -34,9 +34,9 @@ from posthog.hogql.database.schema.events import EventsTable from posthog.hogql.database.schema.groups import GroupsTable, RawGroupsTable from posthog.hogql.database.schema.heatmaps import HeatmapsTable from posthog.hogql.database.schema.log_entries import ( + BatchExportLogEntriesTable, LogEntriesTable, ReplayConsoleLogsLogEntriesTable, - BatchExportLogEntriesTable, ) from posthog.hogql.database.schema.numbers import NumbersTable from posthog.hogql.database.schema.person_distinct_id_overrides import ( @@ -60,8 +60,8 @@ from posthog.hogql.database.schema.session_replay_events import ( ) from posthog.hogql.database.schema.sessions_v1 import RawSessionsTableV1, SessionsTableV1 from posthog.hogql.database.schema.sessions_v2 import ( - SessionsTableV2, RawSessionsTableV2, + SessionsTableV2, join_events_table_to_sessions_table_v2, ) from posthog.hogql.database.schema.static_cohort_people import StaticCohortPeople @@ -213,13 +213,13 @@ def _use_person_id_from_person_overrides(database: Database) -> None: def create_hogql_database( team_id: int, modifiers: Optional[HogQLQueryModifiers] = None, team_arg: Optional["Team"] = None ) -> Database: - from posthog.models import Team from posthog.hogql.database.s3_table import S3Table from posthog.hogql.query import create_default_modifiers_for_team + from posthog.models import Team from posthog.warehouse.models import ( - DataWarehouseTable, - DataWarehouseSavedQuery, DataWarehouseJoin, + DataWarehouseSavedQuery, + DataWarehouseTable, ) team = team_arg or Team.objects.get(pk=team_id) @@ -238,7 +238,7 @@ def create_hogql_database( elif modifiers.personsOnEventsMode == PersonsOnEventsMode.PERSON_ID_OVERRIDE_PROPERTIES_ON_EVENTS: _use_person_id_from_person_overrides(database) _use_person_properties_from_events(database) - database.events.fields["poe"].fields["id"] = database.events.fields["person_id"] + cast(VirtualTable, database.events.fields["poe"]).fields["id"] = database.events.fields["person_id"] elif modifiers.personsOnEventsMode == PersonsOnEventsMode.PERSON_ID_OVERRIDE_PROPERTIES_JOINED: _use_person_id_from_person_overrides(database) @@ -268,14 +268,14 @@ def create_hogql_database( join_table=sessions, join_function=join_replay_table_to_sessions_table_v2, ) - replay_events.fields["events"].join_table = events + cast(LazyJoin, replay_events.fields["events"]).join_table = events raw_replay_events = database.raw_session_replay_events raw_replay_events.fields["session"] = LazyJoin( from_field=["session_id"], join_table=sessions, join_function=join_replay_table_to_sessions_table_v2, ) - raw_replay_events.fields["events"].join_table = events + cast(LazyJoin, raw_replay_events.fields["events"]).join_table = events database.persons.fields["$virt_initial_referring_domain_type"] = create_initial_domain_type( "$virt_initial_referring_domain_type" diff --git a/posthog/hogql/database/test/test_s3_table.py b/posthog/hogql/database/test/test_s3_table.py index 7211f75a5f8..7bc2f18506b 100644 --- a/posthog/hogql/database/test/test_s3_table.py +++ b/posthog/hogql/database/test/test_s3_table.py @@ -2,20 +2,21 @@ from posthog.hogql.constants import MAX_SELECT_RETURNED_ROWS from posthog.hogql.context import HogQLContext from posthog.hogql.database.database import create_hogql_database from posthog.hogql.database.s3_table import build_function_call +from posthog.hogql.database.test.tables import create_aapl_stock_s3_table +from posthog.hogql.errors import ExposedHogQLError from posthog.hogql.parser import parse_select from posthog.hogql.printer import print_ast from posthog.hogql.query import create_default_modifiers_for_team from posthog.test.base import BaseTest -from posthog.hogql.database.test.tables import create_aapl_stock_s3_table -from posthog.hogql.errors import ExposedHogQLError from posthog.warehouse.models.table import DataWarehouseTable class TestS3Table(BaseTest): def _init_database(self): self.database = create_hogql_database(self.team.pk) - self.database.aapl_stock = create_aapl_stock_s3_table() - self.database.aapl_stock_2 = create_aapl_stock_s3_table(name="aapl_stock_2") + self.database.add_warehouse_tables( + aapl_stock=create_aapl_stock_s3_table(), aapl_stock_2=create_aapl_stock_s3_table(name="aapl_stock_2") + ) self.context = HogQLContext( team_id=self.team.pk, enable_select_queries=True, diff --git a/posthog/hogql/database/test/test_view.py b/posthog/hogql/database/test/test_view.py index 747e80cc4e6..26d0e6382bf 100644 --- a/posthog/hogql/database/test/test_view.py +++ b/posthog/hogql/database/test/test_view.py @@ -1,15 +1,15 @@ from posthog.hogql.context import HogQLContext from posthog.hogql.database.database import create_hogql_database +from posthog.hogql.database.test.tables import ( + create_aapl_stock_s3_table, + create_aapl_stock_table_self_referencing, + create_aapl_stock_table_view, + create_nested_aapl_stock_view, +) from posthog.hogql.parser import parse_select from posthog.hogql.printer import print_ast from posthog.hogql.query import create_default_modifiers_for_team from posthog.test.base import BaseTest -from posthog.hogql.database.test.tables import ( - create_aapl_stock_table_view, - create_aapl_stock_s3_table, - create_nested_aapl_stock_view, - create_aapl_stock_table_self_referencing, -) class TestView(BaseTest): @@ -17,10 +17,12 @@ class TestView(BaseTest): def _init_database(self): self.database = create_hogql_database(self.team.pk) - self.database.aapl_stock_view = create_aapl_stock_table_view() - self.database.aapl_stock = create_aapl_stock_s3_table() - self.database.aapl_stock_nested_view = create_nested_aapl_stock_view() - self.database.aapl_stock_self = create_aapl_stock_table_self_referencing() + self.database.add_views( + aapl_stock_view=create_aapl_stock_table_view(), aapl_stock_nested_view=create_nested_aapl_stock_view() + ) + self.database.add_warehouse_tables( + aapl_stock=create_aapl_stock_s3_table(), aapl_stock_self=create_aapl_stock_table_self_referencing() + ) self.context = HogQLContext( team_id=self.team.pk, enable_select_queries=True, diff --git a/posthog/hogql_queries/insights/trends/trends_query_runner.py b/posthog/hogql_queries/insights/trends/trends_query_runner.py index e6675a26155..44cbf4cd5da 100644 --- a/posthog/hogql_queries/insights/trends/trends_query_runner.py +++ b/posthog/hogql_queries/insights/trends/trends_query_runner.py @@ -1073,7 +1073,7 @@ class TrendsQueryRunner(QueryRunner): return res_breakdown - def _is_other_breakdown(self, breakdown: BreakdownItem | list[BreakdownItem]) -> bool: + def _is_other_breakdown(self, breakdown: str | list[str]) -> bool: return ( breakdown == BREAKDOWN_OTHER_STRING_LABEL or isinstance(breakdown, list) diff --git a/posthog/schema.py b/posthog/schema.py index 8188cf8a021..3c7f04dd893 100644 --- a/posthog/schema.py +++ b/posthog/schema.py @@ -32,7 +32,7 @@ class ActorsPropertyTaxonomyResponse(BaseModel): extra="forbid", ) sample_count: int - sample_values: list[str] + sample_values: list[Union[str, float, bool, int]] class AggregationAxisFormat(StrEnum): @@ -63,6 +63,20 @@ class AlertState(StrEnum): SNOOZED = "Snoozed" +class AssistantMessage(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + content: str + type: Literal["ai"] = "ai" + + +class AssistantMessageType(StrEnum): + HUMAN = "human" + AI = "ai" + AI_VIZ = "ai/viz" + + class Kind(StrEnum): METHOD = "Method" FUNCTION = "Function" @@ -751,6 +765,14 @@ class HogQueryResponse(BaseModel): stdout: Optional[str] = None +class HumanMessage(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + content: str + type: Literal["human"] = "human" + + class Compare(StrEnum): CURRENT = "current" PREVIOUS = "previous" @@ -4972,7 +4994,6 @@ class AIActionsNode(BaseModel): EventPropertyFilter, PersonPropertyFilter, SessionPropertyFilter, - CohortPropertyFilter, GroupPropertyFilter, FeaturePropertyFilter, ] @@ -5000,7 +5021,6 @@ class AIActionsNode(BaseModel): EventPropertyFilter, PersonPropertyFilter, SessionPropertyFilter, - CohortPropertyFilter, GroupPropertyFilter, FeaturePropertyFilter, ] @@ -5021,7 +5041,6 @@ class AIEventsNode(BaseModel): EventPropertyFilter, PersonPropertyFilter, SessionPropertyFilter, - CohortPropertyFilter, GroupPropertyFilter, FeaturePropertyFilter, ] @@ -5049,7 +5068,6 @@ class AIEventsNode(BaseModel): EventPropertyFilter, PersonPropertyFilter, SessionPropertyFilter, - CohortPropertyFilter, GroupPropertyFilter, FeaturePropertyFilter, ] @@ -5181,7 +5199,6 @@ class ExperimentalAITrendsQuery(BaseModel): EventPropertyFilter, PersonPropertyFilter, SessionPropertyFilter, - CohortPropertyFilter, GroupPropertyFilter, FeaturePropertyFilter, ] @@ -5427,6 +5444,16 @@ class TrendsQuery(BaseModel): trendsFilter: Optional[TrendsFilter] = Field(default=None, description="Properties specific to the trends insight") +class VisualizationMessage(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + answer: Optional[ExperimentalAITrendsQuery] = None + plan: Optional[str] = None + reasoning_steps: Optional[list[str]] = None + type: Literal["ai/viz"] = "ai/viz" + + class ErrorTrackingQuery(BaseModel): model_config = ConfigDict( extra="forbid", @@ -5923,6 +5950,10 @@ class QueryResponseAlternative( ] +class RootAssistantMessage(RootModel[Union[VisualizationMessage, AssistantMessage, HumanMessage]]): + root: Union[VisualizationMessage, AssistantMessage, HumanMessage] + + class DatabaseSchemaQueryResponse(BaseModel): model_config = ConfigDict( extra="forbid", diff --git a/requirements-dev.in b/requirements-dev.in index f1158cbc549..eab4262ae6d 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -15,7 +15,7 @@ ruff~=0.6.1 mypy~=1.11.1 mypy-baseline~=0.7.0 mypy-extensions==1.0.0 -datamodel-code-generator==0.25.6 +datamodel-code-generator==0.26.1 djangorestframework-stubs~=3.14.5 django-stubs==5.0.4 Faker==17.5.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index 72fe00092b6..cd0caa1e8d1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,7 +10,7 @@ aiosignal==1.2.0 # via # -c requirements.txt # aiohttp -annotated-types==0.5.0 +annotated-types==0.7.0 # via # -c requirements.txt # pydantic @@ -69,7 +69,7 @@ cryptography==39.0.2 # via # -c requirements.txt # types-paramiko -datamodel-code-generator==0.25.6 +datamodel-code-generator==0.26.1 # via -r requirements-dev.in django==4.2.15 # via @@ -113,7 +113,7 @@ genson==1.2.2 # via datamodel-code-generator icdiff==2.0.5 # via pytest-icdiff -idna==2.8 +idna==3.10 # via # -c requirements.txt # email-validator @@ -199,11 +199,11 @@ pycparser==2.20 # via # -c requirements.txt # cffi -pydantic==2.5.3 +pydantic==2.9.2 # via # -c requirements.txt # datamodel-code-generator -pydantic-core==2.14.6 +pydantic-core==2.23.4 # via # -c requirements.txt # pydantic diff --git a/requirements.in b/requirements.in index 45151b4d5d3..619c6a06cda 100644 --- a/requirements.in +++ b/requirements.in @@ -44,9 +44,11 @@ gunicorn==20.1.0 infi-clickhouse-orm@ git+https://github.com/PostHog/infi.clickhouse_orm@9578c79f29635ee2c1d01b7979e89adab8383de2 kafka-python==2.0.2 kombu==5.3.2 -langchain==0.2.15 -langchain-openai==0.1.23 -langsmith==0.1.106 +langchain==0.3.3 +langchain-openai==0.2.2 +langfuse==2.52.1 +langgraph==0.2.34 +langsmith==0.1.132 lzstring==1.0.4 natsort==8.4.0 nanoid==2.0.0 @@ -64,7 +66,7 @@ pymssql==2.3.0 PyMySQL==1.1.1 psycopg[binary]==3.1.20 pyarrow==17.0.0 -pydantic==2.5.3 +pydantic==2.9.2 pyjwt==2.4.0 pyodbc==5.1.0 python-dateutil>=2.8.2 @@ -100,8 +102,8 @@ mimesis==5.2.1 more-itertools==9.0.0 django-two-factor-auth==1.14.0 phonenumberslite==8.13.6 -openai==1.43.0 -tiktoken==0.7.0 +openai==1.51.2 +tiktoken==0.8.0 nh3==0.2.14 hogql-parser==1.0.45 zxcvbn==4.4.28 diff --git a/requirements.txt b/requirements.txt index b9fdf3b435d..82798318498 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,13 +21,14 @@ aiosignal==1.2.0 # via aiohttp amqp==5.1.1 # via kombu -annotated-types==0.5.0 +annotated-types==0.7.0 # via pydantic antlr4-python3-runtime==4.13.1 # via -r requirements.in -anyio==4.2.0 +anyio==4.6.2.post1 # via # httpx + # langfuse # openai asgiref==3.7.2 # via django @@ -48,7 +49,9 @@ attrs==23.2.0 # trio # zeep backoff==2.2.1 - # via posthoganalytics + # via + # langfuse + # posthoganalytics bcrypt==4.1.3 # via paramiko billiard==4.1.0 @@ -264,8 +267,6 @@ googleapis-common-protos==1.60.0 # via # google-api-core # grpcio-status -greenlet==3.1.1 - # via sqlalchemy grpcio==1.57.0 # via # google-api-core @@ -287,14 +288,16 @@ httpcore==1.0.2 # via httpx httpx==0.26.0 # via + # langfuse # langsmith # openai humanize==4.9.0 # via dlt -idna==2.8 +idna==3.10 # via # anyio # httpx + # langfuse # requests # snowflake-connector-python # trio @@ -336,18 +339,26 @@ kombu==5.3.2 # via # -r requirements.in # celery -langchain==0.2.15 +langchain==0.3.3 # via -r requirements.in -langchain-core==0.2.36 +langchain-core==0.3.10 # via # langchain # langchain-openai # langchain-text-splitters -langchain-openai==0.1.23 + # langgraph + # langgraph-checkpoint +langchain-openai==0.2.2 # via -r requirements.in -langchain-text-splitters==0.2.2 +langchain-text-splitters==0.3.0 # via langchain -langsmith==0.1.106 +langfuse==2.52.1 + # via -r requirements.in +langgraph==0.2.34 + # via -r requirements.in +langgraph-checkpoint==2.0.1 + # via langgraph +langsmith==0.1.132 # via # -r requirements.in # langchain @@ -373,6 +384,8 @@ more-itertools==9.0.0 # via # -r requirements.in # simple-salesforce +msgpack==1.1.0 + # via langgraph-checkpoint multidict==6.0.2 # via # aiohttp @@ -395,7 +408,7 @@ oauthlib==3.1.0 # via # requests-oauthlib # social-auth-core -openai==1.43.0 +openai==1.51.2 # via # -r requirements.in # langchain-openai @@ -415,6 +428,7 @@ packaging==24.1 # dlt # google-cloud-bigquery # langchain-core + # langfuse # snowflake-connector-python # sqlalchemy-bigquery # webdriver-manager @@ -480,14 +494,15 @@ pyasn1-modules==0.3.0 # via google-auth pycparser==2.20 # via cffi -pydantic==2.5.3 +pydantic==2.9.2 # via # -r requirements.in # langchain # langchain-core + # langfuse # langsmith # openai -pydantic-core==2.14.6 +pydantic-core==2.23.4 # via pydantic pyjwt==2.4.0 # via @@ -587,7 +602,9 @@ requests-file==2.1.0 requests-oauthlib==1.3.0 # via social-auth-core requests-toolbelt==1.0.0 - # via zeep + # via + # langsmith + # zeep requirements-parser==0.5.0 # via dlt retry==0.9.2 @@ -689,7 +706,7 @@ tenacity==8.2.3 # langchain-core threadpoolctl==3.3.0 # via scikit-learn -tiktoken==0.7.0 +tiktoken==0.8.0 # via # -r requirements.in # langchain-openai @@ -765,7 +782,9 @@ wheel==0.42.0 whitenoise==6.5.0 # via -r requirements.in wrapt==1.15.0 - # via aiobotocore + # via + # aiobotocore + # langfuse wsproto==1.2.0 # via trio-websocket xmlsec==1.3.13