diff --git a/csp_advanced/tests.py b/csp_advanced/tests.py index 9002a4a..80ab2fc 100644 --- a/csp_advanced/tests.py +++ b/csp_advanced/tests.py @@ -3,6 +3,7 @@ from collections import OrderedDict from django.test import SimpleTestCase from csp import CSPCompiler, InvalidCSPError +from utils import callable_csp_dict, merge_csp_dict class CSPCompileTest(SimpleTestCase): @@ -73,3 +74,48 @@ class CSPCompileTest(SimpleTestCase): "style-src 'self'; script-src 'self' https://dmoj.ca; frame-src 'none'; " "plugin-types application/pdf; block-all-mixed-content; sandbox allow-scripts; " "report-uri /dev/null") + + +class CallableCSPDictTest(SimpleTestCase): + def test_callable(self): + self.assertEqual(callable_csp_dict(lambda: {'key': 'value'}), {'key': 'value'}) + + def test_normal_dict(self): + self.assertEqual(callable_csp_dict({'key': 'value'}), {'key': 'value'}) + + def test_callable_entry(self): + self.assertEqual(callable_csp_dict({'key': lambda: 'value'}), {'key': 'value'}) + + def test_mixed_entry(self): + self.assertEqual(callable_csp_dict({ + 'key': lambda: 'value', + 'name': 'mixed', + }), { + 'key': 'value', + 'name': 'mixed' + }) + + +class MergeCSPDictTest(SimpleTestCase): + def test_null(self): + test = {'key': 'value'} + self.assertEqual(merge_csp_dict(test, {}), test) + + def test_distinct_key(self): + self.assertEqual(merge_csp_dict({'spam': 1}, {'ham': 2}), {'spam': 1, 'ham': 2}) + + def test_scalar_override(self): + self.assertEqual(merge_csp_dict({'spam': 1}, {'spam': 2}), {'spam': 2}) + + def test_list_override(self): + orig = [1] + self.assertEqual(merge_csp_dict({'spam': orig}, {'spam': [2]}), {'spam': [1, 2]}) + self.assertEqual(orig, [1]) + + def test_set_override(self): + orig = {1} + self.assertEqual(merge_csp_dict({'spam': orig}, {'spam': [2]}), {'spam': {1, 2}}) + self.assertEqual(orig, {1}) + + def test_tuple_override(self): + self.assertEqual(merge_csp_dict({'spam': (1,)}, {'spam': (2,)}), {'spam': (1, 2)}) diff --git a/csp_advanced/utils.py b/csp_advanced/utils.py new file mode 100644 index 0000000..49590d3 --- /dev/null +++ b/csp_advanced/utils.py @@ -0,0 +1,34 @@ +def callable_csp_dict(data): + if callable(data): + return data() + result = {} + for key, value in data.iteritems(): + if callable(value): + result[key] = value() + else: + result[key] = value + return result + + +def merge_csp_dict(template, override): + result = template.copy() + for key, value in override.iteritems(): + if key not in result: + result[key] = value + continue + orig = result[key] + if isinstance(orig, list): + if orig == template[key]: + result[key] = orig + list(value) + else: + orig += value + elif isinstance(orig, set): + if orig == template[key]: + result[key] = orig.union(value) + else: + orig.update(value) + elif isinstance(orig, tuple): + result[key] = orig + tuple(value) + else: + result[key] = value + return result