mirror of
https://github.com/PostHog/posthog.git
synced 2024-11-21 13:39:22 +01:00
feat(cdp): Oauth refresh flow support (hubspot) (#23811)
This commit is contained in:
parent
2413b661f5
commit
b956712cb5
@ -2,4 +2,4 @@
|
||||
# Important: Add new queues to make Celery consume tasks from them.
|
||||
|
||||
# NOTE: Keep in sync with posthog/tasks/utils.py
|
||||
CELERY_WORKER_QUEUES=celery,stats,email,analytics_queries,analytics_limited,long_running,exports,subscription_delivery,usage_reports,session_replay_embeddings,session_replay_general,session_replay_persistence
|
||||
CELERY_WORKER_QUEUES=celery,stats,email,analytics_queries,analytics_limited,long_running,exports,subscription_delivery,usage_reports,session_replay_embeddings,session_replay_general,session_replay_persistence,integrations
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 20 KiB |
Binary file not shown.
Before Width: | Height: | Size: 100 KiB After Width: | Height: | Size: 99 KiB |
Binary file not shown.
Before Width: | Height: | Size: 162 KiB After Width: | Height: | Size: 162 KiB |
@ -1,3 +1,5 @@
|
||||
import { LemonBanner } from '@posthog/lemon-ui'
|
||||
import api from 'lib/api'
|
||||
import { UserActivityIndicator } from 'lib/components/UserActivityIndicator/UserActivityIndicator'
|
||||
|
||||
import { IntegrationType } from '~/types'
|
||||
@ -9,26 +11,50 @@ export function IntegrationView({
|
||||
integration: IntegrationType
|
||||
suffix?: JSX.Element
|
||||
}): JSX.Element {
|
||||
const errors = (integration.errors && integration.errors?.split(',')) || []
|
||||
|
||||
return (
|
||||
<div className="rounded border flex justify-between items-center p-2 bg-bg-light">
|
||||
<div className="flex items-center gap-4 ml-2">
|
||||
<img src={integration.icon_url} className="h-10 w-10 rounded" />
|
||||
<div>
|
||||
<div className="rounded border bg-bg-light">
|
||||
<div className="flex justify-between items-center p-2">
|
||||
<div className="flex items-center gap-4 ml-2">
|
||||
<img src={integration.icon_url} className="h-10 w-10 rounded" />
|
||||
<div>
|
||||
Connected to <strong>{integration.name}</strong>
|
||||
<div>
|
||||
Connected to <strong>{integration.display_name}</strong>
|
||||
</div>
|
||||
{integration.created_by ? (
|
||||
<UserActivityIndicator
|
||||
at={integration.created_at}
|
||||
by={integration.created_by}
|
||||
prefix="Updated"
|
||||
className="text-muted"
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
{integration.created_by ? (
|
||||
<UserActivityIndicator
|
||||
at={integration.created_at}
|
||||
by={integration.created_by}
|
||||
prefix="Updated"
|
||||
className="text-muted"
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
|
||||
{suffix}
|
||||
</div>
|
||||
|
||||
{suffix}
|
||||
{errors.length > 0 && (
|
||||
<div className="p-2">
|
||||
<LemonBanner
|
||||
type="error"
|
||||
action={{
|
||||
children: 'Reconnect',
|
||||
disableClientSideRouting: true,
|
||||
to: api.integrations.authorizeUrl({
|
||||
kind: integration.kind,
|
||||
next: window.location.pathname,
|
||||
}),
|
||||
}}
|
||||
>
|
||||
{errors[0] === 'TOKEN_REFRESH_FAILED'
|
||||
? 'Authentication token could not be refreshed. Please reconnect.'
|
||||
: `There was an error with this integration: ${errors[0]}`}
|
||||
</LemonBanner>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import { loaders } from 'kea-loaders'
|
||||
import { router, urlToAction } from 'kea-router'
|
||||
import api from 'lib/api'
|
||||
import { fromParamsGivenUrl } from 'lib/utils'
|
||||
import IconHubspot from 'public/services/hubspot.png'
|
||||
import IconSalesforce from 'public/services/salesforce.png'
|
||||
import IconSlack from 'public/services/slack.png'
|
||||
import { preflightLogic } from 'scenes/PreflightCheck/preflightLogic'
|
||||
@ -13,9 +14,10 @@ import { IntegrationKind, IntegrationType } from '~/types'
|
||||
|
||||
import type { integrationsLogicType } from './integrationsLogicType'
|
||||
|
||||
const ICONS = {
|
||||
const ICONS: Record<IntegrationKind, any> = {
|
||||
slack: IconSlack,
|
||||
salesforce: IconSalesforce,
|
||||
hubspot: IconHubspot,
|
||||
}
|
||||
|
||||
export const integrationsLogic = kea<integrationsLogicType>([
|
||||
@ -41,12 +43,6 @@ export const integrationsLogic = kea<integrationsLogicType>([
|
||||
return res.results.map((integration) => {
|
||||
return {
|
||||
...integration,
|
||||
name:
|
||||
integration.kind === 'slack'
|
||||
? integration.config.team.name
|
||||
: integration.kind === 'salesforce'
|
||||
? integration.config.instance_url
|
||||
: 'Unknown',
|
||||
// TODO: Make the icons endpoint independent of hog functions
|
||||
icon_url: ICONS[integration.kind],
|
||||
}
|
||||
|
@ -75,6 +75,7 @@ export type SideAction = Pick<
|
||||
LemonButtonProps,
|
||||
| 'onClick'
|
||||
| 'to'
|
||||
| 'disableClientSideRouting'
|
||||
| 'disabled'
|
||||
| 'icon'
|
||||
| 'type'
|
||||
|
@ -273,6 +273,7 @@ function HogFunctionInputSchemaControls({ value, onChange, onDone }: HogFunction
|
||||
options={[
|
||||
{ label: 'Slack', value: 'slack' },
|
||||
{ label: 'Salesforce', value: 'salesforce' },
|
||||
{ label: 'Hubspot', value: 'hubspot' },
|
||||
]}
|
||||
placeholder="Choose kind"
|
||||
/>
|
||||
|
@ -50,7 +50,7 @@ function HogFunctionIntegrationChoice({
|
||||
icon: <img src={integration.icon_url} className="w-6 h-6 rounded" />,
|
||||
onClick: () => onChange?.(integration.id),
|
||||
active: integration.id === value,
|
||||
label: integration.name,
|
||||
label: integration.display_name,
|
||||
})) || []),
|
||||
],
|
||||
}
|
||||
|
@ -219,7 +219,7 @@ export const mockIntegration: IntegrationType = {
|
||||
},
|
||||
},
|
||||
icon_url: '',
|
||||
name: '',
|
||||
display_name: '',
|
||||
created_at: '2022-01-01T00:09:00',
|
||||
created_by: mockBasicUser,
|
||||
}
|
||||
|
@ -3539,16 +3539,17 @@ export enum EventDefinitionType {
|
||||
EventPostHog = 'event_posthog',
|
||||
}
|
||||
|
||||
export type IntegrationKind = 'slack' | 'salesforce'
|
||||
export type IntegrationKind = 'slack' | 'salesforce' | 'hubspot'
|
||||
|
||||
export interface IntegrationType {
|
||||
id: number
|
||||
kind: IntegrationKind
|
||||
name: string
|
||||
display_name: string
|
||||
icon_url: string
|
||||
config: any
|
||||
created_by?: UserBasicType | null
|
||||
created_at: string
|
||||
errors?: string
|
||||
}
|
||||
|
||||
export interface SlackChannelType {
|
||||
|
@ -5,7 +5,7 @@ contenttypes: 0002_remove_content_type_name
|
||||
ee: 0016_rolemembership_organization_member
|
||||
otp_static: 0002_throttling
|
||||
otp_totp: 0002_auto_20190420_0723
|
||||
posthog: 0446_annotation_dashboard_alter_annotation_scope
|
||||
posthog: 0447_alter_integration_kind
|
||||
sessions: 0001_initial
|
||||
social_django: 0010_uid_db_index
|
||||
two_factor: 0007_auto_20201201_1019
|
||||
|
@ -28,6 +28,11 @@ export class HogFunctionManager {
|
||||
const { hogFunctionIds, teamId } = JSON.parse(message)
|
||||
await this.reloadHogFunctions(teamId, hogFunctionIds)
|
||||
},
|
||||
|
||||
'reload-integrations': async (message) => {
|
||||
const { integrationIds, teamId } = JSON.parse(message)
|
||||
await this.reloadIntegrations(teamId, integrationIds)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@ -149,6 +154,15 @@ export class HogFunctionManager {
|
||||
return items[0] ?? null
|
||||
}
|
||||
|
||||
public reloadIntegrations(teamId: Team['id'], ids: IntegrationType['id'][]): Promise<void> {
|
||||
// We need to find all hog functions that depend on these integrations and re-enrich them
|
||||
|
||||
const items: HogFunctionType[] = Object.values(this.cache[teamId] || {})
|
||||
const itemsToReload = items.filter((item) => ids.some((id) => item.depends_on_integration_ids?.has(id)))
|
||||
|
||||
return this.enrichWithIntegrations(itemsToReload)
|
||||
}
|
||||
|
||||
public async enrichWithIntegrations(items: HogFunctionType[]): Promise<void> {
|
||||
const integrationIds: number[] = []
|
||||
|
||||
@ -158,6 +172,8 @@ export class HogFunctionManager {
|
||||
const input = item.inputs?.[schema.key]
|
||||
if (input && typeof input.value === 'number') {
|
||||
integrationIds.push(input.value)
|
||||
item.depends_on_integration_ids = item.depends_on_integration_ids || new Set()
|
||||
item.depends_on_integration_ids.add(input.value)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -211,6 +211,7 @@ export type HogFunctionType = {
|
||||
inputs_schema?: HogFunctionInputSchemaType[]
|
||||
inputs?: Record<string, HogFunctionInputType>
|
||||
filters?: HogFunctionFilters | null
|
||||
depends_on_integration_ids?: Set<IntegrationType['id']>
|
||||
}
|
||||
|
||||
export type HogFunctionInputType = {
|
||||
|
@ -113,6 +113,7 @@ describe('HogFunctionManager', () => {
|
||||
value: integrations[0].id,
|
||||
},
|
||||
},
|
||||
depends_on_integration_ids: new Set([integrations[0].id]),
|
||||
},
|
||||
])
|
||||
|
||||
|
@ -20,15 +20,8 @@ class IntegrationSerializer(serializers.ModelSerializer):
|
||||
|
||||
class Meta:
|
||||
model = Integration
|
||||
fields = [
|
||||
"id",
|
||||
"kind",
|
||||
"config",
|
||||
"created_at",
|
||||
"created_by",
|
||||
"errors",
|
||||
]
|
||||
read_only_fields = ["id", "created_at", "created_by", "errors"]
|
||||
fields = ["id", "kind", "config", "created_at", "created_by", "errors", "display_name"]
|
||||
read_only_fields = ["id", "created_at", "created_by", "errors", "display_name"]
|
||||
|
||||
def create(self, validated_data: Any) -> Any:
|
||||
request = self.context["request"]
|
||||
|
@ -21,7 +21,7 @@ let body := {
|
||||
}
|
||||
|
||||
let headers := {
|
||||
'Authorization': f'Bearer {inputs.access_token}',
|
||||
'Authorization': f'Bearer {inputs.oauth.access_token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
@ -51,16 +51,14 @@ if (res.status == 409) {
|
||||
} else {
|
||||
print('Contact created successfully!')
|
||||
}
|
||||
|
||||
|
||||
""".strip(),
|
||||
inputs_schema=[
|
||||
{
|
||||
"key": "access_token",
|
||||
"type": "string",
|
||||
"label": "Access token",
|
||||
"description": "Can be acquired under Profile Preferences -> Integrations -> Private Apps",
|
||||
"secret": True,
|
||||
"key": "oauth",
|
||||
"type": "integration",
|
||||
"integration": "hubspot",
|
||||
"label": "Hubspot connection",
|
||||
"secret": False,
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
|
@ -7,7 +7,7 @@ class TestTemplateHubspot(BaseHogFunctionTemplateTest):
|
||||
|
||||
def _inputs(self, **kwargs):
|
||||
inputs = {
|
||||
"access_token": "TOKEN",
|
||||
"oauth": {"access_token": "TOKEN"},
|
||||
"email": "example@posthog.com",
|
||||
"properties": {
|
||||
"company": "PostHog",
|
||||
|
19
posthog/migrations/0447_alter_integration_kind.py
Normal file
19
posthog/migrations/0447_alter_integration_kind.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Generated by Django 4.2.14 on 2024-07-18 14:26
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("posthog", "0446_annotation_dashboard_alter_annotation_scope"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="integration",
|
||||
name="kind",
|
||||
field=models.CharField(
|
||||
choices=[("slack", "Slack"), ("salesforce", "Salesforce"), ("hubspot", "Hubspot")], max_length=10
|
||||
),
|
||||
),
|
||||
]
|
@ -3,7 +3,7 @@ import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.db import models
|
||||
@ -15,12 +15,29 @@ from django.conf import settings
|
||||
from posthog.cache_utils import cache_for
|
||||
from posthog.models.instance_setting import get_instance_settings
|
||||
from posthog.models.user import User
|
||||
import structlog
|
||||
|
||||
from posthog.plugins.plugin_server_api import reload_integrations_on_workers
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
def dot_get(d: Any, path: str, default: Any = None) -> Any:
|
||||
for key in path.split("."):
|
||||
if not isinstance(d, dict):
|
||||
return default
|
||||
d = d.get(key, default)
|
||||
return d
|
||||
|
||||
|
||||
ERROR_TOKEN_REFRESH_FAILED = "TOKEN_REFRESH_FAILED"
|
||||
|
||||
|
||||
class Integration(models.Model):
|
||||
class IntegrationKind(models.TextChoices):
|
||||
SLACK = "slack"
|
||||
SALESFORCE = "salesforce"
|
||||
HUBSPOT = "hubspot"
|
||||
|
||||
class Meta:
|
||||
constraints = [
|
||||
@ -46,6 +63,14 @@ class Integration(models.Model):
|
||||
created_at: models.DateTimeField = models.DateTimeField(auto_now_add=True, blank=True)
|
||||
created_by: models.ForeignKey = models.ForeignKey("User", on_delete=models.SET_NULL, null=True, blank=True)
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
if self.kind in OauthIntegration.supported_kinds:
|
||||
oauth_config = OauthIntegration.oauth_config_for_kind(self.kind)
|
||||
return dot_get(self.config, oauth_config.name_path, self.integration_id)
|
||||
|
||||
return f"ID: {self.integration_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OauthConfig:
|
||||
@ -55,16 +80,17 @@ class OauthConfig:
|
||||
client_secret: str
|
||||
scope: str
|
||||
id_path: str
|
||||
name_path: str
|
||||
token_info_url: Optional[str] = None
|
||||
token_info_config_fields: Optional[list[str]] = None
|
||||
|
||||
|
||||
class OauthIntegration:
|
||||
supported_kinds = ["slack", "salesforce"]
|
||||
supported_kinds = ["slack", "salesforce", "hubspot"]
|
||||
integration: Integration
|
||||
kind: str
|
||||
|
||||
def __init__(self, integration: Integration, kind: str) -> None:
|
||||
def __init__(self, integration: Integration) -> None:
|
||||
self.integration = integration
|
||||
self.kind = kind
|
||||
|
||||
@classmethod
|
||||
@cache_for(timedelta(minutes=5))
|
||||
@ -88,6 +114,7 @@ class OauthIntegration:
|
||||
client_secret=from_settings["SLACK_APP_CLIENT_SECRET"],
|
||||
scope="channels:read,groups:read,chat:write,chat:write.customize",
|
||||
id_path="team.id",
|
||||
name_path="team.name",
|
||||
)
|
||||
elif kind == "salesforce":
|
||||
if not settings.SALESFORCE_CONSUMER_KEY or not settings.SALESFORCE_CONSUMER_SECRET:
|
||||
@ -100,6 +127,22 @@ class OauthIntegration:
|
||||
client_secret=settings.SALESFORCE_CONSUMER_SECRET,
|
||||
scope="full",
|
||||
id_path="instance_url",
|
||||
name_path="instance_url",
|
||||
)
|
||||
elif kind == "hubspot":
|
||||
if not settings.HUBSPOT_APP_CLIENT_ID or not settings.HUBSPOT_APP_CLIENT_SECRET:
|
||||
raise NotImplementedError("Hubspot app not configured")
|
||||
|
||||
return OauthConfig(
|
||||
authorize_url="https://app.hubspot.com/oauth/authorize",
|
||||
token_url="https://api.hubapi.com/oauth/v1/token",
|
||||
token_info_url="https://api.hubapi.com/oauth/v1/access-tokens/:access_token",
|
||||
token_info_config_fields=["hub_id", "hub_domain", "user", "user_id"],
|
||||
client_id=settings.HUBSPOT_APP_CLIENT_ID,
|
||||
client_secret=settings.HUBSPOT_APP_CLIENT_SECRET,
|
||||
scope="tickets crm.objects.contacts.write sales-email-read crm.objects.companies.read crm.objects.deals.read crm.objects.contacts.read crm.objects.quotes.read",
|
||||
id_path="hub_id",
|
||||
name_path="hub_domain",
|
||||
)
|
||||
|
||||
raise NotImplementedError(f"Oauth config for kind {kind} not implemented")
|
||||
@ -125,7 +168,7 @@ class OauthIntegration:
|
||||
|
||||
@classmethod
|
||||
def integration_from_oauth_response(
|
||||
cls, kind: str, team_id: str, created_by: User, params: dict[str, str]
|
||||
cls, kind: str, team_id: int, created_by: User, params: dict[str, str]
|
||||
) -> Integration:
|
||||
oauth_config = cls.oauth_config_for_kind(kind)
|
||||
|
||||
@ -145,10 +188,23 @@ class OauthIntegration:
|
||||
if res.status_code != 200 or not config.get("access_token"):
|
||||
raise Exception("Oauth error")
|
||||
|
||||
integration_id: Any = config
|
||||
if oauth_config.token_info_url:
|
||||
# If token info url is given we call it and check the integration id from there
|
||||
token_info_res = requests.get(
|
||||
oauth_config.token_info_url.replace(":access_token", config["access_token"]),
|
||||
headers={"Authorization": f"Bearer {config['access_token']}"},
|
||||
)
|
||||
|
||||
for key in oauth_config.id_path.split("."):
|
||||
integration_id = integration_id.get(key)
|
||||
if token_info_res.status_code == 200:
|
||||
data = token_info_res.json()
|
||||
if oauth_config.token_info_config_fields:
|
||||
for field in oauth_config.token_info_config_fields:
|
||||
config[field] = dot_get(data, field)
|
||||
|
||||
integration_id = dot_get(config, oauth_config.id_path)
|
||||
|
||||
if isinstance(integration_id, int):
|
||||
integration_id = str(integration_id)
|
||||
|
||||
if not isinstance(integration_id, str):
|
||||
raise Exception("Oauth error")
|
||||
@ -161,6 +217,8 @@ class OauthIntegration:
|
||||
"id_token": config.pop("id_token", None),
|
||||
}
|
||||
|
||||
config["refreshed_at"] = int(time.time())
|
||||
|
||||
integration, created = Integration.objects.update_or_create(
|
||||
team_id=team_id,
|
||||
kind=kind,
|
||||
@ -172,8 +230,56 @@ class OauthIntegration:
|
||||
},
|
||||
)
|
||||
|
||||
if integration.errors:
|
||||
integration.errors = ""
|
||||
integration.save()
|
||||
|
||||
return integration
|
||||
|
||||
def access_token_expired(self, time_threshold: Optional[timedelta] = None) -> bool:
|
||||
# Not all integrations have refresh tokens or expiries, so we just return False if we can't check
|
||||
|
||||
refresh_token = self.integration.sensitive_config.get("refresh_token")
|
||||
expires_in = self.integration.config.get("expires_in")
|
||||
refreshed_at = self.integration.config.get("refreshed_at")
|
||||
if not refresh_token or not expires_in or not refreshed_at:
|
||||
return False
|
||||
|
||||
# To be really safe we refresh if its half way through the expiry
|
||||
time_threshold = time_threshold or timedelta(seconds=expires_in / 2)
|
||||
|
||||
return time.time() > refreshed_at + expires_in - time_threshold.total_seconds()
|
||||
|
||||
def refresh_access_token(self):
|
||||
"""
|
||||
Refresh the access token for the integration if necessary
|
||||
"""
|
||||
|
||||
oauth_config = self.oauth_config_for_kind(self.integration.kind)
|
||||
|
||||
res = requests.post(
|
||||
oauth_config.token_url,
|
||||
data={
|
||||
"client_id": oauth_config.client_id,
|
||||
"client_secret": oauth_config.client_secret,
|
||||
"refresh_token": self.integration.sensitive_config["refresh_token"],
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
)
|
||||
|
||||
config: dict = res.json()
|
||||
|
||||
if res.status_code != 200 or not config.get("access_token"):
|
||||
logger.warning(f"Failed to refresh token for {self}", response=res.text)
|
||||
self.integration.errors = ERROR_TOKEN_REFRESH_FAILED
|
||||
else:
|
||||
logger.info(f"Refreshed access token for {self}")
|
||||
self.integration.sensitive_config["access_token"] = config["access_token"]
|
||||
self.integration.config["expires_in"] = config.get("expires_in")
|
||||
self.integration.config["refreshed_at"] = int(time.time())
|
||||
reload_integrations_on_workers(self.integration.team_id, [self.integration.id])
|
||||
self.integration.save()
|
||||
|
||||
|
||||
class SlackIntegrationError(Exception):
|
||||
pass
|
||||
|
@ -1,5 +1,12 @@
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
from freezegun import freeze_time
|
||||
import pytest
|
||||
from posthog.models.instance_setting import set_instance_setting
|
||||
from posthog.models.integration import SlackIntegration
|
||||
from posthog.models.integration import Integration, OauthIntegration, SlackIntegration
|
||||
from posthog.test.base import BaseTest
|
||||
|
||||
|
||||
@ -20,3 +27,207 @@ class TestIntegrationModel(BaseTest):
|
||||
"SLACK_APP_CLIENT_SECRET": "client-secret",
|
||||
"SLACK_APP_SIGNING_SECRET": "not-so-secret",
|
||||
}
|
||||
|
||||
|
||||
class TestOauthIntegrationModel(BaseTest):
|
||||
mock_settings = {
|
||||
"SALESFORCE_CONSUMER_KEY": "salesforce-client-id",
|
||||
"SALESFORCE_CONSUMER_SECRET": "salesforce-client-secret",
|
||||
"HUBSPOT_APP_CLIENT_ID": "hubspot-client-id",
|
||||
"HUBSPOT_APP_CLIENT_SECRET": "hubspot-client-secret",
|
||||
}
|
||||
|
||||
def create_integration(
|
||||
self, kind: str, config: Optional[dict] = None, sensitive_config: Optional[dict] = None
|
||||
) -> Integration:
|
||||
_config = {"refreshed_at": int(time.time()), "expires_in": 3600}
|
||||
_sensitive_config = {"refresh_token": "REFRESH"}
|
||||
_config.update(config or {})
|
||||
_sensitive_config.update(sensitive_config or {})
|
||||
|
||||
return Integration.objects.create(team=self.team, kind=kind, config=_config, sensitive_config=_sensitive_config)
|
||||
|
||||
def test_authorize_url_raises_if_not_configured(self):
|
||||
with pytest.raises(NotImplementedError):
|
||||
OauthIntegration.authorize_url("salesforce", next="/projects/test")
|
||||
|
||||
def test_authorize_url(self):
|
||||
with self.settings(**self.mock_settings):
|
||||
url = OauthIntegration.authorize_url("salesforce", next="/projects/test")
|
||||
assert (
|
||||
url
|
||||
== "https://login.salesforce.com/services/oauth2/authorize?client_id=salesforce-client-id&scope=full&redirect_uri=https%3A%2F%2Flocalhost%3A8000%2Fintegrations%2Fsalesforce%2Fcallback&response_type=code&state=next%3D%252Fprojects%252Ftest"
|
||||
)
|
||||
|
||||
@patch("posthog.models.integration.requests.post")
|
||||
def test_integration_from_oauth_response(self, mock_post):
|
||||
with self.settings(**self.mock_settings):
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "FAKES_ACCESS_TOKEN",
|
||||
"refresh_token": "FAKE_REFRESH_TOKEN",
|
||||
"instance_url": "https://fake.salesforce.com",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with freeze_time("2024-01-01T12:00:00Z"):
|
||||
integration = OauthIntegration.integration_from_oauth_response(
|
||||
"salesforce",
|
||||
self.team.id,
|
||||
self.user,
|
||||
{
|
||||
"code": "code",
|
||||
"state": "next=/projects/test",
|
||||
},
|
||||
)
|
||||
|
||||
assert integration.team == self.team
|
||||
assert integration.created_by == self.user
|
||||
|
||||
assert integration.config == {
|
||||
"instance_url": "https://fake.salesforce.com",
|
||||
"refreshed_at": 1704110400,
|
||||
"expires_in": 3600,
|
||||
}
|
||||
assert integration.sensitive_config == {
|
||||
"access_token": "FAKES_ACCESS_TOKEN",
|
||||
"refresh_token": "FAKE_REFRESH_TOKEN",
|
||||
"id_token": None,
|
||||
}
|
||||
|
||||
@patch("posthog.models.integration.requests.post")
|
||||
def test_integration_errors_if_id_cannot_be_generated(self, mock_post):
|
||||
with self.settings(**self.mock_settings):
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "FAKES_ACCESS_TOKEN",
|
||||
"refresh_token": "FAKE_REFRESH_TOKEN",
|
||||
"not_instance_url": "https://fake.salesforce.com",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
OauthIntegration.integration_from_oauth_response(
|
||||
"salesforce",
|
||||
self.team.id,
|
||||
self.user,
|
||||
{
|
||||
"code": "code",
|
||||
"state": "next=/projects/test",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("posthog.models.integration.requests.post")
|
||||
@patch("posthog.models.integration.requests.get")
|
||||
def test_integration_fetches_info_from_token_info_url(self, mock_get, mock_post):
|
||||
with self.settings(**self.mock_settings):
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "FAKES_ACCESS_TOKEN",
|
||||
"refresh_token": "FAKE_REFRESH_TOKEN",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
mock_get.return_value.status_code = 200
|
||||
mock_get.return_value.json.return_value = {
|
||||
"hub_id": "hub_id",
|
||||
"hub_domain": "hub_domain",
|
||||
"user": "user",
|
||||
"user_id": "user_id",
|
||||
"should_not": "be_saved",
|
||||
}
|
||||
|
||||
with freeze_time("2024-01-01T12:00:00Z"):
|
||||
integration = OauthIntegration.integration_from_oauth_response(
|
||||
"hubspot",
|
||||
self.team.id,
|
||||
self.user,
|
||||
{
|
||||
"code": "code",
|
||||
"state": "next=/projects/test",
|
||||
},
|
||||
)
|
||||
|
||||
assert integration.config == {
|
||||
"expires_in": 3600,
|
||||
"hub_id": "hub_id",
|
||||
"hub_domain": "hub_domain",
|
||||
"user": "user",
|
||||
"user_id": "user_id",
|
||||
"refreshed_at": 1704110400,
|
||||
}
|
||||
assert integration.sensitive_config == {
|
||||
"access_token": "FAKES_ACCESS_TOKEN",
|
||||
"refresh_token": "FAKE_REFRESH_TOKEN",
|
||||
"id_token": None,
|
||||
}
|
||||
|
||||
def test_integration_access_token_expired(self):
|
||||
now = datetime.now()
|
||||
with freeze_time(now):
|
||||
integration = self.create_integration(kind="hubspot", config={"expires_in": 1000})
|
||||
|
||||
with freeze_time(now):
|
||||
# Access token is not expired
|
||||
assert not OauthIntegration(integration).access_token_expired()
|
||||
|
||||
with freeze_time(now + timedelta(seconds=1000) - timedelta(seconds=501)):
|
||||
# After the expiry but before the threshold it is not expired
|
||||
assert not OauthIntegration(integration).access_token_expired()
|
||||
|
||||
with freeze_time(now + timedelta(seconds=1000) - timedelta(seconds=499)):
|
||||
# After the threshold it is expired
|
||||
assert OauthIntegration(integration).access_token_expired()
|
||||
|
||||
with freeze_time(now + timedelta(seconds=1000)):
|
||||
# After the threshold it is expired
|
||||
assert OauthIntegration(integration).access_token_expired()
|
||||
|
||||
@patch("posthog.models.integration.reload_integrations_on_workers")
|
||||
@patch("posthog.models.integration.requests.post")
|
||||
def test_refresh_access_token(self, mock_post, mock_reload):
|
||||
mock_post.return_value.status_code = 200
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "REFRESHED_ACCESS_TOKEN",
|
||||
"expires_in": 1000,
|
||||
}
|
||||
|
||||
integration = self.create_integration(kind="hubspot", config={"expires_in": 1000})
|
||||
|
||||
with freeze_time("2024-01-01T14:00:00Z"):
|
||||
with self.settings(**self.mock_settings):
|
||||
OauthIntegration(integration).refresh_access_token()
|
||||
|
||||
mock_post.assert_called_with(
|
||||
"https://api.hubapi.com/oauth/v1/token",
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": "hubspot-client-id",
|
||||
"client_secret": "hubspot-client-secret",
|
||||
"refresh_token": "REFRESH",
|
||||
},
|
||||
)
|
||||
|
||||
assert integration.config["expires_in"] == 1000
|
||||
assert integration.config["refreshed_at"] == 1704117600
|
||||
assert integration.sensitive_config["access_token"] == "REFRESHED_ACCESS_TOKEN"
|
||||
|
||||
mock_reload.assert_called_once_with(self.team.id, [integration.id])
|
||||
|
||||
@patch("posthog.models.integration.reload_integrations_on_workers")
|
||||
@patch("posthog.models.integration.requests.post")
|
||||
def test_refresh_access_token_handles_errors(self, mock_post, mock_reload):
|
||||
mock_post.return_value.status_code = 401
|
||||
mock_post.return_value.json.return_value = {"error": "BROKEN"}
|
||||
|
||||
integration = self.create_integration(kind="hubspot", config={"expires_in": 1000, "refreshed_at": 1700000000})
|
||||
|
||||
with freeze_time("2024-01-01T14:00:00Z"):
|
||||
with self.settings(**self.mock_settings):
|
||||
OauthIntegration(integration).refresh_access_token()
|
||||
|
||||
assert integration.config["expires_in"] == 1000
|
||||
assert integration.config["refreshed_at"] == 1700000000
|
||||
assert integration.errors == "TOKEN_REFRESH_FAILED"
|
||||
|
||||
mock_reload.assert_not_called()
|
||||
|
@ -37,6 +37,11 @@ def reload_hog_functions_on_workers(team_id: int, hog_function_ids: list[str]):
|
||||
publish_message("reload-hog-functions", {"teamId": team_id, "hogFunctionIds": hog_function_ids})
|
||||
|
||||
|
||||
def reload_integrations_on_workers(team_id: int, integration_ids: list[int]):
|
||||
logger.info(f"Reloading integrations {integration_ids} on workers")
|
||||
publish_message("reload-integrations", {"teamId": team_id, "integrationIds": integration_ids})
|
||||
|
||||
|
||||
def reset_available_product_features_cache_on_workers(organization_id: str):
|
||||
logger.info(f"Resetting available product features cache for organization {organization_id} on workers")
|
||||
publish_message(
|
||||
|
@ -9,6 +9,7 @@ from . import (
|
||||
email,
|
||||
exporter,
|
||||
hog_functions,
|
||||
integrations,
|
||||
process_scheduled_changes,
|
||||
split_person,
|
||||
sync_all_organization_available_product_features,
|
||||
@ -28,6 +29,7 @@ __all__ = [
|
||||
"email",
|
||||
"exporter",
|
||||
"hog_functions",
|
||||
"integrations",
|
||||
"process_scheduled_changes",
|
||||
"split_person",
|
||||
"sync_all_organization_available_product_features",
|
||||
|
31
posthog/tasks/integrations.py
Normal file
31
posthog/tasks/integrations.py
Normal file
@ -0,0 +1,31 @@
|
||||
from celery import shared_task
|
||||
|
||||
from posthog.tasks.utils import CeleryQueue
|
||||
|
||||
|
||||
@shared_task(ignore_result=True, queue=CeleryQueue.INTEGRATIONS.value)
|
||||
def refresh_integrations() -> int:
|
||||
from posthog.models.integration import Integration, OauthIntegration
|
||||
|
||||
oauth_integrations = Integration.objects.filter(kind__in=OauthIntegration.supported_kinds).all()
|
||||
|
||||
for integration in oauth_integrations:
|
||||
oauth_integration = OauthIntegration(integration)
|
||||
|
||||
if oauth_integration.access_token_expired():
|
||||
refresh_integration.delay(integration.id)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@shared_task(ignore_result=True, queue=CeleryQueue.INTEGRATIONS.value)
|
||||
def refresh_integration(id: int) -> int:
|
||||
from posthog.models.integration import Integration, OauthIntegration
|
||||
|
||||
integration = Integration.objects.get(id=id)
|
||||
|
||||
if integration.kind in OauthIntegration.supported_kinds:
|
||||
oauth_integration = OauthIntegration(integration)
|
||||
oauth_integration.refresh_access_token()
|
||||
|
||||
return 0
|
@ -8,6 +8,7 @@ from django.conf import settings
|
||||
|
||||
from posthog.caching.warming import schedule_warming_for_teams_task
|
||||
from posthog.celery import app
|
||||
from posthog.tasks.integrations import refresh_integrations
|
||||
from posthog.tasks.tasks import (
|
||||
calculate_cohort,
|
||||
calculate_decide_usage,
|
||||
@ -330,3 +331,11 @@ def setup_periodic_tasks(sender: Celery, **kwargs: Any) -> None:
|
||||
calculate_external_data_rows_synced.s(),
|
||||
name="calculate external data rows synced",
|
||||
)
|
||||
|
||||
# Check integrations to refresh every minute
|
||||
add_periodic_task_with_expiry(
|
||||
sender,
|
||||
60,
|
||||
refresh_integrations.s(),
|
||||
name="refresh integrations",
|
||||
)
|
||||
|
38
posthog/tasks/test/test_integrations.py
Normal file
38
posthog/tasks/test/test_integrations.py
Normal file
@ -0,0 +1,38 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
from posthog.models.integration import Integration
|
||||
from posthog.tasks.integrations import refresh_integrations
|
||||
from posthog.test.base import APIBaseTest
|
||||
|
||||
|
||||
class TestIntegrationsTasks(APIBaseTest):
|
||||
integrations: list[Integration] = []
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
def create_integration(
|
||||
self, kind: str, config: Optional[dict] = None, sensitive_config: Optional[dict] = None
|
||||
) -> Integration:
|
||||
_config = {"refreshed_at": int(time.time()), "expires_in": 3600}
|
||||
_sensitive_config = {"refresh_token": "REFRESH"}
|
||||
_config.update(config or {})
|
||||
_sensitive_config.update(sensitive_config or {})
|
||||
|
||||
return Integration.objects.create(team=self.team, kind=kind, config=_config, sensitive_config=_sensitive_config)
|
||||
|
||||
def test_refresh_integrations_schedules_refreshes_for_expired(self) -> None:
|
||||
_integration_1 = self.create_integration("other") # not an oauth one
|
||||
_integration_2 = self.create_integration("slack") # not expired
|
||||
integration_3 = self.create_integration("slack", config={"refreshed_at": time.time() - 3600}) # expired
|
||||
integration_4 = self.create_integration(
|
||||
"slack", config={"refreshed_at": time.time() - 3600 + 170}
|
||||
) # expired with buffer
|
||||
|
||||
with patch("posthog.tasks.integrations.refresh_integration.delay") as refresh_integration_mock:
|
||||
refresh_integrations()
|
||||
# Both 3 and 4 should be refreshed
|
||||
assert refresh_integration_mock.call_args_list == [((integration_3.id,),), ((integration_4.id,),)]
|
@ -38,3 +38,4 @@ class CeleryQueue(Enum):
|
||||
SESSION_REPLAY_EMBEDDINGS = "session_replay_embeddings"
|
||||
SESSION_REPLAY_PERSISTENCE = "session_replay_persistence"
|
||||
SESSION_REPLAY_GENERAL = "session_replay_general"
|
||||
INTEGRATIONS = "integrations"
|
||||
|
Loading…
Reference in New Issue
Block a user