From c2399b05dea0d0e7a0d4bbd00f75c387007b36d5 Mon Sep 17 00:00:00 2001 From: Paul D'Ambra Date: Tue, 2 May 2023 15:33:40 +0100 Subject: [PATCH] feat: read snapshots from s3 (#15328) --- frontend/src/lib/constants.tsx | 1 + .../player/sessionRecordingDataLogic.ts | 101 ++++++++++++++---- frontend/src/types.ts | 1 + package.json | 1 + pnpm-lock.yaml | 7 ++ posthog/api/session_recording.py | 46 ++++++-- .../api/test/__snapshots__/test_query.ambr | 8 +- posthog/api/test/test_session_recordings.py | 70 +++++++++++- posthog/storage/object_storage.py | 51 ++++++++- posthog/storage/test/test_object_storage.py | 53 ++++++++- 10 files changed, 306 insertions(+), 33 deletions(-) diff --git a/frontend/src/lib/constants.tsx b/frontend/src/lib/constants.tsx index 3187cfb71b2..88f4d41196e 100644 --- a/frontend/src/lib/constants.tsx +++ b/frontend/src/lib/constants.tsx @@ -149,6 +149,7 @@ export const FEATURE_FLAGS = { RECORDINGS_DOM_EXPLORER: 'recordings-dom-explorer', // owner: #team-session-recordings AUTO_REDIRECT: 'auto-redirect', // owner: @lharries DATA_MANAGEMENT_HISTORY: 'data-management-history', // owner: @pauldambra + SESSION_RECORDING_BLOB_REPLAY: 'session-recording-blob-replay', // owner: #team-monitoring } /** Which self-hosted plan's features are available with Cloud's "Standard" plan (aka card attached). */ diff --git a/frontend/src/scenes/session-recordings/player/sessionRecordingDataLogic.ts b/frontend/src/scenes/session-recordings/player/sessionRecordingDataLogic.ts index 059df4c1e99..608800aea8c 100644 --- a/frontend/src/scenes/session-recordings/player/sessionRecordingDataLogic.ts +++ b/frontend/src/scenes/session-recordings/player/sessionRecordingDataLogic.ts @@ -25,26 +25,35 @@ import { userLogic } from 'scenes/userLogic' import { chainToElements } from 'lib/utils/elements-chain' import { captureException } from '@sentry/react' import { createSegments, mapSnapshotsToWindowId } from './utils/segmenter' +import { decompressSync, strFromU8 } from 'fflate' +import { featureFlagLogic } from 'lib/logic/featureFlagLogic' +import { FEATURE_FLAGS } from 'lib/constants' const IS_TEST_MODE = process.env.NODE_ENV === 'test' const BUFFER_MS = 60000 // +- before and after start and end of a recording to query for. +export const prepareRecordingSnapshots = ( + newSnapshots?: RecordingSnapshot[], + existingSnapshots?: RecordingSnapshot[] +): RecordingSnapshot[] => { + return (newSnapshots || []) + .concat(existingSnapshots ? existingSnapshots ?? [] : []) + .sort((a, b) => a.timestamp - b.timestamp) +} + // Until we change the API to return a simple list of snapshots, we need to convert this ourselves export const convertSnapshotsResponse = ( snapshotsByWindowId: { [key: string]: eventWithTime[] }, existingSnapshots?: RecordingSnapshot[] ): RecordingSnapshot[] => { - const snapshots: RecordingSnapshot[] = Object.entries(snapshotsByWindowId) - .flatMap(([windowId, snapshots]) => { - return snapshots.map((snapshot) => ({ - ...snapshot, - windowId, - })) - }) - .concat(existingSnapshots ? existingSnapshots ?? [] : []) - .sort((a, b) => a.timestamp - b.timestamp) + const snapshots: RecordingSnapshot[] = Object.entries(snapshotsByWindowId).flatMap(([windowId, snapshots]) => { + return snapshots.map((snapshot) => ({ + ...snapshot, + windowId, + })) + }) - return snapshots + return prepareRecordingSnapshots(snapshots, existingSnapshots) } const generateRecordingReportDurations = ( @@ -84,7 +93,7 @@ export const sessionRecordingDataLogic = kea([ key(({ sessionRecordingId }) => sessionRecordingId || 'no-session-recording-id'), connect({ logic: [eventUsageLogic], - values: [teamLogic, ['currentTeamId'], userLogic, ['hasAvailableFeature']], + values: [teamLogic, ['currentTeamId'], userLogic, ['hasAvailableFeature'], featureFlagLogic, ['featureFlags']], }), defaults({ sessionPlayerMetaData: null as SessionRecordingType | null, @@ -118,6 +127,8 @@ export const sessionRecordingDataLogic = kea([ 0, { loadRecordingSnapshotsSuccess: (state) => state + 1, + // load recording blob snapshots will call loadRecordingSnapshotsSuccess again when complete + loadRecordingBlobSnapshots: () => 0, }, ], @@ -141,7 +152,19 @@ export const sessionRecordingDataLogic = kea([ actions.loadEvents() actions.loadPerformanceEvents() }, + loadRecordingBlobSnapshotsSuccess: () => { + if (values.sessionPlayerSnapshotData?.blob_keys?.length) { + actions.loadRecordingBlobSnapshots(null) + } else { + actions.loadRecordingSnapshotsSuccess(values.sessionPlayerSnapshotData) + } + }, loadRecordingSnapshotsSuccess: () => { + if (values.sessionPlayerSnapshotData?.blob_keys?.length) { + actions.loadRecordingBlobSnapshots(null) + return + } + // If there is more data to poll for load the next batch. // This will keep calling loadRecording until `next` is empty. if (!!values.sessionPlayerSnapshotData?.next) { @@ -246,6 +269,40 @@ export const sessionRecordingDataLogic = kea([ sessionPlayerSnapshotData: [ null as SessionPlayerSnapshotData | null, { + loadRecordingBlobSnapshots: async (_, breakpoint): Promise => { + const snapshotDataClone = { ...values.sessionPlayerSnapshotData } as SessionPlayerSnapshotData + + if (!snapshotDataClone?.blob_keys?.length) { + // only call this loader action when there are blob_keys to load + return snapshotDataClone + } + + await breakpoint(1) + + const blob_key = snapshotDataClone.blob_keys.shift() + + const response = await api.getResponse( + `api/projects/${values.currentTeamId}/session_recordings/${props.sessionRecordingId}/snapshot_file/?blob_key=${blob_key}` + ) + breakpoint() + + const contentBuffer = new Uint8Array(await response.arrayBuffer()) + const jsonLines = strFromU8(decompressSync(contentBuffer)).trim().split('\n') + const snapshots: RecordingSnapshot[] = jsonLines.flatMap((l) => { + const snapshotLine = JSON.parse(l) + const snapshotData = JSON.parse(snapshotLine['data']) + + return snapshotData.map((d: any) => ({ + windowId: snapshotLine['window_id'], + ...d, + })) + }) + + return { + blob_keys: snapshotDataClone.blob_keys, + snapshots: prepareRecordingSnapshots(snapshots, snapshotDataClone.snapshots), + } + }, loadRecordingSnapshots: async ({ nextUrl }, breakpoint): Promise => { cache.snapshotsStartTime = performance.now() @@ -256,6 +313,7 @@ export const sessionRecordingDataLogic = kea([ const params = toParams({ recording_start_time: props.recordingStartTime, + blob_loading_enabled: !!values.featureFlags[FEATURE_FLAGS.SESSION_RECORDING_BLOB_REPLAY], }) const apiUrl = nextUrl || @@ -266,13 +324,20 @@ export const sessionRecordingDataLogic = kea([ // NOTE: This might seem backwards as we translate the snapshotsByWindowId to an array and then derive it again later but // this is for future support of the API that will return them as a simple array - const snapshots = convertSnapshotsResponse( - response.snapshot_data_by_window_id, - nextUrl ? values.sessionPlayerSnapshotData?.snapshots ?? [] : [] - ) - return { - snapshots, - next: response.next, + if (!response.blob_keys) { + const snapshots = convertSnapshotsResponse( + response.snapshot_data_by_window_id, + nextUrl ? values.sessionPlayerSnapshotData?.snapshots ?? [] : [] + ) + return { + snapshots, + next: response.next, + } + } else { + return { + snapshots: [], + blob_keys: response.blob_keys, + } } }, }, diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 98911134031..0f851722431 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -618,6 +618,7 @@ export type RecordingSnapshot = eventWithTime & { export interface SessionPlayerSnapshotData { snapshots: RecordingSnapshot[] next?: string + blob_keys?: string[] } export interface SessionPlayerData { diff --git a/package.json b/package.json index dc64a1ef010..c05d3c9ba0a 100644 --- a/package.json +++ b/package.json @@ -101,6 +101,7 @@ "expr-eval": "^2.0.2", "express": "^4.17.1", "fast-deep-equal": "^3.1.3", + "fflate": "^0.7.4", "fs-extra": "^10.0.0", "fuse.js": "^6.6.2", "husky": "^7.0.4", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 66e51b612ba..136fc04f8a1 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -121,6 +121,9 @@ dependencies: fast-deep-equal: specifier: ^3.1.3 version: 3.1.3 + fflate: + specifier: ^0.7.4 + version: 0.7.4 fs-extra: specifier: ^10.0.0 version: 10.1.0 @@ -10275,6 +10278,10 @@ packages: resolution: {integrity: sha512-FJqqoDBR00Mdj9ppamLa/Y7vxm+PRmNWA67N846RvsoYVMKB4q3y/de5PA7gUmRMYK/8CMz2GDZQmCRN1wBcWA==} dev: false + /fflate@0.7.4: + resolution: {integrity: sha512-5u2V/CDW15QM1XbbgS+0DfPxVB+jUKhWEKuuFuHncbk3tEEqzmoXL+2KyOFuKGqOnmdIy0/davWF1CkuwtibCw==} + dev: false + /figgy-pudding@3.5.2: resolution: {integrity: sha512-0btnI/H8f2pavGMN8w40mlSKOfTK2SVJmBfBeVIj3kNw0swwgzyRq0d5TJVOwodFmtvpPeWPN/MCcfuWF0Ezbw==} dev: true diff --git a/posthog/api/session_recording.py b/posthog/api/session_recording.py index 70084f56fbf..207c351414a 100644 --- a/posthog/api/session_recording.py +++ b/posthog/api/session_recording.py @@ -3,8 +3,9 @@ from typing import Any, List, Type, cast import structlog from dateutil import parser +import requests from django.db.models import Count, Prefetch -from django.http import JsonResponse +from django.http import JsonResponse, HttpResponse from loginas.utils import is_impersonated_session from rest_framework import exceptions, request, serializers, viewsets from rest_framework.decorators import action @@ -26,6 +27,7 @@ from posthog.permissions import ProjectMembershipNecessaryPermissions, TeamMembe from posthog.queries.session_recordings.session_recording_list import SessionRecordingList, SessionRecordingListV2 from posthog.queries.session_recordings.session_recording_properties import SessionRecordingProperties from posthog.rate_limit import ClickHouseBurstRateThrottle, ClickHouseSustainedRateThrottle +from posthog.storage import object_storage from posthog.utils import format_query_params_absolute_url DEFAULT_RECORDING_CHUNK_LIMIT = 20 # Should be tuned to find the best value @@ -136,19 +138,51 @@ class SessionRecordingViewSet(StructuredViewSetMixin, viewsets.ViewSet): return Response({"success": True}) + @action(methods=["GET"], detail=True) + def snapshot_file(self, request: request.Request, **kwargs) -> HttpResponse: + blob_key = request.GET.get("blob_key") + + if not blob_key: + raise exceptions.ValidationError("Must provide a snapshot file blob key") + + # very short-lived pre-signed URL + file_key = f"session_recordings/team_id/{self.team.pk}/session_id/{self.kwargs['pk']}/{blob_key}" + url = object_storage.get_presigned_url(file_key, expiration=60) + if not url: + raise exceptions.NotFound("Snapshot file not found") + + with requests.get(url=url, stream=True) as r: + r.raise_for_status() + response = HttpResponse(content=r.raw, content_type="application/json") + response["Content-Disposition"] = "inline" + return response + # Paginated endpoint that returns the snapshots for the recording @action(methods=["GET"], detail=True) def snapshots(self, request: request.Request, **kwargs): - # TODO: Why do we use a Filter? Just swap to norma, offset, limit pagination - filter = Filter(request=request) - limit = filter.limit if filter.limit else DEFAULT_RECORDING_CHUNK_LIMIT - offset = filter.offset if filter.offset else 0 - recording = SessionRecording.get_or_build(session_id=kwargs["pk"], team=self.team) if recording.deleted: raise exceptions.NotFound("Recording not found") + if request.GET.get("blob_loading_enabled", "false") == "true": + blob_prefix = f"session_recordings/team_id/{self.team.pk}/session_id/{recording.session_id}" + blob_keys = object_storage.list_objects(blob_prefix) + + if blob_keys: + return Response( + { + "snapshot_data_by_window_id": [], + "blob_keys": [x.replace(blob_prefix + "/", "") for x in blob_keys], + "next": None, + } + ) + + # TODO: Why do we use a Filter? Just swap to norma, offset, limit pagination + filter = Filter(request=request) + limit = filter.limit if filter.limit else DEFAULT_RECORDING_CHUNK_LIMIT + offset = filter.offset if filter.offset else 0 + # Optimisation step if passed to speed up retrieval of CH data if not recording.start_time: recording_start_time = ( diff --git a/posthog/api/test/__snapshots__/test_query.ambr b/posthog/api/test/__snapshots__/test_query.ambr index 4c0d3d38d20..6c75b1df690 100644 --- a/posthog/api/test/__snapshots__/test_query.ambr +++ b/posthog/api/test/__snapshots__/test_query.ambr @@ -297,14 +297,14 @@ (SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id, person_distinct_id2.distinct_id FROM person_distinct_id2 - WHERE equals(person_distinct_id2.team_id, 83) + WHERE equals(person_distinct_id2.team_id, 81) GROUP BY person_distinct_id2.distinct_id HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) INNER JOIN (SELECT argMax(replaceRegexpAll(JSONExtractRaw(person.properties, 'email'), '^"|"$', ''), person.version) AS properties___email, person.id FROM person - WHERE equals(person.team_id, 83) + WHERE equals(person.team_id, 81) GROUP BY person.id HAVING equals(argMax(person.is_deleted, person.version), 0)) AS events__pdi__person ON equals(events__pdi.person_id, events__pdi__person.id) WHERE and(equals(events.team_id, 2), equals(events__pdi__person.properties___email, 'tom@posthog.com'), less(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-10 12:14:05.000000', 6, 'UTC')), greater(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-09 12:00:00.000000', 6, 'UTC'))) @@ -327,14 +327,14 @@ (SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id, person_distinct_id2.distinct_id FROM person_distinct_id2 - WHERE equals(person_distinct_id2.team_id, 84) + WHERE equals(person_distinct_id2.team_id, 82) GROUP BY person_distinct_id2.distinct_id HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) INNER JOIN (SELECT argMax(person.pmat_email, person.version) AS properties___email, person.id FROM person - WHERE equals(person.team_id, 84) + WHERE equals(person.team_id, 82) GROUP BY person.id HAVING equals(argMax(person.is_deleted, person.version), 0)) AS events__pdi__person ON equals(events__pdi.person_id, events__pdi__person.id) WHERE and(equals(events.team_id, 2), equals(events__pdi__person.properties___email, 'tom@posthog.com'), less(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-10 12:14:05.000000', 6, 'UTC')), greater(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-09 12:00:00.000000', 6, 'UTC'))) diff --git a/posthog/api/test/test_session_recordings.py b/posthog/api/test/test_session_recordings.py index 9b911a1741a..e318f8bb593 100644 --- a/posthog/api/test/test_session_recordings.py +++ b/posthog/api/test/test_session_recordings.py @@ -1,5 +1,6 @@ +import uuid from datetime import datetime, timedelta, timezone -from unittest.mock import ANY +from unittest.mock import ANY, patch from urllib.parse import urlencode from dateutil.parser import parse @@ -329,7 +330,7 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.headers.get("Content-Encoding", None), "gzip") - def test_get_snapshots_for_chunked_session_recording(self): + def _test_body_for_chunked_session_recording(self, include_feature_flag_param: bool): chunked_session_id = "chunk_id" expected_num_requests = 3 num_chunks = 60 @@ -346,7 +347,8 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) window_id="1" if index % 2 == 0 else "2", ) - next_url = f"/api/projects/{self.team.id}/session_recordings/{chunked_session_id}/snapshots" + blob_flag = "?blob_loading_enabled=false" if include_feature_flag_param else "" + next_url = f"/api/projects/{self.team.id}/session_recordings/{chunked_session_id}/snapshots{blob_flag}" for i in range(expected_num_requests): response = self.client.get(next_url) @@ -367,6 +369,68 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest) next_url = response_data["next"] + def test_get_snapshots_for_chunked_session_recording_with_blob_flag_included_and_off(self): + self._test_body_for_chunked_session_recording(include_feature_flag_param=True) + + def test_get_snapshots_for_chunked_session_recording_with_blob_flag_not_included(self): + self._test_body_for_chunked_session_recording(include_feature_flag_param=False) + + @patch("posthog.api.session_recording.object_storage.list_objects") + def test_get_snapshots_can_load_blobs_when_available(self, mock_list_objects) -> None: + blob_objects = ["session_recordings/something/data", "session_recordings/something_else/data"] + mock_list_objects.return_value = blob_objects + chunked_session_id = "chunk_id" + # only needs enough data so that the test fails if the blobs aren't loaded + num_chunks = 2 + snapshots_per_chunk = 2 + + with freeze_time("2020-09-13T12:26:40.000Z"): + start_time = now() + for index, s in enumerate(range(num_chunks)): + self.create_chunked_snapshots( + snapshots_per_chunk, + "user", + chunked_session_id, + start_time + relativedelta(minutes=s), + window_id="1" if index % 2 == 0 else "2", + ) + + response = self.client.get( + f"/api/projects/{self.team.id}/session_recordings/{chunked_session_id}/snapshots?blob_loading_enabled=true" + ) + response_data = response.json() + assert response_data == {"snapshot_data_by_window_id": [], "next": None, "blob_keys": blob_objects} + + @patch("posthog.api.session_recording.object_storage.get_presigned_url") + @patch("posthog.api.session_recording.requests") + def test_can_get_session_recording_blob(self, _mock_requests, mock_presigned_url) -> None: + session_id = str(uuid.uuid4()) + """API will add session_recordings/team_id/{self.team.pk}/session_id/{session_id}""" + blob_key = f"data/1682608337071" + url = f"/api/projects/{self.team.pk}/session_recordings/{session_id}/snapshot_file/?blob_key={blob_key}" + + def presigned_url_sideeffect(key: str, **kwargs): + if key == f"session_recordings/team_id/{self.team.pk}/session_id/{session_id}/{blob_key}": + return f"https://test.com/" + else: + return None + + mock_presigned_url.side_effect = presigned_url_sideeffect + + response = self.client.get(url) + assert response.status_code == status.HTTP_200_OK + + @patch("posthog.api.session_recording.object_storage.get_presigned_url") + def test_can_not_get_session_recording_blob_that_does_not_exist(self, mock_presigned_url) -> None: + session_id = str(uuid.uuid4()) + blob_key = f"session_recordings/team_id/{self.team.pk}/session_id/{session_id}/data/1682608337071" + url = f"/api/projects/{self.team.pk}/session_recordings/{session_id}/snapshot_file/?blob_key={blob_key}" + + mock_presigned_url.return_value = None + + response = self.client.get(url) + assert response.status_code == status.HTTP_404_NOT_FOUND + def test_get_metadata_for_chunked_session_recording(self): with freeze_time("2020-09-13T12:26:40.000Z"): diff --git a/posthog/storage/object_storage.py b/posthog/storage/object_storage.py index 1fd683b8b94..d53fc3583da 100644 --- a/posthog/storage/object_storage.py +++ b/posthog/storage/object_storage.py @@ -1,5 +1,5 @@ import abc -from typing import Optional, Union +from typing import Optional, Union, List import structlog from boto3 import client @@ -21,6 +21,14 @@ class ObjectStorageClient(metaclass=abc.ABCMeta): def head_bucket(self, bucket: str) -> bool: pass + @abc.abstractmethod + def get_presigned_url(self, bucket: str, file_key: str, expiration: int = 3600) -> Optional[str]: + pass + + @abc.abstractmethod + def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]: + pass + @abc.abstractmethod def read(self, bucket: str, key: str) -> Optional[str]: pass @@ -38,6 +46,12 @@ class UnavailableStorage(ObjectStorageClient): def head_bucket(self, bucket: str): return False + def get_presigned_url(self, bucket: str, file_key: str, expiration: int = 3600) -> Optional[str]: + pass + + def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]: + pass + def read(self, bucket: str, key: str) -> Optional[str]: pass @@ -59,6 +73,31 @@ class ObjectStorage(ObjectStorageClient): logger.warn("object_storage.health_check_failed", bucket=bucket, error=e) return False + def get_presigned_url(self, bucket: str, file_key: str, expiration: int = 3600) -> Optional[str]: + try: + return self.aws_client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": bucket, "Key": file_key}, + ExpiresIn=expiration, + HttpMethod="GET", + ) + except Exception as e: + logger.error("object_storage.get_presigned_url_failed", file_name=file_key, error=e) + capture_exception(e) + return None + + def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]: + try: + s3_response = self.aws_client.list_objects_v2(Bucket=bucket, Prefix=prefix) + if s3_response.get("Contents"): + return [obj["Key"] for obj in s3_response["Contents"]] + else: + return None + except Exception as e: + logger.error("object_storage.list_objects_failed", bucket=bucket, prefix=prefix, error=e) + capture_exception(e) + return None + def read(self, bucket: str, key: str) -> Optional[str]: object_bytes = self.read_bytes(bucket, key) if object_bytes: @@ -121,5 +160,15 @@ def read_bytes(file_name: str) -> Optional[bytes]: return object_storage_client().read_bytes(bucket=settings.OBJECT_STORAGE_BUCKET, key=file_name) +def list_objects(prefix: str) -> Optional[List[str]]: + return object_storage_client().list_objects(bucket=settings.OBJECT_STORAGE_BUCKET, prefix=prefix) + + +def get_presigned_url(file_key: str, expiration: int = 3600) -> Optional[str]: + return object_storage_client().get_presigned_url( + bucket=settings.OBJECT_STORAGE_BUCKET, file_key=file_key, expiration=expiration + ) + + def health_check() -> bool: return object_storage_client().head_bucket(bucket=settings.OBJECT_STORAGE_BUCKET) diff --git a/posthog/storage/test/test_object_storage.py b/posthog/storage/test/test_object_storage.py index 79979c1e127..41f2d648cee 100644 --- a/posthog/storage/test/test_object_storage.py +++ b/posthog/storage/test/test_object_storage.py @@ -10,7 +10,7 @@ from posthog.settings import ( OBJECT_STORAGE_ENDPOINT, OBJECT_STORAGE_SECRET_ACCESS_KEY, ) -from posthog.storage.object_storage import health_check, read, write +from posthog.storage.object_storage import health_check, read, write, get_presigned_url, list_objects from posthog.test.base import APIBaseTest TEST_BUCKET = "test_storage_bucket" @@ -52,3 +52,54 @@ class TestStorage(APIBaseTest): file_name = f"{TEST_BUCKET}/test_write_and_read_works_with_known_content/{name}" write(file_name, "my content".encode("utf-8")) self.assertEqual(read(file_name), "my content") + + def test_can_generate_presigned_url_for_existing_file(self) -> None: + with self.settings(OBJECT_STORAGE_ENABLED=True): + session_id = str(uuid.uuid4()) + chunk_id = uuid.uuid4() + name = f"{session_id}/{0}-{chunk_id}" + file_name = f"{TEST_BUCKET}/test_can_generate_presigned_url_for_existing_file/{name}" + write(file_name, "my content".encode("utf-8")) + + presigned_url = get_presigned_url(file_name) + assert presigned_url is not None + self.assertRegex( + presigned_url, + r"^http://localhost:\d+/posthog/test_storage_bucket/test_can_generate_presigned_url_for_existing_file/.*\?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=.*$", + ) + + def test_can_generate_presigned_url_for_non_existent_file(self) -> None: + with self.settings(OBJECT_STORAGE_ENABLED=True): + name = "a/b-c" + file_name = f"{TEST_BUCKET}/test_can_ignore_presigned_url_for_non_existent_file/{name}" + + presigned_url = get_presigned_url(file_name) + assert presigned_url is not None + self.assertRegex( + presigned_url, + r"^http://localhost:\d+/posthog/test_storage_bucket/test_can_ignore_presigned_url_for_non_existent_file/.*?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=.*$", + ) + + def test_can_list_objects_with_prefix(self) -> None: + with self.settings(OBJECT_STORAGE_ENABLED=True): + shared_prefix = "a_shared_prefix" + + for file in ["a", "b", "c"]: + file_name = f"{TEST_BUCKET}/{shared_prefix}/{file}" + write(file_name, "my content".encode("utf-8")) + + listing = list_objects(prefix=f"{TEST_BUCKET}/{shared_prefix}") + + assert listing == [ + "test_storage_bucket/a_shared_prefix/a", + "test_storage_bucket/a_shared_prefix/b", + "test_storage_bucket/a_shared_prefix/c", + ] + + def test_can_list_unknown_prefix(self) -> None: + with self.settings(OBJECT_STORAGE_ENABLED=True): + shared_prefix = str(uuid.uuid4()) + + listing = list_objects(prefix=shared_prefix) + + assert listing is None