diff --git a/csp_advanced/csp.py b/csp_advanced/csp.py index 142d6c8..e2007c2 100644 --- a/csp_advanced/csp.py +++ b/csp_advanced/csp.py @@ -1,130 +1,130 @@ -from itertools import chain - - -class InvalidCSPError(ValueError): - pass - - -class CSPCompiler(object): - CSP_LISTS = { - # Fetch directives: - 'connect-src', - 'child-src', - 'default-src', - 'font-src', - 'frame-src', - 'img-src', - 'manifest-src', - 'media-src', - 'object-src', - 'script-src', - 'style-src', - 'worker-src', - - # Navigation directives: - 'form-action', - 'frame-ancestors', - - # Document directives: - 'base-uri', - 'plugin-types', - } - - CSP_BOOLEAN = { - 'upgrade-insecure-requests', - 'block-all-mixed-content', - } - - CSP_FETCH_SPECIAL = { - 'self', - 'none', - 'unsafe-inline', - 'unsafe-eval', - 'strict-dynamic', - } - - CSP_PREFIX_SPECIAL = ( - 'nonce-', - 'sha256-', - 'sha384-', - 'sha512-' - ) - - CSP_SANDBOX_VALID = { - 'allow-forms', - 'allow-modals', - 'allow-orientation-lock', - 'allow-pointer-lock', - 'allow-popups', - 'allow-popups-to-escape-sandbox', - 'allow-presentation', - 'allow-same-origin', - 'allow-scripts', - 'allow-top-navigation', - } - - CSP_REQUIRE_SRI_VALID = { - 'script', - 'style', - 'script style', - } - - def __init__(self, csp_dict): - self.csp = csp_dict - - def compile(self): - pieces = [] - for name, value in self.csp.iteritems(): - if name in self.CSP_LISTS: - if value: - pieces.append(self.compile_list(name, value)) - elif name in self.CSP_BOOLEAN: - if value: - pieces.append(name) - elif name == 'sandbox': - if value: - pieces.append(self.compile_sandbox(value)) - elif name == 'report-uri': - pieces.append(self.compile_report_uri(value)) - elif name == 'require-sri-for': - pieces.append(self.compile_require_sri_for(value)) - else: - raise InvalidCSPError('Unknown directive: %s' % (name,)) - return '; '.join(pieces) - - def compile_list(self, name, value_list): - self.ensure_list(name, value_list) - values = [name] - for value in value_list: - if value in self.CSP_FETCH_SPECIAL or value.startswith(self.CSP_PREFIX_SPECIAL): - values.append("'%s'" % value) - else: - values.append(value) - return ' '.join(values) - - def compile_sandbox(self, values): - self.ensure_list('sandbox', values) - for value in values: - if value not in self.CSP_SANDBOX_VALID: - raise InvalidCSPError('Unknown sandbox value: %s' % (value,)) - return ' '.join(chain(['sandbox'], values)) - - def compile_report_uri(self, value): - self.ensure_str('report-uri', value) - return 'report-uri %s' % value - - def compile_require_sri_for(self, value): - self.ensure_str('require-sri-for', value) - if value not in self.CSP_REQUIRE_SRI_VALID: - raise InvalidCSPError('Unknown require-sri-for value: %s' % (value,)) - return 'require-sri-for %s' % value - - @staticmethod - def ensure_list(name, value): - if not isinstance(value, (list, tuple, set)): - raise InvalidCSPError('Values for %s must be list-like type, not %s', (name, type(value))) - - @staticmethod - def ensure_str(name, value): - if not isinstance(value, basestring): - raise InvalidCSPError('Values for %s must be a string type, not %s', (name, type(value))) +from itertools import chain + + +class InvalidCSPError(ValueError): + pass + + +class CSPCompiler(object): + CSP_LISTS = { + # Fetch directives: + 'connect-src', + 'child-src', + 'default-src', + 'font-src', + 'frame-src', + 'img-src', + 'manifest-src', + 'media-src', + 'object-src', + 'script-src', + 'style-src', + 'worker-src', + + # Navigation directives: + 'form-action', + 'frame-ancestors', + + # Document directives: + 'base-uri', + 'plugin-types', + } + + CSP_BOOLEAN = { + 'upgrade-insecure-requests', + 'block-all-mixed-content', + } + + CSP_FETCH_SPECIAL = { + 'self', + 'none', + 'unsafe-inline', + 'unsafe-eval', + 'strict-dynamic', + } + + CSP_PREFIX_SPECIAL = ( + 'nonce-', + 'sha256-', + 'sha384-', + 'sha512-' + ) + + CSP_SANDBOX_VALID = { + 'allow-forms', + 'allow-modals', + 'allow-orientation-lock', + 'allow-pointer-lock', + 'allow-popups', + 'allow-popups-to-escape-sandbox', + 'allow-presentation', + 'allow-same-origin', + 'allow-scripts', + 'allow-top-navigation', + } + + CSP_REQUIRE_SRI_VALID = { + 'script', + 'style', + 'script style', + } + + def __init__(self, csp_dict): + self.csp = csp_dict + + def compile(self): + pieces = [] + for name, value in self.csp.iteritems(): + if name in self.CSP_LISTS: + if value: + pieces.append(self.compile_list(name, value)) + elif name in self.CSP_BOOLEAN: + if value: + pieces.append(name) + elif name == 'sandbox': + if value: + pieces.append(self.compile_sandbox(value)) + elif name == 'report-uri': + pieces.append(self.compile_report_uri(value)) + elif name == 'require-sri-for': + pieces.append(self.compile_require_sri_for(value)) + else: + raise InvalidCSPError('Unknown directive: %s' % (name,)) + return '; '.join(pieces) + + def compile_list(self, name, value_list): + self.ensure_list(name, value_list) + values = [name] + for value in value_list: + if value in self.CSP_FETCH_SPECIAL or value.startswith(self.CSP_PREFIX_SPECIAL): + values.append("'%s'" % value) + else: + values.append(value) + return ' '.join(values) + + def compile_sandbox(self, values): + self.ensure_list('sandbox', values) + for value in values: + if value not in self.CSP_SANDBOX_VALID: + raise InvalidCSPError('Unknown sandbox value: %s' % (value,)) + return ' '.join(chain(['sandbox'], values)) + + def compile_report_uri(self, value): + self.ensure_str('report-uri', value) + return 'report-uri %s' % value + + def compile_require_sri_for(self, value): + self.ensure_str('require-sri-for', value) + if value not in self.CSP_REQUIRE_SRI_VALID: + raise InvalidCSPError('Unknown require-sri-for value: %s' % (value,)) + return 'require-sri-for %s' % value + + @staticmethod + def ensure_list(name, value): + if not isinstance(value, (list, tuple, set)): + raise InvalidCSPError('Values for %s must be list-like type, not %s', (name, type(value))) + + @staticmethod + def ensure_str(name, value): + if not isinstance(value, basestring): + raise InvalidCSPError('Values for %s must be a string type, not %s', (name, type(value))) diff --git a/csp_advanced/middleware.py b/csp_advanced/middleware.py index e69de29..76b9870 100644 --- a/csp_advanced/middleware.py +++ b/csp_advanced/middleware.py @@ -0,0 +1,53 @@ +from django.conf import settings +from django.core.exceptions import MiddlewareNotUsed + +from csp_advanced.csp import CSPCompiler +from csp_advanced.utils import is_callable_csp_dict, call_csp_dict, merge_csp_dict + + +class AdvancedCSPMiddleware(object): + def __init__(self, get_response=None): + self.get_response = get_response + self.enforced_csp = getattr(settings, 'ADVANCED_CSP', None) or {} + self.enforced_csp_is_str = isinstance(self.enforced_csp, basestring) + self.enforced_csp_callable = is_callable_csp_dict(self.enforced_csp) + self.report_csp = getattr(settings, 'ADVANCED_CSP_REPORT_ONLY', None) or {} + self.report_csp_callable = is_callable_csp_dict(self.report_csp) + self.report_csp_is_str = isinstance(self.enforced_csp, basestring) + self.report_only_csp = not self.enforced_csp + + if not self.enforced_csp and not self.report_csp: + raise MiddlewareNotUsed() + + def add_csp_header(self, request, response, header, base, can_call, is_str, attrs): + if header in response: + return + if is_str: + response[header] = base + return + csp = call_csp_dict(base, request, response) if can_call else base + + for attr in attrs: + update = getattr(response, attr, None) + if update is not None: + if update.pop('override', False): + csp = update + else: + csp = merge_csp_dict(csp, update) + break + + if csp: + response[header] = CSPCompiler(csp).compile() + + def process_response(self, request, response): + if self.enforced_csp: + self.add_csp_header(request, response, 'Content-Security-Policy', self.enforced_csp, + self.enforced_csp_callable, self.enforced_csp_is_str, ('csp',)) + if self.report_csp: + self.add_csp_header(request, response, 'Content-Security-Policy-Report-Only', + self.report_csp, self.report_csp_callable, self.report_csp_is_str, + ('csp_report',) if self.enforced_csp else ('csp_report', 'csp')) + return response + + def __call__(self, request): + return self.process_response(request, self.get_response(request)) diff --git a/csp_advanced/tests.py b/csp_advanced/tests.py index 9b1f3b3..8f9c3c9 100644 --- a/csp_advanced/tests.py +++ b/csp_advanced/tests.py @@ -1,9 +1,13 @@ from collections import OrderedDict -from django.test import SimpleTestCase +from django.core.exceptions import MiddlewareNotUsed +from django.http import HttpResponse +from django.test import SimpleTestCase, RequestFactory, override_settings +from django.utils.decorators import decorator_from_middleware_with_args from csp import CSPCompiler, InvalidCSPError -from utils import callable_csp_dict, merge_csp_dict +from csp_advanced.middleware import AdvancedCSPMiddleware +from utils import call_csp_dict, merge_csp_dict, is_callable_csp_dict class CSPCompileTest(SimpleTestCase): @@ -88,20 +92,20 @@ class CallableCSPDictTest(SimpleTestCase): return func def test_callable(self): - self.assertEqual(callable_csp_dict( + self.assertEqual(call_csp_dict( self.make_request_taker({'key': 'value'}), self.request, self.response ), {'key': 'value'}) def test_normal_dict(self): - self.assertEqual(callable_csp_dict({'key': 'value'}, None, None), {'key': 'value'}) + self.assertEqual(call_csp_dict({'key': 'value'}, None, None), {'key': 'value'}) def test_callable_entry(self): - self.assertEqual(callable_csp_dict( + self.assertEqual(call_csp_dict( {'key': self.make_request_taker('value')}, self.request, self.response ), {'key': 'value'}) def test_mixed_entry(self): - self.assertEqual(callable_csp_dict({ + self.assertEqual(call_csp_dict({ 'key': self.make_request_taker('value'), 'name': 'mixed', }, self.request, self.response), { @@ -109,6 +113,13 @@ class CallableCSPDictTest(SimpleTestCase): 'name': 'mixed' }) + def test_is_callable(self): + self.assertTrue(is_callable_csp_dict(self.make_request_taker({}))) + self.assertTrue(is_callable_csp_dict({'key': self.make_request_taker('value')})) + self.assertFalse(is_callable_csp_dict({})) + self.assertFalse(is_callable_csp_dict({'key': 'value'})) + self.assertFalse(is_callable_csp_dict(None)) + class MergeCSPDictTest(SimpleTestCase): def test_null(self): @@ -133,3 +144,100 @@ class MergeCSPDictTest(SimpleTestCase): def test_tuple_override(self): self.assertEqual(merge_csp_dict({'spam': (1,)}, {'spam': (2,)}), {'spam': (1, 2)}) + + +class TestMiddleware(SimpleTestCase): + decorator_factory = decorator_from_middleware_with_args(AdvancedCSPMiddleware) + + def setUp(self): + self.factory = RequestFactory() + + def make_ok_view(self): + @self.decorator_factory() + def view(request): + return HttpResponse('ok') + return view + + def get_request(self): + return self.factory.get('/') + + def test_no_csp(self): + self.assertRaises(MiddlewareNotUsed, self.decorator_factory) + + @override_settings(ADVANCED_CSP={'script-src': ['self']}) + def test_setting_csp(self): + self.assertEqual(self.make_ok_view()(self.get_request())['Content-Security-Policy'], "script-src 'self'") + + @override_settings(ADVANCED_CSP_REPORT_ONLY={'default-src': ['http://dmoj.ca']}) + def test_setting_csp_report(self): + self.assertEqual(self.make_ok_view()(self.get_request())['Content-Security-Policy-Report-Only'], + "default-src http://dmoj.ca") + + @override_settings(ADVANCED_CSP={'script-src': ['self']}, + ADVANCED_CSP_REPORT_ONLY={'default-src': ['http://dmoj.ca']}) + def test_setting_both(self): + response = self.make_ok_view()(self.get_request()) + self.assertEqual(response['Content-Security-Policy'], "script-src 'self'") + self.assertEqual(response['Content-Security-Policy-Report-Only'], 'default-src http://dmoj.ca') + + @override_settings(ADVANCED_CSP={'script-src': ['self']}) + def test_merge_csp_same(self): + @self.decorator_factory() + def view(request): + response = HttpResponse() + response.csp = {'script-src': ['https://dmoj.ca']} + return response + self.assertEqual(view(self.get_request())['Content-Security-Policy'], "script-src 'self' https://dmoj.ca") + + @override_settings(ADVANCED_CSP={'script-src': ['self']}) + def test_merge_csp_different(self): + @self.decorator_factory() + def view(request): + response = HttpResponse() + response.csp = {'style-src': ['https://dmoj.ca']} + return response + self.assertEqual(view(self.get_request())['Content-Security-Policy'], + "script-src 'self'; style-src https://dmoj.ca") + + @override_settings(ADVANCED_CSP={'script-src': ['self']}) + def test_override_csp_explicit(self): + @self.decorator_factory() + def view(request): + response = HttpResponse() + response.csp = {'style-src': ['none'], 'override': True} + return response + self.assertEqual(view(self.get_request())['Content-Security-Policy'], "style-src 'none'") + + @override_settings(ADVANCED_CSP_REPORT_ONLY={'script-src': ['self']}) + def test_override_csp_to_report_explicit(self): + @self.decorator_factory() + def view(request): + response = HttpResponse() + response.csp = {'style-src': ['none'], 'override': True} + return response + self.assertEqual(view(self.get_request())['Content-Security-Policy-Report-Only'], "style-src 'none'") + + @override_settings(ADVANCED_CSP_REPORT_ONLY={'script-src': ['self']}) + def test_override_csp_report_both_explicit(self): + @self.decorator_factory() + def view(request): + response = HttpResponse() + response.csp = {'style-src': ['none'], 'override': True} + response.csp_report = {'script-src': ['none'], 'override': True} + return response + + response = view(self.get_request()) + self.assertEqual(response['Content-Security-Policy-Report-Only'], "script-src 'none'") + self.assertTrue('Content-Security-Policy' not in response) + + @override_settings(ADVANCED_CSP_REPORT_ONLY={'script-src': ['self']}) + def test_override_csp_report_only_explicit(self): + @self.decorator_factory() + def view(request): + response = HttpResponse() + response.csp_report = {'script-src': ['none'], 'override': True} + return response + + response = view(self.get_request()) + self.assertEqual(response['Content-Security-Policy-Report-Only'], "script-src 'none'") + self.assertTrue('Content-Security-Policy' not in response) diff --git a/csp_advanced/utils.py b/csp_advanced/utils.py index a4b7a98..ae04f4f 100644 --- a/csp_advanced/utils.py +++ b/csp_advanced/utils.py @@ -1,34 +1,43 @@ -def callable_csp_dict(data, request, response): - if callable(data): - return data(request, response) - result = {} - for key, value in data.iteritems(): - if callable(value): - result[key] = value(request, response) - else: - result[key] = value - return result - - -def merge_csp_dict(template, override): - result = template.copy() - for key, value in override.iteritems(): - if key not in result: - result[key] = value - continue - orig = result[key] - if isinstance(orig, list): - if orig == template[key]: - result[key] = orig + list(value) - else: - orig += value - elif isinstance(orig, set): - if orig == template[key]: - result[key] = orig.union(value) - else: - orig.update(value) - elif isinstance(orig, tuple): - result[key] = orig + tuple(value) - else: - result[key] = value - return result +def is_callable_csp_dict(data): + if callable(data): + return True + if not isinstance(data, dict): + return False + return any(callable(value) for value in data.itervalues()) + + +def call_csp_dict(data, request, response): + if callable(data): + return data(request, response) + + result = {} + for key, value in data.iteritems(): + if callable(value): + result[key] = value(request, response) + else: + result[key] = value + return result + + +def merge_csp_dict(template, override): + result = template.copy() + for key, value in override.iteritems(): + if key not in result: + result[key] = value + continue + orig = result[key] + if isinstance(orig, list): + if orig == template[key]: + result[key] = orig + list(value) + else: + orig += value + elif isinstance(orig, set): + if orig == template[key]: + result[key] = orig.union(value) + else: + orig.update(value) + elif isinstance(orig, tuple): + result[key] = orig + tuple(value) + else: + result[key] = value + return result