0
0
mirror of https://github.com/wagtail/wagtail.git synced 2024-12-01 11:41:20 +01:00

Merge pull request #1538 from tomchristie/rest-framework

Refactor API to REST framework.
This commit is contained in:
Karl Hobley 2015-08-03 14:33:49 +01:00
commit 94ad915fee
10 changed files with 588 additions and 582 deletions

View File

@ -28,6 +28,7 @@ install_requires = [
"django-modelcluster>=0.6",
"django-taggit>=0.13.0",
"django-treebeard==3.0",
"djangorestframework==3.1.3",
"Pillow>=2.6.1",
"beautifulsoup4>=4.3.2",
"html5lib==0.999",

View File

@ -22,6 +22,7 @@ deps =
django-taggit==0.13.0
django-treebeard==3.0
django-sendfile==0.3.6
djangorestframework==3.1.3
Pillow>=2.3.0
beautifulsoup4>=4.3.2
html5lib==0.999

View File

@ -1,95 +0,0 @@
import json
from functools import wraps
from django.conf.urls import url, include
from django.http import HttpResponse, HttpResponseBadRequest, HttpResponseNotFound, Http404
from django.core.serializers.json import DjangoJSONEncoder
from django.core.urlresolvers import reverse
from taggit.managers import _TaggableManager
from taggit.models import Tag
from wagtail.utils.urlpatterns import decorate_urlpatterns
from wagtail.wagtailcore.blocks import StreamValue
from .endpoints import URLPath, ObjectDetailURL, PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint
from .utils import BadRequestError, get_base_url
def get_full_url(request, path):
base_url = get_base_url(request) or ''
return base_url + path
class API(object):
def __init__(self, endpoints):
self.endpoints = endpoints
def find_model_detail_view(self, model):
for endpoint_name, endpoint in self.endpoints.items():
if endpoint.has_model(model):
return 'wagtailapi_v1:%s:detail' % endpoint_name
def make_response(self, request, data, response_cls=HttpResponse):
api = self
class WagtailAPIJSONEncoder(DjangoJSONEncoder):
def default(self, o):
if isinstance(o, _TaggableManager):
return list(o.all())
elif isinstance(o, Tag):
return o.name
elif isinstance(o, URLPath):
return get_full_url(request, o.path)
elif isinstance(o, ObjectDetailURL):
view = api.find_model_detail_view(o.model)
if view:
return get_full_url(request, reverse(view, args=(o.pk, )))
else:
return None
elif isinstance(o, StreamValue):
return o.stream_block.get_prep_value(o)
else:
return super(WagtailAPIJSONEncoder, self).default(o)
return response_cls(
json.dumps(data, indent=4, cls=WagtailAPIJSONEncoder),
content_type='application/json'
)
def api_view(self, view):
"""
This is a decorator that is applied to all API views.
It is responsible for serialising the responses from the endpoints
and handling errors.
"""
@wraps(view)
def wrapper(request, *args, **kwargs):
# Catch exceptions and format them as JSON documents
try:
return self.make_response(request, view(request, *args, **kwargs))
except Http404 as e:
return self.make_response(request, {
'message': str(e)
}, response_cls=HttpResponseNotFound)
except BadRequestError as e:
return self.make_response(request, {
'message': str(e)
}, response_cls=HttpResponseBadRequest)
return wrapper
def get_urlpatterns(self):
return decorate_urlpatterns([
url(r'^%s/' % name, include(endpoint.get_urlpatterns(), namespace=name))
for name, endpoint in self.endpoints.items()
], self.api_view)
v1 = API({
'pages': PagesAPIEndpoint(),
'images': ImagesAPIEndpoint(),
'documents': DocumentsAPIEndpoint(),
})

View File

@ -1,100 +1,34 @@
from __future__ import absolute_import
from collections import OrderedDict
from modelcluster.models import get_all_child_relations
from taggit.managers import _TaggableManager
from django.db import models
from django.utils.encoding import force_text
from django.shortcuts import get_object_or_404
from django.conf.urls import url
from django.conf import settings
from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
from wagtail.wagtailcore.models import Page
from wagtail.wagtailimages.models import get_image_model
from wagtail.wagtaildocs.models import Document
from wagtail.wagtailcore.utils import resolve_model_string
from wagtail.wagtailsearch.backends import get_search_backend
from wagtail.utils.compat import get_related_model
from .filters import (
FieldsFilter, OrderingFilter, SearchFilter,
ChildOfFilter, DescendantOfFilter
)
from .renderers import WagtailJSONRenderer
from .pagination import WagtailPagination
from .serializers import WagtailSerializer, PageSerializer, DocumentSerializer
from .utils import BadRequestError
class URLPath(object):
"""
This class represents a URL path that should be converted to a full URL.
class BaseAPIEndpoint(GenericViewSet):
renderer_classes = [WagtailJSONRenderer]
pagination_class = WagtailPagination
serializer_class = WagtailSerializer
filter_classes = []
queryset = None # Set on subclasses or implement `get_queryset()`.
It is used when the domain that should be used is not known at the time
the URL was generated. It will get resolved to a full URL during
serialisation in api.py.
One example use case is the documents endpoint adding download URLs into
the JSON. The endpoint does not know the domain name to use at the time so
returns one of these instead.
"""
def __init__(self, path):
self.path = path
class ObjectDetailURL(object):
def __init__(self, model, pk):
self.model = model
self.pk = pk
def get_api_data(obj, fields):
# Find any child relations (pages only)
child_relations = {}
if isinstance(obj, Page):
child_relations = {
child_relation.field.rel.related_name: get_related_model(child_relation)
for child_relation in get_all_child_relations(type(obj))
}
# Loop through fields
for field_name in fields:
# Check child relations
if field_name in child_relations and hasattr(child_relations[field_name], 'api_fields'):
yield field_name, [
dict(get_api_data(child_object, child_relations[field_name].api_fields))
for child_object in getattr(obj, field_name).all()
]
continue
# Check django fields
try:
field = obj._meta.get_field(field_name)
if field.rel and isinstance(field.rel, models.ManyToOneRel):
# Foreign key
val = field._get_val_from_obj(obj)
if val:
yield field_name, OrderedDict([
('id', field._get_val_from_obj(obj)),
('meta', OrderedDict([
('type', field.rel.to._meta.app_label + '.' + field.rel.to.__name__),
('detail_url', ObjectDetailURL(field.rel.to, val)),
])),
])
else:
yield field_name, None
else:
yield field_name, field._get_val_from_obj(obj)
continue
except models.fields.FieldDoesNotExist:
pass
# Check attributes
if hasattr(obj, field_name):
value = getattr(obj, field_name)
yield field_name, force_text(value, strings_only=True)
continue
class BaseAPIEndpoint(object):
known_query_parameters = frozenset([
'limit',
'offset',
@ -102,76 +36,48 @@ class BaseAPIEndpoint(object):
'order',
'search',
])
extra_api_fields = []
name = None # Set on subclass.
def listing_view(self, request):
return NotImplemented
queryset = self.get_queryset()
self.check_query_parameters(queryset)
queryset = self.filter_queryset(queryset)
queryset = self.paginate_queryset(queryset)
serializer = self.get_serializer(queryset, many=True)
return self.get_paginated_response(serializer.data)
def detail_view(self, request, pk):
return NotImplemented
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)
def handle_exception(self, exc):
if isinstance(exc, Http404):
data = {'message': str(exc)}
return Response(data, status=status.HTTP_404_NOT_FOUND)
elif isinstance(exc, BadRequestError):
data = {'message': str(exc)}
return Response(data, status=status.HTTP_400_BAD_REQUEST)
return super(BaseAPIEndpoint, self).handle_exception(exc)
def get_api_fields(self, model):
"""
This returns a list of field names that are allowed to
be used in the API (excluding the id field).
"""
api_fields = []
api_fields = self.extra_api_fields[:]
if hasattr(model, 'api_fields'):
api_fields.extend(model.api_fields)
return api_fields
def serialize_object_metadata(self, request, obj, show_details=False):
def check_query_parameters(self, queryset):
"""
This returns a JSON-serialisable dict to use for the "meta"
section of a particlular object.
Ensure that only valid query paramters are included in the URL.
"""
data = OrderedDict()
# Add type
data['type'] = type(obj)._meta.app_label + '.' + type(obj).__name__
data['detail_url'] = ObjectDetailURL(type(obj), obj.pk)
return data
def serialize_object(self, request, obj, fields=frozenset(), extra_data=(), all_fields=False, show_details=False):
"""
This converts an object into JSON-serialisable dict so it can
be used in the API.
"""
data = [
('id', obj.id),
]
# Add meta
metadata = self.serialize_object_metadata(request, obj, show_details=show_details)
if metadata:
data.append(('meta', metadata))
# Add extra data
data.extend(extra_data)
# Add other fields
api_fields = self.get_api_fields(type(obj))
api_fields = list(OrderedDict.fromkeys(api_fields)) # Removes any duplicates in case the user put "title" in api_fields
if all_fields:
fields = api_fields
else:
unknown_fields = fields - set(api_fields)
if unknown_fields:
raise BadRequestError("unknown fields: %s" % ', '.join(sorted(unknown_fields)))
# Reorder fields so it matches the order of api_fields
fields = [field for field in api_fields if field in fields]
data.extend(get_api_data(obj, fields))
return OrderedDict(data)
def check_query_parameters(self, request, queryset):
query_parameters = set(request.GET.keys())
query_parameters = set(self.request.GET.keys())
# All query paramters must be either a field or an operation
allowed_query_parameters = set(self.get_api_fields(queryset.model)).union(self.known_query_parameters).union({'id'})
@ -179,147 +85,87 @@ class BaseAPIEndpoint(object):
if unknown_parameters:
raise BadRequestError("query parameter is not an operation or a recognised field: %s" % ', '.join(sorted(unknown_parameters)))
def do_field_filtering(self, request, queryset):
def get_serializer_context(self):
"""
This performs field level filtering on the result set
Eg: ?title=James Joyce
The serialization context differs between listing and detail views.
"""
fields = set(self.get_api_fields(queryset.model)).union({'id'})
request = self.request
if self.action == 'listing_view':
for field_name, value in request.GET.items():
if field_name in fields:
field = getattr(queryset.model, field_name, None)
if isinstance(field, _TaggableManager):
for tag in value.split(','):
queryset = queryset.filter(**{field_name + '__name': tag})
# Stick a message on the queryset to indicate that tag filtering has been performed
# This will let the do_search method know that it must raise an error as searching
# and tag filtering at the same time is not supported
queryset._filtered_by_tag = True
else:
queryset = queryset.filter(**{field_name: value})
return queryset
def do_ordering(self, request, queryset):
"""
This applies ordering to the result set
Eg: ?order=title
It also supports reverse ordering
Eg: ?order=-title
And random ordering
Eg: ?order=random
"""
if 'order' in request.GET:
# Prevent ordering while searching
if 'search' in request.GET:
raise BadRequestError("ordering with a search query is not supported")
order_by = request.GET['order']
# Random ordering
if order_by == 'random':
# Prevent ordering by random with offset
if 'offset' in request.GET:
raise BadRequestError("random ordering with offset is not supported")
return queryset.order_by('?')
# Check if reverse ordering is set
if order_by.startswith('-'):
reverse_order = True
order_by = order_by[1:]
if 'fields' in request.GET:
fields = set(request.GET['fields'].split(','))
else:
reverse_order = False
fields = {'title'}
# Add ordering
if order_by == 'id' or order_by in self.get_api_fields(queryset.model):
queryset = queryset.order_by(order_by)
else:
# Unknown field
raise BadRequestError("cannot order by '%s' (unknown field)" % order_by)
return {
'request': request,
'view': self,
'fields': fields
}
# Reverse order
if reverse_order:
queryset = queryset.reverse()
return {
'request': request,
'view': self,
'all_fields': True,
'show_details': True
}
return queryset
def get_renderer_context(self):
context = super(BaseAPIEndpoint, self).get_renderer_context()
context['endpoints'] = [
PagesAPIEndpoint,
ImagesAPIEndpoint,
DocumentsAPIEndpoint
]
return context
def do_search(self, request, queryset):
"""
This performs a full-text search on the result set
Eg: ?search=James Joyce
"""
search_enabled = getattr(settings, 'WAGTAILAPI_SEARCH_ENABLED', True)
if 'search' in request.GET:
if not search_enabled:
raise BadRequestError("search is disabled")
# Searching and filtering by tag at the same time is not supported
if getattr(queryset, '_filtered_by_tag', False):
raise BadRequestError("filtering by tag with a search query is not supported")
search_query = request.GET['search']
sb = get_search_backend()
queryset = sb.search(search_query, queryset)
return queryset
def do_pagination(self, request, queryset):
"""
This performs limit/offset based pagination on the result set
Eg: ?limit=10&offset=20 -- Returns 10 items starting at item 20
"""
limit_max = getattr(settings, 'WAGTAILAPI_LIMIT_MAX', 20)
try:
offset = int(request.GET.get('offset', 0))
assert offset >= 0
except (ValueError, AssertionError):
raise BadRequestError("offset must be a positive integer")
try:
limit = int(request.GET.get('limit', min(20, limit_max)))
if limit > limit_max:
raise BadRequestError("limit cannot be higher than %d" % limit_max)
assert limit >= 0
except (ValueError, AssertionError):
raise BadRequestError("limit must be a positive integer")
start = offset
stop = offset + limit
return queryset[start:stop]
def get_urlpatterns(self):
@classmethod
def get_urlpatterns(cls):
"""
This returns a list of URL patterns for the endpoint
"""
return [
url(r'^$', self.listing_view, name='listing'),
url(r'^(\d+)/$', self.detail_view, name='detail'),
url(r'^$', cls.as_view({'get': 'listing_view'}), name='listing'),
url(r'^(?P<pk>\d+)/$', cls.as_view({'get': 'detail_view'}), name='detail'),
]
def has_model(self, model):
return False
@classmethod
def has_model(cls, model):
return NotImplemented
class PagesAPIEndpoint(BaseAPIEndpoint):
serializer_class = PageSerializer
filter_backends = [
FieldsFilter,
ChildOfFilter,
DescendantOfFilter,
OrderingFilter,
SearchFilter
]
known_query_parameters = BaseAPIEndpoint.known_query_parameters.union([
'type',
'child_of',
'descendant_of',
])
extra_api_fields = ['title']
name = 'pages'
def get_queryset(self):
request = self.request
# Allow pages to be filtered to a specific type
if 'type' not in request.GET:
model = Page
else:
model_name = request.GET['type']
try:
model = resolve_model_string(model_name)
except LookupError:
raise BadRequestError("type doesn't exist")
if not issubclass(model, Page):
raise BadRequestError("type doesn't exist")
def get_queryset(self, request, model=Page):
# Get live pages that are not in a private section
queryset = model.objects.public().live()
@ -328,245 +174,33 @@ class PagesAPIEndpoint(BaseAPIEndpoint):
return queryset
def get_api_fields(self, model):
api_fields = ['title']
api_fields.extend(super(PagesAPIEndpoint, self).get_api_fields(model))
return api_fields
def get_object(self):
base = super(PagesAPIEndpoint, self).get_object()
return base.specific
def serialize_object_metadata(self, request, page, show_details=False):
data = super(PagesAPIEndpoint, self).serialize_object_metadata(request, page, show_details=show_details)
# Add type
data['type'] = page.specific_class._meta.app_label + '.' + page.specific_class.__name__
return data
def serialize_object(self, request, page, fields=frozenset(), extra_data=(), all_fields=False, show_details=False):
# Add parent
if show_details:
parent = page.get_parent()
# Make sure the parent is visible in the API
if self.get_queryset(request).filter(id=parent.id).exists():
parent_class = parent.specific_class
extra_data += (
('parent', OrderedDict([
('id', parent.id),
('meta', OrderedDict([
('type', parent_class._meta.app_label + '.' + parent_class.__name__),
('detail_url', ObjectDetailURL(parent_class, parent.id)),
])),
])),
)
return super(PagesAPIEndpoint, self).serialize_object(request, page, fields=fields, extra_data=extra_data, all_fields=all_fields, show_details=show_details)
def get_model(self, request):
if 'type' not in request.GET:
return Page
model_name = request.GET['type']
try:
model = resolve_model_string(model_name)
if not issubclass(model, Page):
raise BadRequestError("type doesn't exist")
return model
except LookupError:
raise BadRequestError("type doesn't exist")
def do_child_of_filter(self, request, queryset):
if 'child_of' in request.GET:
try:
parent_page_id = int(request.GET['child_of'])
assert parent_page_id >= 0
except (ValueError, AssertionError):
raise BadRequestError("child_of must be a positive integer")
try:
parent_page = self.get_queryset(request).get(id=parent_page_id)
queryset = queryset.child_of(parent_page)
queryset._filtered_by_child_of = True
return queryset
except Page.DoesNotExist:
raise BadRequestError("parent page doesn't exist")
return queryset
def do_descendant_of_filter(self, request, queryset):
if 'descendant_of' in request.GET:
if getattr(queryset, '_filtered_by_child_of', False):
raise BadRequestError("filtering by descendant_of with child_of is not supported")
try:
ancestor_page_id = int(request.GET['descendant_of'])
assert ancestor_page_id >= 0
except (ValueError, AssertionError):
raise BadRequestError("descendant_of must be a positive integer")
try:
ancestor_page = self.get_queryset(request).get(id=ancestor_page_id)
return queryset.descendant_of(ancestor_page)
except Page.DoesNotExist:
raise BadRequestError("ancestor page doesn't exist")
return queryset
def listing_view(self, request):
# Get model and queryset
model = self.get_model(request)
queryset = self.get_queryset(request, model=model)
# Check query paramters
self.check_query_parameters(request, queryset)
# Filtering
queryset = self.do_field_filtering(request, queryset)
queryset = self.do_child_of_filter(request, queryset)
queryset = self.do_descendant_of_filter(request, queryset)
# Ordering
queryset = self.do_ordering(request, queryset)
# Search
queryset = self.do_search(request, queryset)
# Pagination
total_count = queryset.count()
queryset = self.do_pagination(request, queryset)
# Get list of fields to show in results
if 'fields' in request.GET:
fields = set(request.GET['fields'].split(','))
else:
fields = {'title'}
return OrderedDict([
('meta', OrderedDict([
('total_count', total_count),
])),
('pages', [
self.serialize_object(request, page, fields=fields)
for page in queryset
]),
])
def detail_view(self, request, pk):
page = get_object_or_404(self.get_queryset(request), pk=pk).specific
return self.serialize_object(request, page, all_fields=True, show_details=True)
def has_model(self, model):
@classmethod
def has_model(cls, model):
return issubclass(model, Page)
class ImagesAPIEndpoint(BaseAPIEndpoint):
model = get_image_model()
queryset = get_image_model().objects.all().order_by('id')
filter_backends = [FieldsFilter, OrderingFilter, SearchFilter]
extra_api_fields = ['title', 'tags', 'width', 'height']
name = 'images'
def get_queryset(self, request):
return self.model.objects.all().order_by('id')
def get_api_fields(self, model):
api_fields = ['title', 'tags', 'width', 'height']
api_fields.extend(super(ImagesAPIEndpoint, self).get_api_fields(model))
return api_fields
def listing_view(self, request):
queryset = self.get_queryset(request)
# Check query paramters
self.check_query_parameters(request, queryset)
# Filtering
queryset = self.do_field_filtering(request, queryset)
# Ordering
queryset = self.do_ordering(request, queryset)
# Search
queryset = self.do_search(request, queryset)
# Pagination
total_count = queryset.count()
queryset = self.do_pagination(request, queryset)
# Get list of fields to show in results
if 'fields' in request.GET:
fields = set(request.GET['fields'].split(','))
else:
fields = {'title'}
return OrderedDict([
('meta', OrderedDict([
('total_count', total_count),
])),
('images', [
self.serialize_object(request, image, fields=fields)
for image in queryset
]),
])
def detail_view(self, request, pk):
image = get_object_or_404(self.get_queryset(request), pk=pk)
return self.serialize_object(request, image, all_fields=True)
def has_model(self, model):
return model == self.model
@classmethod
def has_model(cls, model):
return model == get_image_model()
class DocumentsAPIEndpoint(BaseAPIEndpoint):
def get_api_fields(self, model):
api_fields = ['title', 'tags']
api_fields.extend(super(DocumentsAPIEndpoint, self).get_api_fields(model))
return api_fields
queryset = Document.objects.all().order_by('id')
serializer_class = DocumentSerializer
filter_backends = [FieldsFilter, OrderingFilter, SearchFilter]
extra_api_fields = ['title', 'tags']
name = 'documents'
def serialize_object_metadata(self, request, document, show_details=False):
data = super(DocumentsAPIEndpoint, self).serialize_object_metadata(request, document, show_details=show_details)
# Download URL
if show_details:
data['download_url'] = URLPath(document.url)
return data
def listing_view(self, request):
queryset = Document.objects.all().order_by('id')
# Check query paramters
self.check_query_parameters(request, queryset)
# Filtering
queryset = self.do_field_filtering(request, queryset)
# Ordering
queryset = self.do_ordering(request, queryset)
# Search
queryset = self.do_search(request, queryset)
# Pagination
total_count = queryset.count()
queryset = self.do_pagination(request, queryset)
# Get list of fields to show in results
if 'fields' in request.GET:
fields = set(request.GET['fields'].split(','))
else:
fields = {'title'}
return OrderedDict([
('meta', OrderedDict([
('total_count', total_count),
])),
('documents', [
self.serialize_object(request, document, fields=fields)
for document in queryset
]),
])
def detail_view(self, request, pk):
document = get_object_or_404(Document, pk=pk)
return self.serialize_object(request, document, all_fields=True, show_details=True)
def has_model(self, model):
@classmethod
def has_model(cls, model):
return model == Document

View File

@ -0,0 +1,150 @@
from django.conf import settings
from rest_framework.filters import BaseFilterBackend
from taggit.managers import _TaggableManager
from wagtail.wagtailcore.models import Page
from wagtail.wagtailsearch.backends import get_search_backend
from .utils import BadRequestError, pages_for_site
class FieldsFilter(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
"""
This performs field level filtering on the result set
Eg: ?title=James Joyce
"""
fields = set(view.get_api_fields(queryset.model)).union({'id'})
for field_name, value in request.GET.items():
if field_name in fields:
field = getattr(queryset.model, field_name, None)
if isinstance(field, _TaggableManager):
for tag in value.split(','):
queryset = queryset.filter(**{field_name + '__name': tag})
# Stick a message on the queryset to indicate that tag filtering has been performed
# This will let the do_search method know that it must raise an error as searching
# and tag filtering at the same time is not supported
queryset._filtered_by_tag = True
else:
queryset = queryset.filter(**{field_name: value})
return queryset
class OrderingFilter(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
"""
This applies ordering to the result set
Eg: ?order=title
It also supports reverse ordering
Eg: ?order=-title
And random ordering
Eg: ?order=random
"""
if 'order' in request.GET:
# Prevent ordering while searching
if 'search' in request.GET:
raise BadRequestError("ordering with a search query is not supported")
order_by = request.GET['order']
# Random ordering
if order_by == 'random':
# Prevent ordering by random with offset
if 'offset' in request.GET:
raise BadRequestError("random ordering with offset is not supported")
return queryset.order_by('?')
# Check if reverse ordering is set
if order_by.startswith('-'):
reverse_order = True
order_by = order_by[1:]
else:
reverse_order = False
# Add ordering
if order_by == 'id' or order_by in view.get_api_fields(queryset.model):
queryset = queryset.order_by(order_by)
else:
# Unknown field
raise BadRequestError("cannot order by '%s' (unknown field)" % order_by)
# Reverse order
if reverse_order:
queryset = queryset.reverse()
return queryset
class SearchFilter(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
"""
This performs a full-text search on the result set
Eg: ?search=James Joyce
"""
search_enabled = getattr(settings, 'WAGTAILAPI_SEARCH_ENABLED', True)
if 'search' in request.GET:
if not search_enabled:
raise BadRequestError("search is disabled")
# Searching and filtering by tag at the same time is not supported
if getattr(queryset, '_filtered_by_tag', False):
raise BadRequestError("filtering by tag with a search query is not supported")
search_query = request.GET['search']
sb = get_search_backend()
queryset = sb.search(search_query, queryset)
return queryset
class ChildOfFilter(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
if 'child_of' in request.GET:
try:
parent_page_id = int(request.GET['child_of'])
assert parent_page_id >= 0
except (ValueError, AssertionError):
raise BadRequestError("child_of must be a positive integer")
site_pages = pages_for_site(request.site)
try:
parent_page = site_pages.get(id=parent_page_id)
queryset = queryset.child_of(parent_page)
queryset._filtered_by_child_of = True
return queryset
except Page.DoesNotExist:
raise BadRequestError("parent page doesn't exist")
return queryset
class DescendantOfFilter(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
if 'descendant_of' in request.GET:
if getattr(queryset, '_filtered_by_child_of', False):
raise BadRequestError("filtering by descendant_of with child_of is not supported")
try:
ancestor_page_id = int(request.GET['descendant_of'])
assert ancestor_page_id >= 0
except (ValueError, AssertionError):
raise BadRequestError("descendant_of must be a positive integer")
site_pages = pages_for_site(request.site)
try:
ancestor_page = site_pages.get(id=ancestor_page_id)
return queryset.descendant_of(ancestor_page)
except Page.DoesNotExist:
raise BadRequestError("ancestor page doesn't exist")
return queryset

View File

@ -0,0 +1,45 @@
from collections import OrderedDict
from django.conf import settings
from rest_framework.pagination import BasePagination
from rest_framework.response import Response
from .utils import BadRequestError
class WagtailPagination(BasePagination):
def paginate_queryset(self, queryset, request, view=None):
limit_max = getattr(settings, 'WAGTAILAPI_LIMIT_MAX', 20)
try:
offset = int(request.GET.get('offset', 0))
assert offset >= 0
except (ValueError, AssertionError):
raise BadRequestError("offset must be a positive integer")
try:
limit = int(request.GET.get('limit', min(20, limit_max)))
if limit > limit_max:
raise BadRequestError("limit cannot be higher than %d" % limit_max)
assert limit >= 0
except (ValueError, AssertionError):
raise BadRequestError("limit must be a positive integer")
start = offset
stop = offset + limit
self.view = view
self.total_count = queryset.count()
return queryset[start:stop]
def get_paginated_response(self, data):
data = OrderedDict([
('meta', OrderedDict([
('total_count', self.total_count),
])),
(self.view.name, data),
])
return Response(data)

View File

@ -0,0 +1,61 @@
import json
from django.core.serializers.json import DjangoJSONEncoder
from django.core.urlresolvers import reverse
from django.utils.six import text_type
from rest_framework import renderers
from taggit.managers import _TaggableManager
from taggit.models import Tag
from wagtail.wagtailcore.blocks import StreamValue
from .utils import URLPath, ObjectDetailURL, get_base_url
def get_full_url(request, path):
base_url = get_base_url(request) or ''
return base_url + path
def find_model_detail_view(model, endpoints):
for endpoint in endpoints:
if endpoint.has_model(model):
return 'wagtailapi_v1:%s:detail' % endpoint.name
class WagtailJSONRenderer(renderers.BaseRenderer):
media_type = 'application/json'
charset = None
def render(self, data, media_type=None, renderer_context=None):
request = renderer_context['request']
endpoints = renderer_context['endpoints']
class WagtailAPIJSONEncoder(DjangoJSONEncoder):
def default(self, o):
if isinstance(o, _TaggableManager):
return list(o.all())
elif isinstance(o, Tag):
return o.name
elif isinstance(o, URLPath):
return get_full_url(request, o.path)
elif isinstance(o, ObjectDetailURL):
detail_view = find_model_detail_view(o.model, endpoints)
if detail_view:
return get_full_url(request, reverse(detail_view, args=(o.pk, )))
else:
return None
elif isinstance(o, StreamValue):
return o.stream_block.get_prep_value(o)
else:
return super(WagtailAPIJSONEncoder, self).default(o)
ret = json.dumps(data, indent=4, cls=WagtailAPIJSONEncoder)
# Deal with inconsistent py2/py3 behavior, and always return bytes.
if isinstance(ret, text_type):
return bytes(ret.encode('utf-8'))
return ret

View File

@ -0,0 +1,172 @@
from __future__ import absolute_import
from collections import OrderedDict
from django.db import models
from django.utils.encoding import force_text
from modelcluster.models import get_all_child_relations
from rest_framework.serializers import BaseSerializer
from wagtail.utils.compat import get_related_model
from wagtail.wagtailcore.models import Page
from .utils import ObjectDetailURL, URLPath, BadRequestError, pages_for_site
def get_api_data(obj, fields):
# Find any child relations (pages only)
child_relations = {}
if isinstance(obj, Page):
child_relations = {
child_relation.field.rel.related_name: get_related_model(child_relation)
for child_relation in get_all_child_relations(type(obj))
}
# Loop through fields
for field_name in fields:
# Check child relations
if field_name in child_relations and hasattr(child_relations[field_name], 'api_fields'):
yield field_name, [
dict(get_api_data(child_object, child_relations[field_name].api_fields))
for child_object in getattr(obj, field_name).all()
]
continue
# Check django fields
try:
field = obj._meta.get_field(field_name)
if field.rel and isinstance(field.rel, models.ManyToOneRel):
# Foreign key
val = field._get_val_from_obj(obj)
if val:
yield field_name, OrderedDict([
('id', field._get_val_from_obj(obj)),
('meta', OrderedDict([
('type', field.rel.to._meta.app_label + '.' + field.rel.to.__name__),
('detail_url', ObjectDetailURL(field.rel.to, val)),
])),
])
else:
yield field_name, None
else:
yield field_name, field._get_val_from_obj(obj)
continue
except models.fields.FieldDoesNotExist:
pass
# Check attributes
if hasattr(obj, field_name):
value = getattr(obj, field_name)
yield field_name, force_text(value, strings_only=True)
continue
class WagtailSerializer(BaseSerializer):
def to_representation(self, instance):
request = self.context['request']
fields = self.context.get('fields', frozenset())
all_fields = self.context.get('all_fields', False)
show_details = self.context.get('show_details', False)
return self.serialize_object(
request,
instance,
fields=fields,
all_fields=all_fields,
show_details=show_details
)
def serialize_object_metadata(self, request, obj, show_details=False):
"""
This returns a JSON-serialisable dict to use for the "meta"
section of a particlular object.
"""
data = OrderedDict()
# Add type
data['type'] = type(obj)._meta.app_label + '.' + type(obj).__name__
data['detail_url'] = ObjectDetailURL(type(obj), obj.pk)
return data
def serialize_object(self, request, obj, fields=frozenset(), extra_data=(), all_fields=False, show_details=False):
"""
This converts an object into JSON-serialisable dict so it can
be used in the API.
"""
data = [
('id', obj.id),
]
# Add meta
metadata = self.serialize_object_metadata(request, obj, show_details=show_details)
if metadata:
data.append(('meta', metadata))
# Add extra data
data.extend(extra_data)
# Add other fields
api_fields = self.context['view'].get_api_fields(type(obj))
api_fields = list(OrderedDict.fromkeys(api_fields)) # Removes any duplicates in case the user put "title" in api_fields
if all_fields:
fields = api_fields
else:
unknown_fields = fields - set(api_fields)
if unknown_fields:
raise BadRequestError("unknown fields: %s" % ', '.join(sorted(unknown_fields)))
# Reorder fields so it matches the order of api_fields
fields = [field for field in api_fields if field in fields]
data.extend(get_api_data(obj, fields))
return OrderedDict(data)
class PageSerializer(WagtailSerializer):
def serialize_object_metadata(self, request, page, show_details=False):
data = super(PageSerializer, self).serialize_object_metadata(request, page, show_details=show_details)
# Add type
data['type'] = page.specific_class._meta.app_label + '.' + page.specific_class.__name__
return data
def serialize_object(self, request, page, fields=frozenset(), extra_data=(), all_fields=False, show_details=False):
# Add parent
if show_details:
parent = page.get_parent()
site_pages = pages_for_site(request.site)
if site_pages.filter(id=parent.id).exists():
parent_class = parent.specific_class
extra_data += (
('parent', OrderedDict([
('id', parent.id),
('meta', OrderedDict([
('type', parent_class._meta.app_label + '.' + parent_class.__name__),
('detail_url', ObjectDetailURL(parent_class, parent.id)),
])),
])),
)
return super(PageSerializer, self).serialize_object(request, page, fields=fields, extra_data=extra_data, all_fields=all_fields, show_details=show_details)
class DocumentSerializer(WagtailSerializer):
def serialize_object_metadata(self, request, document, show_details=False):
data = super(DocumentSerializer, self).serialize_object_metadata(request, document, show_details=show_details)
# Download URL
if show_details:
data['download_url'] = URLPath(document.url)
return data

View File

@ -2,9 +2,16 @@ from __future__ import absolute_import
from django.conf.urls import url, include
from . import api
from .endpoints import PagesAPIEndpoint, ImagesAPIEndpoint, DocumentsAPIEndpoint
v1 = [
url(r'^pages/', include(PagesAPIEndpoint.get_urlpatterns(), namespace='pages')),
url(r'^images/', include(ImagesAPIEndpoint.get_urlpatterns(), namespace='images')),
url(r'^documents/', include(DocumentsAPIEndpoint.get_urlpatterns(), namespace='documents'))
]
urlpatterns = [
url(r'^v1/', include(api.v1.get_urlpatterns(), namespace='wagtailapi_v1')),
url(r'^v1/', include(v1, namespace='wagtailapi_v1')),
]

View File

@ -1,11 +1,35 @@
from django.conf import settings
from django.utils.six.moves.urllib.parse import urlparse
from wagtail.wagtailcore.models import Page
class BadRequestError(Exception):
pass
class URLPath(object):
"""
This class represents a URL path that should be converted to a full URL.
It is used when the domain that should be used is not known at the time
the URL was generated. It will get resolved to a full URL during
serialisation in api.py.
One example use case is the documents endpoint adding download URLs into
the JSON. The endpoint does not know the domain name to use at the time so
returns one of these instead.
"""
def __init__(self, path):
self.path = path
class ObjectDetailURL(object):
def __init__(self, model, pk):
self.model = model
self.pk = pk
def get_base_url(request=None):
base_url = getattr(settings, 'WAGTAILAPI_BASE_URL', request.site.root_url if request else None)
@ -14,3 +38,9 @@ def get_base_url(request=None):
base_url_parsed = urlparse(base_url)
return base_url_parsed.scheme + '://' + base_url_parsed.netloc
def pages_for_site(site):
pages = Page.objects.public().live()
pages = pages.descendant_of(site.root_page, inclusive=True)
return pages