0
0
mirror of https://github.com/PostHog/posthog.git synced 2024-11-21 21:49:51 +01:00

chore(data-warehouse): refactor salesforce integration oauth (#24378)

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Eric Duong 2024-08-15 10:15:38 -04:00 committed by GitHub
parent ce89f5969d
commit 2d9956cbcd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 234 additions and 157 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 38 KiB

After

Width:  |  Height:  |  Size: 38 KiB

View File

@ -0,0 +1,23 @@
import {
IntegrationChoice,
IntegrationConfigureProps,
} from 'scenes/pipeline/hogfunctions/integrations/IntegrationChoice'
import { SourceConfig } from '~/types'
export type DataWarehouseIntegrationChoice = IntegrationConfigureProps & {
sourceConfig: SourceConfig
}
export function DataWarehouseIntegrationChoice({
sourceConfig,
...props
}: DataWarehouseIntegrationChoice): JSX.Element {
return (
<IntegrationChoice
{...props}
integration={sourceConfig.name.toLowerCase()}
redirectUrl={`/data-warehouse/new?kind=${sourceConfig.name.toLowerCase()}`}
/>
)
}

View File

@ -6,6 +6,7 @@ import { LemonField } from 'lib/lemon-ui/LemonField'
import { SourceConfig, SourceFieldConfig } from '~/types'
import { SOURCE_DETAILS, sourceWizardLogic } from '../../new/sourceWizardLogic'
import { DataWarehouseIntegrationChoice } from './DataWarehouseIntegrationChoice'
interface SourceFormProps {
sourceConfig: SourceConfig
@ -13,14 +14,18 @@ interface SourceFormProps {
showSourceFields?: boolean
}
const sourceFieldToElement = (field: SourceFieldConfig): JSX.Element => {
const sourceFieldToElement = (field: SourceFieldConfig, sourceConfig: SourceConfig): JSX.Element => {
if (field.type === 'switch-group') {
return (
<LemonField key={field.name} name={[field.name, 'enabled']} label={field.label}>
{({ value, onChange }) => (
<>
<LemonSwitch checked={value} onChange={onChange} />
{value && <Group name={field.name}>{field.fields.map(sourceFieldToElement)}</Group>}
{value && (
<Group name={field.name}>
{field.fields.map((field) => sourceFieldToElement(field, sourceConfig))}
</Group>
)}
</>
)}
</LemonField>
@ -42,7 +47,7 @@ const sourceFieldToElement = (field: SourceFieldConfig): JSX.Element => {
<Group name={field.name}>
{field.options
.find((n) => n.value === (value ?? field.defaultValue))
?.fields?.map(sourceFieldToElement)}
?.fields?.map((field) => sourceFieldToElement(field, sourceConfig))}
</Group>
</>
)}
@ -63,6 +68,21 @@ const sourceFieldToElement = (field: SourceFieldConfig): JSX.Element => {
)
}
if (field.type === 'oauth') {
return (
<LemonField key={field.name} name={field.name} label={field.label}>
{({ value, onChange }) => (
<DataWarehouseIntegrationChoice
key={field.name}
sourceConfig={sourceConfig}
value={value}
onChange={onChange}
/>
)}
</LemonField>
)
}
return (
<LemonField key={field.name} name={field.name} label={field.label}>
<LemonInput
@ -88,7 +108,7 @@ export default function SourceForm({ sourceConfig }: SourceFormProps): JSX.Eleme
<Form logic={sourceWizardLogic} formKey="sourceConnectionDetails" className="space-y-4" enableFormOnSubmit>
{showSourceFields && (
<Group name="payload">
{SOURCE_DETAILS[sourceConfig.name].fields.map((field) => sourceFieldToElement(field))}
{SOURCE_DETAILS[sourceConfig.name].fields.map((field) => sourceFieldToElement(field, sourceConfig))}
</Group>
)}
{showPrefix && (

View File

@ -41,7 +41,6 @@ const Caption = (): JSX.Element => (
)
export const getHubspotRedirectUri = (): string => `${window.location.origin}/data-warehouse/hubspot/redirect`
export const getSalesforceRedirectUri = (): string => `${window.location.origin}/data-warehouse/salesforce/redirect`
export const SOURCE_DETAILS: Record<ExternalDataSourceType, SourceConfig> = {
Stripe: {
@ -424,6 +423,18 @@ export const SOURCE_DETAILS: Record<ExternalDataSourceType, SourceConfig> = {
},
],
},
Salesforce: {
name: 'Salesforce',
fields: [
{
name: 'integration_id',
label: 'Salesforce account',
type: 'oauth',
required: true,
},
],
caption: 'Select an existing Salesforce account to link to PostHog or create a new connection',
},
}
export const buildKeaFormDefaultFromSourceDetails = (
@ -750,27 +761,6 @@ export const sourceWizardLogic = kea<sourceWizardLogicType>([
}
},
],
addToSalesforceButtonUrl: [
(s) => [s.preflight],
(preflight) => {
return (subdomain: string) => {
const clientId = preflight?.data_warehouse_integrations?.salesforce.client_id
if (!clientId) {
return null
}
const params = new URLSearchParams()
params.set('client_id', clientId)
params.set('redirect_uri', `${window.location.origin}/data-warehouse/salesforce/redirect`)
params.set('response_type', 'code')
params.set('scope', 'refresh_token api')
params.set('state', subdomain)
return `https://${subdomain}.my.salesforce.com/services/oauth2/authorize?${params.toString()}`
}
},
],
modalTitle: [
(s) => [s.currentStep],
(currentStep) => {
@ -908,6 +898,12 @@ export const sourceWizardLogic = kea<sourceWizardLogicType>([
})
return
}
case 'salesforce': {
actions.updateSource({
source_type: 'Salesforce',
})
break
}
default:
lemonToast.error(`Something went wrong.`)
}
@ -951,8 +947,6 @@ export const sourceWizardLogic = kea<sourceWizardLogicType>([
if (kind === 'salesforce') {
router.actions.push(urls.dataWarehouseTable(), {
kind,
code: searchParams.code,
subdomain: searchParams.state,
})
}
},
@ -964,21 +958,17 @@ export const sourceWizardLogic = kea<sourceWizardLogicType>([
})
actions.setStep(2)
}
if (searchParams.kind == 'salesforce') {
actions.selectConnector(SOURCE_DETAILS['Salesforce'])
actions.handleRedirect(searchParams.kind, {})
actions.setStep(2)
}
},
})),
forms(({ actions, values }) => ({
sourceConnectionDetails: {
defaults: buildKeaFormDefaultFromSourceDetails(SOURCE_DETAILS),
errors: (sourceValues) => {
if (
values.selectedConnector &&
SOURCE_DETAILS[values.selectedConnector?.name].oauthPayload &&
SOURCE_DETAILS[values.selectedConnector.name].oauthPayload?.every(
(element) => values.source.payload[element]
)
) {
return {}
}
return getErrorsForFields(values.selectedConnector?.fields ?? [], sourceValues as any)
},
submit: async (sourceValues) => {

View File

@ -1,99 +1,17 @@
import { IconExternal, IconX } from '@posthog/icons'
import { LemonButton, LemonMenu, LemonSkeleton } from '@posthog/lemon-ui'
import { useValues } from 'kea'
import api from 'lib/api'
import { integrationsLogic } from 'lib/integrations/integrationsLogic'
import { IntegrationView } from 'lib/integrations/IntegrationView'
import { capitalizeFirstLetter } from 'lib/utils'
import { urls } from 'scenes/urls'
import { HogFunctionInputSchemaType } from '~/types'
type HogFunctionInputIntegrationConfigureProps = {
value?: number
onChange?: (value: number | null) => void
}
import { IntegrationChoice, IntegrationConfigureProps } from './IntegrationChoice'
export type HogFunctionInputIntegrationProps = HogFunctionInputIntegrationConfigureProps & {
export type HogFunctionInputIntegrationProps = IntegrationConfigureProps & {
schema: HogFunctionInputSchemaType
}
export function HogFunctionInputIntegration({ schema, ...props }: HogFunctionInputIntegrationProps): JSX.Element {
return <HogFunctionIntegrationChoice {...props} schema={schema} />
}
function HogFunctionIntegrationChoice({
onChange,
value,
schema,
}: HogFunctionInputIntegrationProps): JSX.Element | null {
const { integrationsLoading, integrations } = useValues(integrationsLogic)
const kind = schema.integration
const integrationsOfKind = integrations?.filter((x) => x.kind === kind)
const integration = integrationsOfKind?.find((integration) => integration.id === value)
if (!kind) {
return null
}
if (integrationsLoading) {
return <LemonSkeleton className="h-10" />
}
const button = (
<LemonMenu
items={[
integrationsOfKind?.length
? {
items: [
...(integrationsOfKind?.map((integration) => ({
icon: <img src={integration.icon_url} className="w-6 h-6 rounded" />,
onClick: () => onChange?.(integration.id),
active: integration.id === value,
label: integration.display_name,
})) || []),
],
}
: null,
{
items: [
{
to: api.integrations.authorizeUrl({
kind,
next: `${window.location.pathname}?integration_target=${schema.key}`,
}),
disableClientSideRouting: true,
label: integrationsOfKind?.length
? `Connect to a different ${kind} integration`
: `Connect to ${kind}`,
},
],
},
{
items: [
{
to: urls.settings('project-integrations'),
label: 'Manage integrations',
sideIcon: <IconExternal />,
},
value
? {
onClick: () => onChange?.(null),
label: 'Clear',
sideIcon: <IconX />,
}
: null,
],
},
]}
>
{integration ? (
<LemonButton type="secondary">Change</LemonButton>
) : (
<LemonButton type="secondary">Choose {capitalizeFirstLetter(kind)} connection</LemonButton>
)}
</LemonMenu>
return (
<IntegrationChoice
{...props}
integration={schema.integration}
redirectUrl={`${window.location.pathname}?integration_target=${schema.key}`}
/>
)
return <>{integration ? <IntegrationView integration={integration} suffix={button} /> : button}</>
}

View File

@ -0,0 +1,92 @@
import { IconExternal, IconX } from '@posthog/icons'
import { LemonButton, LemonMenu, LemonSkeleton } from '@posthog/lemon-ui'
import { useValues } from 'kea'
import api from 'lib/api'
import { integrationsLogic } from 'lib/integrations/integrationsLogic'
import { IntegrationView } from 'lib/integrations/IntegrationView'
import { capitalizeFirstLetter } from 'lib/utils'
import { urls } from 'scenes/urls'
export type IntegrationConfigureProps = {
value?: number
onChange?: (value: number | null) => void
redirectUrl?: string
integration?: string
}
export function IntegrationChoice({
onChange,
value,
integration,
redirectUrl,
}: IntegrationConfigureProps): JSX.Element | null {
const { integrationsLoading, integrations } = useValues(integrationsLogic)
const kind = integration
const integrationsOfKind = integrations?.filter((x) => x.kind === kind)
const integrationKind = integrationsOfKind?.find((integration) => integration.id === value)
if (!kind) {
return null
}
if (integrationsLoading) {
return <LemonSkeleton className="h-10" />
}
const button = (
<LemonMenu
items={[
integrationsOfKind?.length
? {
items: [
...(integrationsOfKind?.map((integration) => ({
icon: <img src={integration.icon_url} className="w-6 h-6 rounded" />,
onClick: () => onChange?.(integration.id),
active: integration.id === value,
label: integration.display_name,
})) || []),
],
}
: null,
{
items: [
{
to: api.integrations.authorizeUrl({
kind,
next: redirectUrl,
}),
disableClientSideRouting: true,
label: integrationsOfKind?.length
? `Connect to a different ${kind} integration`
: `Connect to ${kind}`,
},
],
},
{
items: [
{
to: urls.settings('project-integrations'),
label: 'Manage integrations',
sideIcon: <IconExternal />,
},
value
? {
onClick: () => onChange?.(null),
label: 'Clear',
sideIcon: <IconX />,
}
: null,
],
},
]}
>
{integrationKind ? (
<LemonButton type="secondary">Change</LemonButton>
) : (
<LemonButton type="secondary">Choose {capitalizeFirstLetter(kind)} connection</LemonButton>
)}
</LemonMenu>
)
return <>{integrationKind ? <IntegrationView integration={integrationKind} suffix={button} /> : button}</>
}

View File

@ -3834,7 +3834,15 @@ export enum DataWarehouseSettingsTab {
SelfManaged = 'self-managed',
}
export const externalDataSources = ['Stripe', 'Hubspot', 'Postgres', 'MySQL', 'Zendesk', 'Snowflake'] as const
export const externalDataSources = [
'Stripe',
'Hubspot',
'Postgres',
'MySQL',
'Zendesk',
'Snowflake',
'Salesforce',
] as const
export type ExternalDataSourceType = (typeof externalDataSources)[number]
@ -4186,6 +4194,13 @@ export enum SidePanelTab {
Exports = 'exports',
}
export interface SourceFieldOauthConfig {
type: 'oauth'
name: string
label: string
required: boolean
}
export interface SourceFieldInputConfig {
type: LemonInputProps['type'] | 'textarea'
name: string
@ -4211,7 +4226,11 @@ export interface SourceFieldSwitchGroupConfig {
fields: SourceFieldConfig[]
}
export type SourceFieldConfig = SourceFieldInputConfig | SourceFieldSwitchGroupConfig | SourceFieldSelectConfig
export type SourceFieldConfig =
| SourceFieldInputConfig
| SourceFieldSwitchGroupConfig
| SourceFieldSelectConfig
| SourceFieldOauthConfig
export interface SourceConfig {
name: ExternalDataSourceType

View File

@ -18,6 +18,7 @@ from posthog.models.user import User
import structlog
from posthog.plugins.plugin_server_api import reload_integrations_on_workers
from posthog.warehouse.util import database_sync_to_async
logger = structlog.get_logger(__name__)
@ -71,6 +72,19 @@ class Integration(models.Model):
return f"ID: {self.integration_id}"
@property
def access_token(self) -> Optional[str]:
return self.sensitive_config.get("access_token")
@property
def refresh_token(self) -> Optional[str]:
return self.sensitive_config.get("refresh_token")
@database_sync_to_async
def aget_integration_by_id(integration_id: str, team_id: int) -> Integration | None:
return Integration.objects.get(id=integration_id, team_id=team_id)
@dataclass
class OauthConfig:
@ -125,7 +139,7 @@ class OauthIntegration:
token_url="https://login.salesforce.com/services/oauth2/token",
client_id=settings.SALESFORCE_CONSUMER_KEY,
client_secret=settings.SALESFORCE_CONSUMER_SECRET,
scope="full",
scope="full refresh_token",
id_path="instance_url",
name_path="instance_url",
)

View File

@ -56,7 +56,7 @@ class TestOauthIntegrationModel(BaseTest):
url = OauthIntegration.authorize_url("salesforce", next="/projects/test")
assert (
url
== "https://login.salesforce.com/services/oauth2/authorize?client_id=salesforce-client-id&scope=full&redirect_uri=https%3A%2F%2Flocalhost%3A8000%2Fintegrations%2Fsalesforce%2Fcallback&response_type=code&state=next%3D%252Fprojects%252Ftest"
== "https://login.salesforce.com/services/oauth2/authorize?client_id=salesforce-client-id&scope=full+refresh_token&redirect_uri=https%3A%2F%2Flocalhost%3A8000%2Fintegrations%2Fsalesforce%2Fcallback&response_type=code&state=next%3D%252Fprojects%252Ftest"
)
@patch("posthog.models.integration.requests.post")

View File

@ -6,7 +6,7 @@ from posthog.temporal.data_imports.pipelines.rest_source.typing import EndpointR
from posthog.temporal.data_imports.pipelines.salesforce.auth import SalseforceAuth
def get_resource(name: str, is_incremental: bool, subdomain: str) -> EndpointResource:
def get_resource(name: str, is_incremental: bool) -> EndpointResource:
resources: dict[str, EndpointResource] = {
"User": {
"name": "User",
@ -153,9 +153,9 @@ def get_resource(name: str, is_incremental: bool, subdomain: str) -> EndpointRes
class SalesforceEndpointPaginator(BasePaginator):
def __init__(self, subdomain):
def __init__(self, instance_url):
super().__init__()
self.subdomain = subdomain
self.instance_url = instance_url
def update_state(self, response: Response) -> None:
res = response.json()
@ -173,12 +173,12 @@ class SalesforceEndpointPaginator(BasePaginator):
self._has_next_page = False
def update_request(self, request: Request) -> None:
request.url = f"https://{self.subdomain}.my.salesforce.com{self._next_page}"
request.url = f"{self.instance_url}{self._next_page}"
@dlt.source(max_table_nesting=0)
def salesforce_source(
subdomain: str,
instance_url: str,
access_token: str,
refresh_token: str,
endpoint: str,
@ -188,14 +188,14 @@ def salesforce_source(
):
config: RESTAPIConfig = {
"client": {
"base_url": f"https://{subdomain}.my.salesforce.com",
"base_url": instance_url,
"auth": SalseforceAuth(refresh_token, access_token),
"paginator": SalesforceEndpointPaginator(subdomain=subdomain),
"paginator": SalesforceEndpointPaginator(instance_url=instance_url),
},
"resource_defaults": {
"primary_key": "id",
},
"resources": [get_resource(endpoint, is_incremental, subdomain)],
"resources": [get_resource(endpoint, is_incremental)],
}
yield from rest_api_resources(config, team_id, job_id)

View File

@ -227,20 +227,30 @@ async def import_data_activity(inputs: ImportDataActivityInputs):
elif model.pipeline.source_type == ExternalDataSource.Type.SALESFORCE:
from posthog.temporal.data_imports.pipelines.salesforce.auth import salesforce_refresh_access_token
from posthog.temporal.data_imports.pipelines.salesforce import salesforce_source
from posthog.models.integration import aget_integration_by_id
subdomain = model.pipeline.job_inputs.get("salesforce_subdomain")
salesforce_access_token = model.pipeline.job_inputs.get("salesforce_access_token", None)
refresh_token = model.pipeline.job_inputs.get("salesforce_refresh_token", None)
if not refresh_token:
salesforce_integration_id = model.pipeline.job_inputs.get("salesforce_integration_id", None)
if not salesforce_integration_id:
raise ValueError(f"Salesforce integration not found for job {model.id}")
integration = await aget_integration_by_id(integration_id=salesforce_integration_id, team_id=inputs.team_id)
salesforce_refresh_token = integration.refresh_token
if not salesforce_refresh_token:
raise ValueError(f"Salesforce refresh token not found for job {model.id}")
salesforce_access_token = integration.access_token
if not salesforce_access_token:
salesforce_access_token = salesforce_refresh_access_token(refresh_token)
salesforce_access_token = salesforce_refresh_access_token(salesforce_refresh_token)
salesforce_instance_url = integration.config.get("instance_url")
source = salesforce_source(
subdomain=subdomain,
instance_url=salesforce_instance_url,
access_token=salesforce_access_token,
refresh_token=refresh_token,
refresh_token=salesforce_refresh_token,
endpoint=schema.name,
team_id=inputs.team_id,
job_id=inputs.run_id,

View File

@ -32,9 +32,6 @@ from posthog.temporal.data_imports.pipelines.schemas import (
from posthog.temporal.data_imports.pipelines.hubspot.auth import (
get_hubspot_access_token_from_code,
)
from posthog.temporal.data_imports.pipelines.salesforce.auth import (
get_salesforce_access_token_from_code,
)
from posthog.warehouse.models.external_data_schema import (
filter_postgres_incremental_fields,
filter_snowflake_incremental_fields,
@ -399,13 +396,9 @@ class ExternalDataSourceViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
def _handle_salesforce_source(self, request: Request, *args: Any, **kwargs: Any) -> ExternalDataSource:
payload = request.data["payload"]
code = payload.get("code")
redirect_uri = payload.get("redirect_uri")
prefix = request.data.get("prefix", None)
source_type = request.data["source_type"]
subdomain = payload.get("subdomain")
access_token, refresh_token = get_salesforce_access_token_from_code(code, redirect_uri=redirect_uri)
integration_id = payload.get("integration_id")
new_source_model = ExternalDataSource.objects.create(
source_id=str(uuid.uuid4()),
@ -415,9 +408,7 @@ class ExternalDataSourceViewSet(TeamAndOrgViewSetMixin, viewsets.ModelViewSet):
status="Running",
source_type=source_type,
job_inputs={
"salesforce_access_token": access_token,
"salesforce_refresh_token": refresh_token,
"salesforce_subdomain": subdomain,
"salesforce_integration_id": integration_id,
},
prefix=prefix,
)