mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-21 13:39:22 +01:00
feat: add universal 60 sec event counter to livestream (#24380)
This commit is contained in:
parent
826f5a8b68
commit
8cecfbfdad
@ -7,20 +7,37 @@ import (
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
)
|
||||
|
||||
type TeamStats struct {
|
||||
Store map[string]*expirable.LRU[string, string]
|
||||
const (
|
||||
COUNTER_TTL = 60
|
||||
)
|
||||
|
||||
type Stats struct {
|
||||
Store map[string]*expirable.LRU[string, string]
|
||||
GlobalStore *expirable.LRU[string, string]
|
||||
Counter *SlidingWindowCounter
|
||||
}
|
||||
|
||||
func (ts *TeamStats) keepStats(statsChan chan PostHogEvent) {
|
||||
func newStatsKeeper() *Stats {
|
||||
return &Stats{
|
||||
Store: make(map[string]*expirable.LRU[string, string]),
|
||||
GlobalStore: expirable.NewLRU[string, string](0, nil, time.Second*COUNTER_TTL),
|
||||
Counter: NewSlidingWindowCounter(COUNTER_TTL),
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *Stats) keepStats(statsChan chan PostHogEvent) {
|
||||
log.Println("starting stats keeper...")
|
||||
|
||||
for { // ignore the range warning here - it's wrong
|
||||
select {
|
||||
case event := <-statsChan:
|
||||
ts.Counter.Increment()
|
||||
token := event.Token
|
||||
if _, ok := ts.Store[token]; !ok {
|
||||
ts.Store[token] = expirable.NewLRU[string, string](1000000, nil, time.Second*30)
|
||||
ts.Store[token] = expirable.NewLRU[string, string](0, nil, time.Second*COUNTER_TTL)
|
||||
}
|
||||
ts.Store[token].Add(event.DistinctId, "much wow")
|
||||
ts.Store[token].Add(event.DistinctId, "1")
|
||||
ts.GlobalStore.Add(event.DistinctId, "1")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12,7 +12,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
"github.com/spf13/viper"
|
||||
@ -63,16 +62,14 @@ func main() {
|
||||
log.Fatalf("Failed to open MMDB: %v", err)
|
||||
}
|
||||
|
||||
teamStats := &TeamStats{
|
||||
Store: make(map[string]*expirable.LRU[string, string]),
|
||||
}
|
||||
stats := newStatsKeeper()
|
||||
|
||||
phEventChan := make(chan PostHogEvent)
|
||||
statsChan := make(chan PostHogEvent)
|
||||
subChan := make(chan Subscription)
|
||||
unSubChan := make(chan Subscription)
|
||||
|
||||
go teamStats.keepStats(statsChan)
|
||||
go stats.keepStats(statsChan)
|
||||
|
||||
kafkaSecurityProtocol := "SSL"
|
||||
if !isProd {
|
||||
@ -109,43 +106,14 @@ func main() {
|
||||
// Routes
|
||||
e.GET("/", index)
|
||||
|
||||
e.GET("/stats", func(c echo.Context) error {
|
||||
e.GET("/served", servedHandler(stats))
|
||||
|
||||
type stats struct {
|
||||
UsersOnProduct int `json:"users_on_product,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return errors.New("authorization header is required")
|
||||
}
|
||||
|
||||
claims, err := decodeAuthToken(authHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
token := fmt.Sprint(claims["api_token"])
|
||||
|
||||
var hash *expirable.LRU[string, string]
|
||||
var ok bool
|
||||
if hash, ok = teamStats.Store[token]; !ok {
|
||||
resp := stats{
|
||||
Error: "no stats",
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
siteStats := stats{
|
||||
UsersOnProduct: hash.Len(),
|
||||
}
|
||||
return c.JSON(http.StatusOK, siteStats)
|
||||
})
|
||||
e.GET("/stats", statsHandler(stats))
|
||||
|
||||
e.GET("/events", func(c echo.Context) error {
|
||||
e.Logger.Printf("SSE client connected, ip: %v", c.RealIP())
|
||||
|
||||
teamId := c.QueryParam("teamId")
|
||||
var teamId string
|
||||
eventType := c.QueryParam("eventType")
|
||||
distinctId := c.QueryParam("distinctId")
|
||||
geo := c.QueryParam("geo")
|
||||
|
@ -36,18 +36,18 @@ func TestStatsHandler(t *testing.T) {
|
||||
req.Header.Set("Authorization", "Bearer mock_token")
|
||||
|
||||
// Create a mock TeamStats
|
||||
teamStats := &TeamStats{
|
||||
stats := &Stats{
|
||||
Store: make(map[string]*expirable.LRU[string, string]),
|
||||
}
|
||||
teamStats.Store["mock_token"] = expirable.NewLRU[string, string](100, nil, time.Minute)
|
||||
teamStats.Store["mock_token"].Add("user1", "data1")
|
||||
stats.Store["mock_token"] = expirable.NewLRU[string, string](100, nil, time.Minute)
|
||||
stats.Store["mock_token"].Add("user1", "data1")
|
||||
|
||||
// Add the teamStats to the context
|
||||
c.Set("teamStats", teamStats)
|
||||
c.Set("teamStats", stats)
|
||||
|
||||
handler := func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"users_on_product": teamStats.Store["mock_token"].Len(),
|
||||
"users_on_product": stats.Store["mock_token"].Len(),
|
||||
})
|
||||
}
|
||||
|
||||
|
62
livestream/served.go
Normal file
62
livestream/served.go
Normal file
@ -0,0 +1,62 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
type Counter struct {
|
||||
EventCount uint32
|
||||
UserCount uint32
|
||||
}
|
||||
|
||||
func servedHandler(stats *Stats) func(c echo.Context) error {
|
||||
return func(c echo.Context) error {
|
||||
userCount := stats.GlobalStore.Len()
|
||||
count := stats.Counter.Count()
|
||||
resp := Counter{
|
||||
EventCount: uint32(count),
|
||||
UserCount: uint32(userCount),
|
||||
}
|
||||
return c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func statsHandler(stats *Stats) func(c echo.Context) error {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
type resp struct {
|
||||
UsersOnProduct int `json:"users_on_product,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return errors.New("authorization header is required")
|
||||
}
|
||||
|
||||
claims, err := decodeAuthToken(authHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
token := fmt.Sprint(claims["api_token"])
|
||||
|
||||
var hash *expirable.LRU[string, string]
|
||||
var ok bool
|
||||
if hash, ok = stats.Store[token]; !ok {
|
||||
resp := resp{
|
||||
Error: "no stats",
|
||||
}
|
||||
return c.JSON(http.StatusNotFound, resp)
|
||||
}
|
||||
|
||||
siteStats := resp{
|
||||
UsersOnProduct: hash.Len(),
|
||||
}
|
||||
return c.JSON(http.StatusOK, siteStats)
|
||||
}
|
||||
}
|
48
livestream/ttl_counter.go
Normal file
48
livestream/ttl_counter.go
Normal file
@ -0,0 +1,48 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SlidingWindowCounter struct {
|
||||
mu sync.Mutex
|
||||
events []time.Time
|
||||
windowSize time.Duration
|
||||
}
|
||||
|
||||
func NewSlidingWindowCounter(windowSize time.Duration) *SlidingWindowCounter {
|
||||
return &SlidingWindowCounter{
|
||||
events: make([]time.Time, 0),
|
||||
windowSize: windowSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (swc *SlidingWindowCounter) Increment() {
|
||||
swc.mu.Lock()
|
||||
defer swc.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
swc.events = append(swc.events, now)
|
||||
swc.removeOldEvents(now)
|
||||
}
|
||||
|
||||
func (swc *SlidingWindowCounter) Count() int {
|
||||
swc.mu.Lock()
|
||||
defer swc.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
swc.removeOldEvents(now)
|
||||
return len(swc.events)
|
||||
}
|
||||
|
||||
func (swc *SlidingWindowCounter) removeOldEvents(now time.Time) {
|
||||
cutoff := now.Add(-swc.windowSize)
|
||||
i := 0
|
||||
for ; i < len(swc.events); i++ {
|
||||
if swc.events[i].After(cutoff) {
|
||||
break
|
||||
}
|
||||
}
|
||||
swc.events = swc.events[i:]
|
||||
}
|
78
livestream/ttl_counter_test.go
Normal file
78
livestream/ttl_counter_test.go
Normal file
@ -0,0 +1,78 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewSlidingWindowCounter(t *testing.T) {
|
||||
windowSize := time.Minute
|
||||
swc := NewSlidingWindowCounter(windowSize)
|
||||
|
||||
assert.Equal(t, windowSize, swc.windowSize, "Window size should match")
|
||||
assert.Empty(t, swc.events, "Events slice should be empty")
|
||||
}
|
||||
|
||||
func TestIncrement(t *testing.T) {
|
||||
swc := NewSlidingWindowCounter(time.Minute)
|
||||
|
||||
swc.Increment()
|
||||
assert.Equal(t, 1, swc.Count(), "Count should be 1 after first increment")
|
||||
|
||||
swc.Increment()
|
||||
assert.Equal(t, 2, swc.Count(), "Count should be 2 after second increment")
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
swc := NewSlidingWindowCounter(time.Second)
|
||||
|
||||
swc.Increment()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
swc.Increment()
|
||||
|
||||
assert.Equal(t, 2, swc.Count(), "Count should be 2 within the time window")
|
||||
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, 1, swc.Count(), "Count should be 1 after oldest event expires")
|
||||
}
|
||||
|
||||
func TestRemoveOldEvents(t *testing.T) {
|
||||
swc := NewSlidingWindowCounter(time.Second)
|
||||
|
||||
now := time.Now()
|
||||
swc.events = []time.Time{
|
||||
now.Add(-2 * time.Second),
|
||||
now.Add(-1500 * time.Millisecond),
|
||||
now.Add(-500 * time.Millisecond),
|
||||
now,
|
||||
}
|
||||
|
||||
swc.removeOldEvents(now)
|
||||
|
||||
require.Len(t, swc.events, 2, "Should have 2 events after removal")
|
||||
assert.Equal(t, now.Add(-500*time.Millisecond), swc.events[0], "First event should be 500ms ago")
|
||||
assert.Equal(t, now, swc.events[1], "Second event should be now")
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
swc := NewSlidingWindowCounter(time.Minute)
|
||||
iterations := 1000
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(iterations)
|
||||
for i := 0; i < iterations; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
swc.Increment()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, iterations, swc.Count(), "Count should match the number of increments")
|
||||
}
|
Loading…
Reference in New Issue
Block a user