mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-24 09:14:46 +01:00
feat: read snapshots from s3 (#15328)
This commit is contained in:
parent
295dd4e2cc
commit
c2399b05de
@ -149,6 +149,7 @@ export const FEATURE_FLAGS = {
|
||||
RECORDINGS_DOM_EXPLORER: 'recordings-dom-explorer', // owner: #team-session-recordings
|
||||
AUTO_REDIRECT: 'auto-redirect', // owner: @lharries
|
||||
DATA_MANAGEMENT_HISTORY: 'data-management-history', // owner: @pauldambra
|
||||
SESSION_RECORDING_BLOB_REPLAY: 'session-recording-blob-replay', // owner: #team-monitoring
|
||||
}
|
||||
|
||||
/** Which self-hosted plan's features are available with Cloud's "Standard" plan (aka card attached). */
|
||||
|
@ -25,26 +25,35 @@ import { userLogic } from 'scenes/userLogic'
|
||||
import { chainToElements } from 'lib/utils/elements-chain'
|
||||
import { captureException } from '@sentry/react'
|
||||
import { createSegments, mapSnapshotsToWindowId } from './utils/segmenter'
|
||||
import { decompressSync, strFromU8 } from 'fflate'
|
||||
import { featureFlagLogic } from 'lib/logic/featureFlagLogic'
|
||||
import { FEATURE_FLAGS } from 'lib/constants'
|
||||
|
||||
const IS_TEST_MODE = process.env.NODE_ENV === 'test'
|
||||
const BUFFER_MS = 60000 // +- before and after start and end of a recording to query for.
|
||||
|
||||
export const prepareRecordingSnapshots = (
|
||||
newSnapshots?: RecordingSnapshot[],
|
||||
existingSnapshots?: RecordingSnapshot[]
|
||||
): RecordingSnapshot[] => {
|
||||
return (newSnapshots || [])
|
||||
.concat(existingSnapshots ? existingSnapshots ?? [] : [])
|
||||
.sort((a, b) => a.timestamp - b.timestamp)
|
||||
}
|
||||
|
||||
// Until we change the API to return a simple list of snapshots, we need to convert this ourselves
|
||||
export const convertSnapshotsResponse = (
|
||||
snapshotsByWindowId: { [key: string]: eventWithTime[] },
|
||||
existingSnapshots?: RecordingSnapshot[]
|
||||
): RecordingSnapshot[] => {
|
||||
const snapshots: RecordingSnapshot[] = Object.entries(snapshotsByWindowId)
|
||||
.flatMap(([windowId, snapshots]) => {
|
||||
return snapshots.map((snapshot) => ({
|
||||
...snapshot,
|
||||
windowId,
|
||||
}))
|
||||
})
|
||||
.concat(existingSnapshots ? existingSnapshots ?? [] : [])
|
||||
.sort((a, b) => a.timestamp - b.timestamp)
|
||||
const snapshots: RecordingSnapshot[] = Object.entries(snapshotsByWindowId).flatMap(([windowId, snapshots]) => {
|
||||
return snapshots.map((snapshot) => ({
|
||||
...snapshot,
|
||||
windowId,
|
||||
}))
|
||||
})
|
||||
|
||||
return snapshots
|
||||
return prepareRecordingSnapshots(snapshots, existingSnapshots)
|
||||
}
|
||||
|
||||
const generateRecordingReportDurations = (
|
||||
@ -84,7 +93,7 @@ export const sessionRecordingDataLogic = kea<sessionRecordingDataLogicType>([
|
||||
key(({ sessionRecordingId }) => sessionRecordingId || 'no-session-recording-id'),
|
||||
connect({
|
||||
logic: [eventUsageLogic],
|
||||
values: [teamLogic, ['currentTeamId'], userLogic, ['hasAvailableFeature']],
|
||||
values: [teamLogic, ['currentTeamId'], userLogic, ['hasAvailableFeature'], featureFlagLogic, ['featureFlags']],
|
||||
}),
|
||||
defaults({
|
||||
sessionPlayerMetaData: null as SessionRecordingType | null,
|
||||
@ -118,6 +127,8 @@ export const sessionRecordingDataLogic = kea<sessionRecordingDataLogicType>([
|
||||
0,
|
||||
{
|
||||
loadRecordingSnapshotsSuccess: (state) => state + 1,
|
||||
// load recording blob snapshots will call loadRecordingSnapshotsSuccess again when complete
|
||||
loadRecordingBlobSnapshots: () => 0,
|
||||
},
|
||||
],
|
||||
|
||||
@ -141,7 +152,19 @@ export const sessionRecordingDataLogic = kea<sessionRecordingDataLogicType>([
|
||||
actions.loadEvents()
|
||||
actions.loadPerformanceEvents()
|
||||
},
|
||||
loadRecordingBlobSnapshotsSuccess: () => {
|
||||
if (values.sessionPlayerSnapshotData?.blob_keys?.length) {
|
||||
actions.loadRecordingBlobSnapshots(null)
|
||||
} else {
|
||||
actions.loadRecordingSnapshotsSuccess(values.sessionPlayerSnapshotData)
|
||||
}
|
||||
},
|
||||
loadRecordingSnapshotsSuccess: () => {
|
||||
if (values.sessionPlayerSnapshotData?.blob_keys?.length) {
|
||||
actions.loadRecordingBlobSnapshots(null)
|
||||
return
|
||||
}
|
||||
|
||||
// If there is more data to poll for load the next batch.
|
||||
// This will keep calling loadRecording until `next` is empty.
|
||||
if (!!values.sessionPlayerSnapshotData?.next) {
|
||||
@ -246,6 +269,40 @@ export const sessionRecordingDataLogic = kea<sessionRecordingDataLogicType>([
|
||||
sessionPlayerSnapshotData: [
|
||||
null as SessionPlayerSnapshotData | null,
|
||||
{
|
||||
loadRecordingBlobSnapshots: async (_, breakpoint): Promise<SessionPlayerSnapshotData | null> => {
|
||||
const snapshotDataClone = { ...values.sessionPlayerSnapshotData } as SessionPlayerSnapshotData
|
||||
|
||||
if (!snapshotDataClone?.blob_keys?.length) {
|
||||
// only call this loader action when there are blob_keys to load
|
||||
return snapshotDataClone
|
||||
}
|
||||
|
||||
await breakpoint(1)
|
||||
|
||||
const blob_key = snapshotDataClone.blob_keys.shift()
|
||||
|
||||
const response = await api.getResponse(
|
||||
`api/projects/${values.currentTeamId}/session_recordings/${props.sessionRecordingId}/snapshot_file/?blob_key=${blob_key}`
|
||||
)
|
||||
breakpoint()
|
||||
|
||||
const contentBuffer = new Uint8Array(await response.arrayBuffer())
|
||||
const jsonLines = strFromU8(decompressSync(contentBuffer)).trim().split('\n')
|
||||
const snapshots: RecordingSnapshot[] = jsonLines.flatMap((l) => {
|
||||
const snapshotLine = JSON.parse(l)
|
||||
const snapshotData = JSON.parse(snapshotLine['data'])
|
||||
|
||||
return snapshotData.map((d: any) => ({
|
||||
windowId: snapshotLine['window_id'],
|
||||
...d,
|
||||
}))
|
||||
})
|
||||
|
||||
return {
|
||||
blob_keys: snapshotDataClone.blob_keys,
|
||||
snapshots: prepareRecordingSnapshots(snapshots, snapshotDataClone.snapshots),
|
||||
}
|
||||
},
|
||||
loadRecordingSnapshots: async ({ nextUrl }, breakpoint): Promise<SessionPlayerSnapshotData | null> => {
|
||||
cache.snapshotsStartTime = performance.now()
|
||||
|
||||
@ -256,6 +313,7 @@ export const sessionRecordingDataLogic = kea<sessionRecordingDataLogicType>([
|
||||
|
||||
const params = toParams({
|
||||
recording_start_time: props.recordingStartTime,
|
||||
blob_loading_enabled: !!values.featureFlags[FEATURE_FLAGS.SESSION_RECORDING_BLOB_REPLAY],
|
||||
})
|
||||
const apiUrl =
|
||||
nextUrl ||
|
||||
@ -266,13 +324,20 @@ export const sessionRecordingDataLogic = kea<sessionRecordingDataLogicType>([
|
||||
// NOTE: This might seem backwards as we translate the snapshotsByWindowId to an array and then derive it again later but
|
||||
// this is for future support of the API that will return them as a simple array
|
||||
|
||||
const snapshots = convertSnapshotsResponse(
|
||||
response.snapshot_data_by_window_id,
|
||||
nextUrl ? values.sessionPlayerSnapshotData?.snapshots ?? [] : []
|
||||
)
|
||||
return {
|
||||
snapshots,
|
||||
next: response.next,
|
||||
if (!response.blob_keys) {
|
||||
const snapshots = convertSnapshotsResponse(
|
||||
response.snapshot_data_by_window_id,
|
||||
nextUrl ? values.sessionPlayerSnapshotData?.snapshots ?? [] : []
|
||||
)
|
||||
return {
|
||||
snapshots,
|
||||
next: response.next,
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
snapshots: [],
|
||||
blob_keys: response.blob_keys,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
|
@ -618,6 +618,7 @@ export type RecordingSnapshot = eventWithTime & {
|
||||
export interface SessionPlayerSnapshotData {
|
||||
snapshots: RecordingSnapshot[]
|
||||
next?: string
|
||||
blob_keys?: string[]
|
||||
}
|
||||
|
||||
export interface SessionPlayerData {
|
||||
|
@ -101,6 +101,7 @@
|
||||
"expr-eval": "^2.0.2",
|
||||
"express": "^4.17.1",
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
"fflate": "^0.7.4",
|
||||
"fs-extra": "^10.0.0",
|
||||
"fuse.js": "^6.6.2",
|
||||
"husky": "^7.0.4",
|
||||
|
@ -121,6 +121,9 @@ dependencies:
|
||||
fast-deep-equal:
|
||||
specifier: ^3.1.3
|
||||
version: 3.1.3
|
||||
fflate:
|
||||
specifier: ^0.7.4
|
||||
version: 0.7.4
|
||||
fs-extra:
|
||||
specifier: ^10.0.0
|
||||
version: 10.1.0
|
||||
@ -10275,6 +10278,10 @@ packages:
|
||||
resolution: {integrity: sha512-FJqqoDBR00Mdj9ppamLa/Y7vxm+PRmNWA67N846RvsoYVMKB4q3y/de5PA7gUmRMYK/8CMz2GDZQmCRN1wBcWA==}
|
||||
dev: false
|
||||
|
||||
/fflate@0.7.4:
|
||||
resolution: {integrity: sha512-5u2V/CDW15QM1XbbgS+0DfPxVB+jUKhWEKuuFuHncbk3tEEqzmoXL+2KyOFuKGqOnmdIy0/davWF1CkuwtibCw==}
|
||||
dev: false
|
||||
|
||||
/figgy-pudding@3.5.2:
|
||||
resolution: {integrity: sha512-0btnI/H8f2pavGMN8w40mlSKOfTK2SVJmBfBeVIj3kNw0swwgzyRq0d5TJVOwodFmtvpPeWPN/MCcfuWF0Ezbw==}
|
||||
dev: true
|
||||
|
@ -3,8 +3,9 @@ from typing import Any, List, Type, cast
|
||||
|
||||
import structlog
|
||||
from dateutil import parser
|
||||
import requests
|
||||
from django.db.models import Count, Prefetch
|
||||
from django.http import JsonResponse
|
||||
from django.http import JsonResponse, HttpResponse
|
||||
from loginas.utils import is_impersonated_session
|
||||
from rest_framework import exceptions, request, serializers, viewsets
|
||||
from rest_framework.decorators import action
|
||||
@ -26,6 +27,7 @@ from posthog.permissions import ProjectMembershipNecessaryPermissions, TeamMembe
|
||||
from posthog.queries.session_recordings.session_recording_list import SessionRecordingList, SessionRecordingListV2
|
||||
from posthog.queries.session_recordings.session_recording_properties import SessionRecordingProperties
|
||||
from posthog.rate_limit import ClickHouseBurstRateThrottle, ClickHouseSustainedRateThrottle
|
||||
from posthog.storage import object_storage
|
||||
from posthog.utils import format_query_params_absolute_url
|
||||
|
||||
DEFAULT_RECORDING_CHUNK_LIMIT = 20 # Should be tuned to find the best value
|
||||
@ -136,19 +138,51 @@ class SessionRecordingViewSet(StructuredViewSetMixin, viewsets.ViewSet):
|
||||
|
||||
return Response({"success": True})
|
||||
|
||||
@action(methods=["GET"], detail=True)
|
||||
def snapshot_file(self, request: request.Request, **kwargs) -> HttpResponse:
|
||||
blob_key = request.GET.get("blob_key")
|
||||
|
||||
if not blob_key:
|
||||
raise exceptions.ValidationError("Must provide a snapshot file blob key")
|
||||
|
||||
# very short-lived pre-signed URL
|
||||
file_key = f"session_recordings/team_id/{self.team.pk}/session_id/{self.kwargs['pk']}/{blob_key}"
|
||||
url = object_storage.get_presigned_url(file_key, expiration=60)
|
||||
if not url:
|
||||
raise exceptions.NotFound("Snapshot file not found")
|
||||
|
||||
with requests.get(url=url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
response = HttpResponse(content=r.raw, content_type="application/json")
|
||||
response["Content-Disposition"] = "inline"
|
||||
return response
|
||||
|
||||
# Paginated endpoint that returns the snapshots for the recording
|
||||
@action(methods=["GET"], detail=True)
|
||||
def snapshots(self, request: request.Request, **kwargs):
|
||||
# TODO: Why do we use a Filter? Just swap to norma, offset, limit pagination
|
||||
filter = Filter(request=request)
|
||||
limit = filter.limit if filter.limit else DEFAULT_RECORDING_CHUNK_LIMIT
|
||||
offset = filter.offset if filter.offset else 0
|
||||
|
||||
recording = SessionRecording.get_or_build(session_id=kwargs["pk"], team=self.team)
|
||||
|
||||
if recording.deleted:
|
||||
raise exceptions.NotFound("Recording not found")
|
||||
|
||||
if request.GET.get("blob_loading_enabled", "false") == "true":
|
||||
blob_prefix = f"session_recordings/team_id/{self.team.pk}/session_id/{recording.session_id}"
|
||||
blob_keys = object_storage.list_objects(blob_prefix)
|
||||
|
||||
if blob_keys:
|
||||
return Response(
|
||||
{
|
||||
"snapshot_data_by_window_id": [],
|
||||
"blob_keys": [x.replace(blob_prefix + "/", "") for x in blob_keys],
|
||||
"next": None,
|
||||
}
|
||||
)
|
||||
|
||||
# TODO: Why do we use a Filter? Just swap to norma, offset, limit pagination
|
||||
filter = Filter(request=request)
|
||||
limit = filter.limit if filter.limit else DEFAULT_RECORDING_CHUNK_LIMIT
|
||||
offset = filter.offset if filter.offset else 0
|
||||
|
||||
# Optimisation step if passed to speed up retrieval of CH data
|
||||
if not recording.start_time:
|
||||
recording_start_time = (
|
||||
|
@ -297,14 +297,14 @@
|
||||
(SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id,
|
||||
person_distinct_id2.distinct_id
|
||||
FROM person_distinct_id2
|
||||
WHERE equals(person_distinct_id2.team_id, 83)
|
||||
WHERE equals(person_distinct_id2.team_id, 81)
|
||||
GROUP BY person_distinct_id2.distinct_id
|
||||
HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id)
|
||||
INNER JOIN
|
||||
(SELECT argMax(replaceRegexpAll(JSONExtractRaw(person.properties, 'email'), '^"|"$', ''), person.version) AS properties___email,
|
||||
person.id
|
||||
FROM person
|
||||
WHERE equals(person.team_id, 83)
|
||||
WHERE equals(person.team_id, 81)
|
||||
GROUP BY person.id
|
||||
HAVING equals(argMax(person.is_deleted, person.version), 0)) AS events__pdi__person ON equals(events__pdi.person_id, events__pdi__person.id)
|
||||
WHERE and(equals(events.team_id, 2), equals(events__pdi__person.properties___email, 'tom@posthog.com'), less(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-10 12:14:05.000000', 6, 'UTC')), greater(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-09 12:00:00.000000', 6, 'UTC')))
|
||||
@ -327,14 +327,14 @@
|
||||
(SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id,
|
||||
person_distinct_id2.distinct_id
|
||||
FROM person_distinct_id2
|
||||
WHERE equals(person_distinct_id2.team_id, 84)
|
||||
WHERE equals(person_distinct_id2.team_id, 82)
|
||||
GROUP BY person_distinct_id2.distinct_id
|
||||
HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id)
|
||||
INNER JOIN
|
||||
(SELECT argMax(person.pmat_email, person.version) AS properties___email,
|
||||
person.id
|
||||
FROM person
|
||||
WHERE equals(person.team_id, 84)
|
||||
WHERE equals(person.team_id, 82)
|
||||
GROUP BY person.id
|
||||
HAVING equals(argMax(person.is_deleted, person.version), 0)) AS events__pdi__person ON equals(events__pdi.person_id, events__pdi__person.id)
|
||||
WHERE and(equals(events.team_id, 2), equals(events__pdi__person.properties___email, 'tom@posthog.com'), less(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-10 12:14:05.000000', 6, 'UTC')), greater(toTimeZone(events.timestamp, 'UTC'), toDateTime64('2020-01-09 12:00:00.000000', 6, 'UTC')))
|
||||
|
@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import ANY
|
||||
from unittest.mock import ANY, patch
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from dateutil.parser import parse
|
||||
@ -329,7 +330,7 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest)
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(response.headers.get("Content-Encoding", None), "gzip")
|
||||
|
||||
def test_get_snapshots_for_chunked_session_recording(self):
|
||||
def _test_body_for_chunked_session_recording(self, include_feature_flag_param: bool):
|
||||
chunked_session_id = "chunk_id"
|
||||
expected_num_requests = 3
|
||||
num_chunks = 60
|
||||
@ -346,7 +347,8 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest)
|
||||
window_id="1" if index % 2 == 0 else "2",
|
||||
)
|
||||
|
||||
next_url = f"/api/projects/{self.team.id}/session_recordings/{chunked_session_id}/snapshots"
|
||||
blob_flag = "?blob_loading_enabled=false" if include_feature_flag_param else ""
|
||||
next_url = f"/api/projects/{self.team.id}/session_recordings/{chunked_session_id}/snapshots{blob_flag}"
|
||||
|
||||
for i in range(expected_num_requests):
|
||||
response = self.client.get(next_url)
|
||||
@ -367,6 +369,68 @@ class TestSessionRecordings(APIBaseTest, ClickhouseTestMixin, QueryMatchingTest)
|
||||
|
||||
next_url = response_data["next"]
|
||||
|
||||
def test_get_snapshots_for_chunked_session_recording_with_blob_flag_included_and_off(self):
|
||||
self._test_body_for_chunked_session_recording(include_feature_flag_param=True)
|
||||
|
||||
def test_get_snapshots_for_chunked_session_recording_with_blob_flag_not_included(self):
|
||||
self._test_body_for_chunked_session_recording(include_feature_flag_param=False)
|
||||
|
||||
@patch("posthog.api.session_recording.object_storage.list_objects")
|
||||
def test_get_snapshots_can_load_blobs_when_available(self, mock_list_objects) -> None:
|
||||
blob_objects = ["session_recordings/something/data", "session_recordings/something_else/data"]
|
||||
mock_list_objects.return_value = blob_objects
|
||||
chunked_session_id = "chunk_id"
|
||||
# only needs enough data so that the test fails if the blobs aren't loaded
|
||||
num_chunks = 2
|
||||
snapshots_per_chunk = 2
|
||||
|
||||
with freeze_time("2020-09-13T12:26:40.000Z"):
|
||||
start_time = now()
|
||||
for index, s in enumerate(range(num_chunks)):
|
||||
self.create_chunked_snapshots(
|
||||
snapshots_per_chunk,
|
||||
"user",
|
||||
chunked_session_id,
|
||||
start_time + relativedelta(minutes=s),
|
||||
window_id="1" if index % 2 == 0 else "2",
|
||||
)
|
||||
|
||||
response = self.client.get(
|
||||
f"/api/projects/{self.team.id}/session_recordings/{chunked_session_id}/snapshots?blob_loading_enabled=true"
|
||||
)
|
||||
response_data = response.json()
|
||||
assert response_data == {"snapshot_data_by_window_id": [], "next": None, "blob_keys": blob_objects}
|
||||
|
||||
@patch("posthog.api.session_recording.object_storage.get_presigned_url")
|
||||
@patch("posthog.api.session_recording.requests")
|
||||
def test_can_get_session_recording_blob(self, _mock_requests, mock_presigned_url) -> None:
|
||||
session_id = str(uuid.uuid4())
|
||||
"""API will add session_recordings/team_id/{self.team.pk}/session_id/{session_id}"""
|
||||
blob_key = f"data/1682608337071"
|
||||
url = f"/api/projects/{self.team.pk}/session_recordings/{session_id}/snapshot_file/?blob_key={blob_key}"
|
||||
|
||||
def presigned_url_sideeffect(key: str, **kwargs):
|
||||
if key == f"session_recordings/team_id/{self.team.pk}/session_id/{session_id}/{blob_key}":
|
||||
return f"https://test.com/"
|
||||
else:
|
||||
return None
|
||||
|
||||
mock_presigned_url.side_effect = presigned_url_sideeffect
|
||||
|
||||
response = self.client.get(url)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@patch("posthog.api.session_recording.object_storage.get_presigned_url")
|
||||
def test_can_not_get_session_recording_blob_that_does_not_exist(self, mock_presigned_url) -> None:
|
||||
session_id = str(uuid.uuid4())
|
||||
blob_key = f"session_recordings/team_id/{self.team.pk}/session_id/{session_id}/data/1682608337071"
|
||||
url = f"/api/projects/{self.team.pk}/session_recordings/{session_id}/snapshot_file/?blob_key={blob_key}"
|
||||
|
||||
mock_presigned_url.return_value = None
|
||||
|
||||
response = self.client.get(url)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_get_metadata_for_chunked_session_recording(self):
|
||||
|
||||
with freeze_time("2020-09-13T12:26:40.000Z"):
|
||||
|
@ -1,5 +1,5 @@
|
||||
import abc
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import structlog
|
||||
from boto3 import client
|
||||
@ -21,6 +21,14 @@ class ObjectStorageClient(metaclass=abc.ABCMeta):
|
||||
def head_bucket(self, bucket: str) -> bool:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_presigned_url(self, bucket: str, file_key: str, expiration: int = 3600) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, bucket: str, key: str) -> Optional[str]:
|
||||
pass
|
||||
@ -38,6 +46,12 @@ class UnavailableStorage(ObjectStorageClient):
|
||||
def head_bucket(self, bucket: str):
|
||||
return False
|
||||
|
||||
def get_presigned_url(self, bucket: str, file_key: str, expiration: int = 3600) -> Optional[str]:
|
||||
pass
|
||||
|
||||
def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]:
|
||||
pass
|
||||
|
||||
def read(self, bucket: str, key: str) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@ -59,6 +73,31 @@ class ObjectStorage(ObjectStorageClient):
|
||||
logger.warn("object_storage.health_check_failed", bucket=bucket, error=e)
|
||||
return False
|
||||
|
||||
def get_presigned_url(self, bucket: str, file_key: str, expiration: int = 3600) -> Optional[str]:
|
||||
try:
|
||||
return self.aws_client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": bucket, "Key": file_key},
|
||||
ExpiresIn=expiration,
|
||||
HttpMethod="GET",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("object_storage.get_presigned_url_failed", file_name=file_key, error=e)
|
||||
capture_exception(e)
|
||||
return None
|
||||
|
||||
def list_objects(self, bucket: str, prefix: str) -> Optional[List[str]]:
|
||||
try:
|
||||
s3_response = self.aws_client.list_objects_v2(Bucket=bucket, Prefix=prefix)
|
||||
if s3_response.get("Contents"):
|
||||
return [obj["Key"] for obj in s3_response["Contents"]]
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("object_storage.list_objects_failed", bucket=bucket, prefix=prefix, error=e)
|
||||
capture_exception(e)
|
||||
return None
|
||||
|
||||
def read(self, bucket: str, key: str) -> Optional[str]:
|
||||
object_bytes = self.read_bytes(bucket, key)
|
||||
if object_bytes:
|
||||
@ -121,5 +160,15 @@ def read_bytes(file_name: str) -> Optional[bytes]:
|
||||
return object_storage_client().read_bytes(bucket=settings.OBJECT_STORAGE_BUCKET, key=file_name)
|
||||
|
||||
|
||||
def list_objects(prefix: str) -> Optional[List[str]]:
|
||||
return object_storage_client().list_objects(bucket=settings.OBJECT_STORAGE_BUCKET, prefix=prefix)
|
||||
|
||||
|
||||
def get_presigned_url(file_key: str, expiration: int = 3600) -> Optional[str]:
|
||||
return object_storage_client().get_presigned_url(
|
||||
bucket=settings.OBJECT_STORAGE_BUCKET, file_key=file_key, expiration=expiration
|
||||
)
|
||||
|
||||
|
||||
def health_check() -> bool:
|
||||
return object_storage_client().head_bucket(bucket=settings.OBJECT_STORAGE_BUCKET)
|
||||
|
@ -10,7 +10,7 @@ from posthog.settings import (
|
||||
OBJECT_STORAGE_ENDPOINT,
|
||||
OBJECT_STORAGE_SECRET_ACCESS_KEY,
|
||||
)
|
||||
from posthog.storage.object_storage import health_check, read, write
|
||||
from posthog.storage.object_storage import health_check, read, write, get_presigned_url, list_objects
|
||||
from posthog.test.base import APIBaseTest
|
||||
|
||||
TEST_BUCKET = "test_storage_bucket"
|
||||
@ -52,3 +52,54 @@ class TestStorage(APIBaseTest):
|
||||
file_name = f"{TEST_BUCKET}/test_write_and_read_works_with_known_content/{name}"
|
||||
write(file_name, "my content".encode("utf-8"))
|
||||
self.assertEqual(read(file_name), "my content")
|
||||
|
||||
def test_can_generate_presigned_url_for_existing_file(self) -> None:
|
||||
with self.settings(OBJECT_STORAGE_ENABLED=True):
|
||||
session_id = str(uuid.uuid4())
|
||||
chunk_id = uuid.uuid4()
|
||||
name = f"{session_id}/{0}-{chunk_id}"
|
||||
file_name = f"{TEST_BUCKET}/test_can_generate_presigned_url_for_existing_file/{name}"
|
||||
write(file_name, "my content".encode("utf-8"))
|
||||
|
||||
presigned_url = get_presigned_url(file_name)
|
||||
assert presigned_url is not None
|
||||
self.assertRegex(
|
||||
presigned_url,
|
||||
r"^http://localhost:\d+/posthog/test_storage_bucket/test_can_generate_presigned_url_for_existing_file/.*\?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=.*$",
|
||||
)
|
||||
|
||||
def test_can_generate_presigned_url_for_non_existent_file(self) -> None:
|
||||
with self.settings(OBJECT_STORAGE_ENABLED=True):
|
||||
name = "a/b-c"
|
||||
file_name = f"{TEST_BUCKET}/test_can_ignore_presigned_url_for_non_existent_file/{name}"
|
||||
|
||||
presigned_url = get_presigned_url(file_name)
|
||||
assert presigned_url is not None
|
||||
self.assertRegex(
|
||||
presigned_url,
|
||||
r"^http://localhost:\d+/posthog/test_storage_bucket/test_can_ignore_presigned_url_for_non_existent_file/.*?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=.*$",
|
||||
)
|
||||
|
||||
def test_can_list_objects_with_prefix(self) -> None:
|
||||
with self.settings(OBJECT_STORAGE_ENABLED=True):
|
||||
shared_prefix = "a_shared_prefix"
|
||||
|
||||
for file in ["a", "b", "c"]:
|
||||
file_name = f"{TEST_BUCKET}/{shared_prefix}/{file}"
|
||||
write(file_name, "my content".encode("utf-8"))
|
||||
|
||||
listing = list_objects(prefix=f"{TEST_BUCKET}/{shared_prefix}")
|
||||
|
||||
assert listing == [
|
||||
"test_storage_bucket/a_shared_prefix/a",
|
||||
"test_storage_bucket/a_shared_prefix/b",
|
||||
"test_storage_bucket/a_shared_prefix/c",
|
||||
]
|
||||
|
||||
def test_can_list_unknown_prefix(self) -> None:
|
||||
with self.settings(OBJECT_STORAGE_ENABLED=True):
|
||||
shared_prefix = str(uuid.uuid4())
|
||||
|
||||
listing = list_objects(prefix=shared_prefix)
|
||||
|
||||
assert listing is None
|
||||
|
Loading…
Reference in New Issue
Block a user