diff --git a/django/utils/crypto.py b/django/utils/crypto.py index 4ec1cfcf77..edeb336f34 100644 --- a/django/utils/crypto.py +++ b/django/utils/crypto.py @@ -9,10 +9,16 @@ from django.conf import settings from django.utils.encoding import force_bytes -def salted_hmac(key_salt, value, secret=None): +class InvalidAlgorithm(ValueError): + """Algorithm is not supported by hashlib.""" + pass + + +def salted_hmac(key_salt, value, secret=None, *, algorithm='sha1'): """ - Return the HMAC-SHA1 of 'value', using a key generated from key_salt and a - secret (which defaults to settings.SECRET_KEY). + Return the HMAC of 'value', using a key generated from key_salt and a + secret (which defaults to settings.SECRET_KEY). Default algorithm is SHA1, + but any algorithm name supported by hashlib.new() can be passed. A different key_salt should be passed in for every application of HMAC. """ @@ -21,16 +27,21 @@ def salted_hmac(key_salt, value, secret=None): key_salt = force_bytes(key_salt) secret = force_bytes(secret) - + try: + hasher = getattr(hashlib, algorithm) + except AttributeError as e: + raise InvalidAlgorithm( + '%r is not an algorithm accepted by the hashlib module.' + % algorithm + ) from e # We need to generate a derived key from our base key. We can do this by - # passing the key_salt and our base key through a pseudo-random function and - # SHA1 works nicely. - key = hashlib.sha1(key_salt + secret).digest() + # passing the key_salt and our base key through a pseudo-random function. + key = hasher(key_salt + secret).digest() # If len(key_salt + secret) > block size of the hash algorithm, the above # line is redundant and could be replaced by key = key_salt + secret, since # the hmac module does the same thing for keys longer than the block size. # However, we need to ensure that we *always* do this. - return hmac.new(key, msg=force_bytes(value), digestmod=hashlib.sha1) + return hmac.new(key, msg=force_bytes(value), digestmod=hasher) def get_random_string(length=12, diff --git a/tests/utils_tests/test_crypto.py b/tests/utils_tests/test_crypto.py index 1c3868f4ca..9dbfd9fe57 100644 --- a/tests/utils_tests/test_crypto.py +++ b/tests/utils_tests/test_crypto.py @@ -1,10 +1,13 @@ import hashlib import unittest -from django.utils.crypto import constant_time_compare, pbkdf2, salted_hmac +from django.test import SimpleTestCase +from django.utils.crypto import ( + InvalidAlgorithm, constant_time_compare, pbkdf2, salted_hmac, +) -class TestUtilsCryptoMisc(unittest.TestCase): +class TestUtilsCryptoMisc(SimpleTestCase): def test_constant_time_compare(self): # It's hard to test for constant time, just test the result. @@ -27,11 +30,31 @@ class TestUtilsCryptoMisc(unittest.TestCase): {'secret': 'x' * hashlib.sha1().block_size}, 'bd3749347b412b1b0a9ea65220e55767ac8e96b0', ), + ( + ('salt', 'value'), + {'algorithm': 'sha256'}, + 'ee0bf789e4e009371a5372c90f73fcf17695a8439c9108b0480f14e347b3f9ec', + ), + ( + ('salt', 'value'), + { + 'algorithm': 'blake2b', + 'secret': 'x' * hashlib.blake2b().block_size, + }, + 'fc6b9800a584d40732a07fa33fb69c35211269441823bca431a143853c32f' + 'e836cf19ab881689528ede647dac412170cd5d3407b44c6d0f44630690c54' + 'ad3d58', + ), ] for args, kwargs, digest in tests: with self.subTest(args=args, kwargs=kwargs): self.assertEqual(salted_hmac(*args, **kwargs).hexdigest(), digest) + def test_invalid_algorithm(self): + msg = "'whatever' is not an algorithm accepted by the hashlib module." + with self.assertRaisesMessage(InvalidAlgorithm, msg): + salted_hmac('salt', 'value', algorithm='whatever') + class TestUtilsCryptoPBKDF2(unittest.TestCase):