0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-21 13:39:22 +01:00

feat(hogql): placeholder expressions (attempt 2) (#25216)

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Marius Andra 2024-09-26 12:33:37 +02:00 committed by GitHub
parent f2f10a9fe5
commit 9ff42460b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 99 additions and 59 deletions

View File

@ -578,7 +578,7 @@ void hogqlparserParserInitialize() {
1257,5,106,0,0,1254,1257,3,148,74,0,1255,1257,3,150,75,0,1256,1253,1,
0,0,0,1256,1254,1,0,0,0,1256,1255,1,0,0,0,1257,157,1,0,0,0,1258,1259,
3,162,81,0,1259,1260,5,123,0,0,1260,1261,3,144,72,0,1261,159,1,0,0,0,
1262,1263,5,129,0,0,1263,1264,3,130,65,0,1264,1265,5,148,0,0,1265,161,
1262,1263,5,129,0,0,1263,1264,3,116,58,0,1264,1265,5,148,0,0,1265,161,
1,0,0,0,1266,1269,5,111,0,0,1267,1269,3,164,82,0,1268,1266,1,0,0,0,1268,
1267,1,0,0,0,1269,163,1,0,0,0,1270,1274,5,143,0,0,1271,1273,3,166,83,
0,1272,1271,1,0,0,0,1273,1276,1,0,0,0,1274,1272,1,0,0,0,1274,1275,1,0,
@ -11599,8 +11599,8 @@ tree::TerminalNode* HogQLParser::PlaceholderContext::LBRACE() {
return getToken(HogQLParser::LBRACE, 0);
}
HogQLParser::NestedIdentifierContext* HogQLParser::PlaceholderContext::nestedIdentifier() {
return getRuleContext<HogQLParser::NestedIdentifierContext>(0);
HogQLParser::ColumnExprContext* HogQLParser::PlaceholderContext::columnExpr() {
return getRuleContext<HogQLParser::ColumnExprContext>(0);
}
tree::TerminalNode* HogQLParser::PlaceholderContext::RBRACE() {
@ -11636,7 +11636,7 @@ HogQLParser::PlaceholderContext* HogQLParser::placeholder() {
setState(1262);
match(HogQLParser::LBRACE);
setState(1263);
nestedIdentifier();
columnExpr(0);
setState(1264);
match(HogQLParser::RBRACE);

View File

@ -2407,7 +2407,7 @@ public:
PlaceholderContext(antlr4::ParserRuleContext *parent, size_t invokingState);
virtual size_t getRuleIndex() const override;
antlr4::tree::TerminalNode *LBRACE();
NestedIdentifierContext *nestedIdentifier();
ColumnExprContext *columnExpr();
antlr4::tree::TerminalNode *RBRACE();

File diff suppressed because one or more lines are too long

View File

@ -2570,11 +2570,7 @@ class HogQLParseTreeConverter : public HogQLParserBaseVisitor {
}
VISIT(Placeholder) {
auto nested_identifier_ctx = ctx->nestedIdentifier();
vector<string> nested =
nested_identifier_ctx ? any_cast<vector<string>>(visit(nested_identifier_ctx)) : vector<string>();
RETURN_NEW_AST_NODE("Placeholder", "{s:N}", "chain", X_PyList_FromStrings(nested));
RETURN_NEW_AST_NODE("Placeholder", "{s:N}", "expr", visitAsPyObject(ctx->columnExpr()));
}
VISIT_UNSUPPORTED(EnumValue)

View File

@ -17,7 +17,7 @@ archs = [ # We could also build a universal wheel, but separate ones are lighter
"arm64",
]
before-build = [ # We need to install the libraries for each architecture separately
"brew uninstall --force boost antlr4-cpp-runtime",
"brew uninstall --ignore-dependencies --force boost antlr4-cpp-runtime",
"brew fetch --force --bottle-tag=${ARCHFLAGS##'-arch '}_monterey boost antlr4-cpp-runtime",
"brew install $(brew --cache --bottle-tag=${ARCHFLAGS##'-arch '}_monterey boost antlr4-cpp-runtime)",
]

View File

@ -32,7 +32,7 @@ module = Extension(
setup(
name="hogql_parser",
version="1.0.40",
version="1.0.45",
url="https://github.com/PostHog/posthog/tree/master/hogql_parser",
author="PostHog Inc.",
author_email="hey@posthog.com",

View File

@ -2852,7 +2852,7 @@ class TestInsight(ClickhouseTestMixin, APIBaseTest, QueryMatchingTest):
)
self.assertEqual(
response_placeholder.json(),
self.validation_error_response("Placeholders, such as {team_id}, are not supported in this context"),
self.validation_error_response("Unresolved placeholder: {team_id}"),
)
@also_test_with_materialized_columns(event_properties=["int_value"], person_properties=["fish"])

View File

@ -698,11 +698,18 @@ class Field(Expr):
@dataclass(kw_only=True)
class Placeholder(Expr):
chain: list[str | int]
expr: Expr
@property
def field(self):
return ".".join(str(chain) for chain in self.chain)
def chain(self) -> list[str | int] | None:
expr = self.expr
while isinstance(expr, Alias):
expr = expr.expr
return expr.chain if isinstance(expr, Field) else None
@property
def field(self) -> str | None:
return ".".join(str(chain) for chain in self.chain) if self.chain else None
@dataclass(kw_only=True)

View File

@ -726,7 +726,11 @@ class BytecodeCompiler(Visitor):
def visit_function(self, node: ast.Function):
# add an implicit return if none at the end of the function
body = node.body
if isinstance(node.body, ast.Block):
# Sometimes blocks like `fn x() {foo}` get parsed as placeholders
if isinstance(body, ast.Placeholder):
body = ast.Block(declarations=[ast.ExprStatement(expr=body.expr), ast.ReturnStatement(expr=None)])
elif isinstance(node.body, ast.Block):
if len(node.body.declarations) == 0 or not isinstance(node.body.declarations[-1], ast.ReturnStatement):
body = ast.Block(declarations=[*node.body.declarations, ast.ReturnStatement(expr=None)])
elif not isinstance(node.body, ast.ReturnStatement):
@ -753,7 +757,11 @@ class BytecodeCompiler(Visitor):
def visit_lambda(self, node: ast.Lambda):
# add an implicit return if none at the end of the function
expr: ast.Expr | ast.Statement = node.expr
if isinstance(expr, ast.Block):
# Sometimes blocks like `x -> {foo}` get parsed as placeholders
if isinstance(expr, ast.Placeholder):
expr = ast.Block(declarations=[ast.ExprStatement(expr=expr.expr), ast.ReturnStatement(expr=None)])
elif isinstance(expr, ast.Block):
if len(expr.declarations) == 0 or not isinstance(expr.declarations[-1], ast.ReturnStatement):
expr = ast.Block(declarations=[*expr.declarations, ast.ReturnStatement(expr=None)])
elif not isinstance(expr, ast.ReturnStatement):

View File

@ -293,7 +293,7 @@ keywordForAlias
alias: IDENTIFIER | keywordForAlias; // |interval| can't be an alias, otherwise 'INTERVAL 1 SOMETHING' becomes ambiguous.
identifier: IDENTIFIER | interval | keyword;
enumValue: string EQ_SINGLE numberLiteral;
placeholder: LBRACE nestedIdentifier RBRACE;
placeholder: LBRACE columnExpr RBRACE;
string: STRING_LITERAL | templateString;
templateString : QUOTE_SINGLE_TEMPLATE stringContents* QUOTE_SINGLE ;

File diff suppressed because one or more lines are too long

View File

@ -511,7 +511,7 @@ def serializedATN():
5,106,0,0,1254,1257,3,148,74,0,1255,1257,3,150,75,0,1256,1253,1,
0,0,0,1256,1254,1,0,0,0,1256,1255,1,0,0,0,1257,157,1,0,0,0,1258,
1259,3,162,81,0,1259,1260,5,123,0,0,1260,1261,3,144,72,0,1261,159,
1,0,0,0,1262,1263,5,129,0,0,1263,1264,3,130,65,0,1264,1265,5,148,
1,0,0,0,1262,1263,5,129,0,0,1263,1264,3,116,58,0,1264,1265,5,148,
0,0,1265,161,1,0,0,0,1266,1269,5,111,0,0,1267,1269,3,164,82,0,1268,
1266,1,0,0,0,1268,1267,1,0,0,0,1269,163,1,0,0,0,1270,1274,5,143,
0,0,1271,1273,3,166,83,0,1272,1271,1,0,0,0,1273,1276,1,0,0,0,1274,
@ -9701,8 +9701,8 @@ class HogQLParser ( Parser ):
def LBRACE(self):
return self.getToken(HogQLParser.LBRACE, 0)
def nestedIdentifier(self):
return self.getTypedRuleContext(HogQLParser.NestedIdentifierContext,0)
def columnExpr(self):
return self.getTypedRuleContext(HogQLParser.ColumnExprContext,0)
def RBRACE(self):
@ -9729,7 +9729,7 @@ class HogQLParser ( Parser ):
self.state = 1262
self.match(HogQLParser.LBRACE)
self.state = 1263
self.nestedIdentifier()
self.columnExpr(0)
self.state = 1264
self.match(HogQLParser.RBRACE)
except RecognitionException as re:

View File

@ -1122,8 +1122,7 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
return ast.HogQLXAttribute(name=name, value=ast.Constant(value=True))
def visitPlaceholder(self, ctx: HogQLParser.PlaceholderContext):
nested = self.visit(ctx.nestedIdentifier()) if ctx.nestedIdentifier() else []
return ast.Placeholder(chain=nested)
return ast.Placeholder(expr=self.visit(ctx.columnExpr()))
def visitColumnExprTemplateString(self, ctx: HogQLParser.ColumnExprTemplateStringContext):
return self.visit(ctx.templateString())

View File

@ -24,6 +24,8 @@ class FindPlaceholders(TraversingVisitor):
super().visit(node.expr)
def visit_placeholder(self, node: ast.Placeholder):
if node.field is None:
raise QueryError("Placeholder expressions are not yet supported")
self.found.add(node.field)
@ -34,7 +36,7 @@ class ReplacePlaceholders(CloningVisitor):
def visit_placeholder(self, node):
if not self.placeholders:
raise QueryError(f"Placeholders, such as {{{node.field}}}, are not supported in this context")
raise QueryError(f"Unresolved placeholder: {{{node.field}}}")
if node.field in self.placeholders and self.placeholders[node.field] is not None:
new_node = self.placeholders[node.field]
new_node.start = node.start

View File

@ -1165,7 +1165,9 @@ class _Printer(Visitor):
raise QueryError(f"Unsupported function call '{node.name}(...)'")
def visit_placeholder(self, node: ast.Placeholder):
raise QueryError(f"Placeholders, such as {{{node.field}}}, are not supported in this context")
if node.field is None:
raise QueryError("You can not use expressions inside placeholders")
raise QueryError(f"Unresolved placeholder: {{{node.field}}}")
def visit_alias(self, node: ast.Alias):
# Skip hidden aliases completely.

View File

@ -77,12 +77,14 @@ def execute_hogql_query(
f"Query contains 'filters' placeholder, yet filters are also provided as a standalone query parameter."
)
if "filters" in placeholders_in_query or any(
placeholder.startswith("filters.") for placeholder in placeholders_in_query
placeholder and placeholder.startswith("filters.") for placeholder in placeholders_in_query
):
select_query = replace_filters(select_query, filters, team)
leftover_placeholders: list[str] = []
for placeholder in placeholders_in_query:
if placeholder is None:
raise ValueError("Placeholder expressions are not yet supported")
if placeholder != "filters" and not placeholder.startswith("filters."):
leftover_placeholders.append(placeholder)
@ -91,7 +93,7 @@ def execute_hogql_query(
if len(placeholders_in_query) > 0:
if len(placeholders) == 0:
raise ValueError(
f"Query contains placeholders, but none were provided. Placeholders in query: {', '.join(placeholders_in_query)}"
f"Query contains placeholders, but none were provided. Placeholders in query: {', '.join(s for s in placeholders_in_query if s is not None)}"
)
select_query = replace_placeholders(select_query, placeholders)

View File

@ -735,7 +735,7 @@ def parser_test_factory(backend: Literal["python", "cpp"]):
def test_placeholders(self):
self.assertEqual(
self._expr("{foo}"),
ast.Placeholder(chain=["foo"]),
ast.Placeholder(expr=ast.Field(chain=["foo"])),
)
self.assertEqual(
self._expr("{foo}", {"foo": ast.Constant(value="bar")}),
@ -946,7 +946,7 @@ def parser_test_factory(backend: Literal["python", "cpp"]):
self._select("select 1 from {placeholder}"),
ast.SelectQuery(
select=[ast.Constant(value=1)],
select_from=ast.JoinExpr(table=ast.Placeholder(chain=["placeholder"])),
select_from=ast.JoinExpr(table=ast.Placeholder(expr=ast.Field(chain=["placeholder"]))),
),
)
self.assertEqual(
@ -1336,7 +1336,7 @@ def parser_test_factory(backend: Literal["python", "cpp"]):
where=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Constant(value=1),
right=ast.Placeholder(chain=["hogql_val_1"]),
right=ast.Placeholder(expr=ast.Field(chain=["hogql_val_1"])),
),
),
)
@ -1355,6 +1355,29 @@ def parser_test_factory(backend: Literal["python", "cpp"]):
),
)
def test_placeholder_expressions(self):
actual = self._select("select 1 where 1 == {1 ? hogql_val_1 : hogql_val_2}")
expected = clear_locations(
ast.SelectQuery(
select=[ast.Constant(value=1)],
where=ast.CompareOperation(
op=ast.CompareOperationOp.Eq,
left=ast.Constant(value=1),
right=ast.Placeholder(
expr=ast.Call(
name="if",
args=[
ast.Constant(value=1),
ast.Field(chain=["hogql_val_1"]),
ast.Field(chain=["hogql_val_2"]),
],
)
),
),
)
)
self.assertEqual(actual, expected)
def test_select_union_all(self):
self.assertEqual(
self._select("select 1 union all select 2 union all select 3"),

View File

@ -3,6 +3,7 @@ from posthog.hogql import ast
from posthog.hogql.errors import QueryError
from posthog.hogql.parser import parse_expr, parse_select
from posthog.hogql.placeholders import replace_placeholders, find_placeholders
from posthog.hogql.visitor import clear_locations
from posthog.test.base import BaseTest
@ -12,23 +13,23 @@ class TestParser(BaseTest):
self.assertEqual(sorted(find_placeholders(expr)), sorted(["foo", "bar"]))
def test_replace_placeholders_simple(self):
expr = parse_expr("{foo}")
expr = clear_locations(parse_expr("{foo}"))
self.assertEqual(
expr,
ast.Placeholder(chain=["foo"], start=0, end=5),
ast.Placeholder(expr=ast.Field(chain=["foo"])),
)
expr2 = replace_placeholders(expr, {"foo": ast.Constant(value="bar")})
self.assertEqual(
expr2,
ast.Constant(value="bar", start=0, end=5),
ast.Constant(value="bar"),
)
def test_replace_placeholders_error(self):
expr = ast.Placeholder(chain=["foo"])
expr = ast.Placeholder(expr=ast.Field(chain=["foo"]))
with self.assertRaises(QueryError) as context:
replace_placeholders(expr, {})
self.assertEqual(
"Placeholders, such as {foo}, are not supported in this context",
"Unresolved placeholder: {foo}",
str(context.exception),
)
with self.assertRaises(QueryError) as context:
@ -39,35 +40,31 @@ class TestParser(BaseTest):
)
def test_replace_placeholders_comparison(self):
expr = parse_expr("timestamp < {timestamp}")
expr = clear_locations(parse_expr("timestamp < {timestamp}"))
self.assertEqual(
expr,
ast.CompareOperation(
start=0,
end=23,
op=ast.CompareOperationOp.Lt,
left=ast.Field(chain=["timestamp"], start=0, end=9),
right=ast.Placeholder(chain=["timestamp"], start=12, end=23),
left=ast.Field(chain=["timestamp"]),
right=ast.Placeholder(expr=ast.Field(chain=["timestamp"])),
),
)
expr2 = replace_placeholders(expr, {"timestamp": ast.Constant(value=123)})
self.assertEqual(
expr2,
ast.CompareOperation(
start=0,
end=23,
op=ast.CompareOperationOp.Lt,
left=ast.Field(chain=["timestamp"], start=0, end=9),
right=ast.Constant(value=123, start=12, end=23),
left=ast.Field(chain=["timestamp"]),
right=ast.Constant(value=123),
),
)
def test_assert_no_placeholders(self):
expr = ast.Placeholder(chain=["foo"])
expr = ast.Placeholder(expr=ast.Field(chain=["foo"]))
with self.assertRaises(QueryError) as context:
replace_placeholders(expr, None)
self.assertEqual(
"Placeholders, such as {foo}, are not supported in this context",
"Unresolved placeholder: {foo}",
str(context.exception),
)

View File

@ -745,7 +745,7 @@ class TestPrinter(BaseTest):
self._assert_expr_error("this makes little sense", "mismatched input 'makes' expecting <EOF>")
self._assert_expr_error("1;2", "mismatched input ';' expecting <EOF>")
self._assert_expr_error("b.a(bla)", "You can only call simple functions in HogQL, not expressions")
self._assert_expr_error("a -> { print(2) }", "You can not use blocks in HogQL")
self._assert_expr_error("a -> { print(2) }", "You can not use expressions inside placeholders")
def test_logic(self):
self.assertEqual(

View File

@ -58,7 +58,7 @@ class TestVisitor(BaseTest):
args=[
ast.Alias(
alias="d",
expr=ast.Placeholder(chain=["e"]),
expr=ast.Placeholder(expr=ast.Field(chain=["e"])),
),
ast.OrderExpr(
expr=ast.Field(chain=["c"]),

View File

@ -93,7 +93,7 @@ class TraversingVisitor(Visitor[None]):
self.visit(node.type)
def visit_placeholder(self, node: ast.Placeholder):
self.visit(node.type)
self.visit(node.expr)
def visit_call(self, node: ast.Call):
for expr in node.args:
@ -498,7 +498,7 @@ class CloningVisitor(Visitor[Any]):
start=None if self.clear_locations else node.start,
end=None if self.clear_locations else node.end,
type=None if self.clear_types else node.type,
chain=node.chain,
expr=self.visit(node.expr),
)
def visit_call(self, node: ast.Call):

View File

@ -105,7 +105,7 @@ class ErrorTrackingQueryRunner(QueryRunner):
left=ast.Field(chain=["event"]),
right=ast.Constant(value="$exception"),
),
ast.Placeholder(chain=["filters"]),
ast.Placeholder(expr=ast.Field(chain=["filters"])),
]
groups = []

View File

@ -51,9 +51,13 @@ class AggregationOperations(DataWarehouseInsightQueryMixin):
actor = "e.distinct_id" if self.team.aggregate_users_by_distinct_id else "e.person_id"
return parse_expr(f"count(DISTINCT {actor})")
elif self.series.math == "weekly_active":
return ast.Placeholder(chain=["replaced"]) # This gets replaced when doing query orchestration
return ast.Placeholder(
expr=ast.Field(chain=["replaced"])
) # This gets replaced when doing query orchestration
elif self.series.math == "monthly_active":
return ast.Placeholder(chain=["replaced"]) # This gets replaced when doing query orchestration
return ast.Placeholder(
expr=ast.Field(chain=["replaced"])
) # This gets replaced when doing query orchestration
elif self.series.math == "unique_session":
return parse_expr('count(DISTINCT e."$session_id")')
elif self.series.math == "unique_group" and self.series.math_group_type_index is not None:

View File

@ -103,7 +103,7 @@ phonenumberslite==8.13.6
openai==1.43.0
tiktoken==0.7.0
nh3==0.2.14
hogql-parser==1.0.40
hogql-parser==1.0.45
zxcvbn==4.4.28
zstd==1.5.5.1
xmlsec==1.3.13 # Do not change this version - it will break SAML

View File

@ -279,7 +279,7 @@ h11==0.13.0
# wsproto
hexbytes==1.0.0
# via dlt
hogql-parser==1.0.40
hogql-parser==1.0.45
# via -r requirements.in
httpcore==1.0.2
# via httpx