diff --git a/django/core/mail/backends/base.py b/django/core/mail/backends/base.py index 484d2a8b57..4914f479b9 100644 --- a/django/core/mail/backends/base.py +++ b/django/core/mail/backends/base.py @@ -40,7 +40,11 @@ class BaseEmailBackend(object): pass def __enter__(self): - self.open() + try: + self.open() + except Exception: + self.close() + raise return self def __exit__(self, exc_type, exc_value, traceback): diff --git a/tests/mail/tests.py b/tests/mail/tests.py index 662cd7c6d6..e053a65d6a 100644 --- a/tests/mail/tests.py +++ b/tests/mail/tests.py @@ -1289,11 +1289,9 @@ class SMTPBackendTests(BaseEmailBackendTests, SMTPBackendTestsBase): """ backend = smtp.EmailBackend( username='not empty username', password='not empty password') - try: - with self.assertRaisesMessage(SMTPException, 'SMTP AUTH extension not supported by server.'): - backend.open() - finally: - backend.close() + with self.assertRaisesMessage(SMTPException, 'SMTP AUTH extension not supported by server.'): + with backend: + pass def test_server_open(self): """ @@ -1315,7 +1313,8 @@ class SMTPBackendTests(BaseEmailBackendTests, SMTPBackendTestsBase): backend = CustomEmailBackend(username='username', password='password') with self.assertRaises(SMTPAuthenticationError): - backend.open() + with backend: + pass @override_settings(EMAIL_USE_TLS=True) def test_email_tls_use_settings(self): @@ -1377,21 +1376,17 @@ class SMTPBackendTests(BaseEmailBackendTests, SMTPBackendTestsBase): def test_email_tls_attempts_starttls(self): backend = smtp.EmailBackend() self.assertTrue(backend.use_tls) - try: - with self.assertRaisesMessage(SMTPException, 'STARTTLS extension not supported by server.'): - backend.open() - finally: - backend.close() + with self.assertRaisesMessage(SMTPException, 'STARTTLS extension not supported by server.'): + with backend: + pass @override_settings(EMAIL_USE_SSL=True) def test_email_ssl_attempts_ssl_connection(self): backend = smtp.EmailBackend() self.assertTrue(backend.use_ssl) - try: - with self.assertRaises(SSLError): - backend.open() - finally: - backend.close() + with self.assertRaises(SSLError): + with backend: + pass def test_connection_timeout_default(self): """Test that the connection's timeout value is None by default."""