|
22 | 22 | import datetime |
23 | 23 | import importlib.resources as importlib_resources |
24 | 24 | import os |
25 | | -from typing import Any |
| 25 | +from typing import Any, cast |
26 | 26 | from unittest.mock import AsyncMock |
27 | 27 |
|
28 | 28 | from twisted.internet.testing import MemoryReactor |
@@ -66,6 +66,17 @@ def make_homeserver( |
66 | 66 | hs.get_send_email_handler()._sendmail = AsyncMock() |
67 | 67 | return hs |
68 | 68 |
|
| 69 | + def _get_sendmail_mock(self) -> AsyncMock: |
| 70 | + """ |
| 71 | + Cast the homeserver's `_sendmail` object as an `AsyncMock`. |
| 72 | +
|
| 73 | + `_sendmail` is an `AsyncMock` (see `make_homeserver`) but this type |
| 74 | + information doesn't make it through the test harness. Thus we need to |
| 75 | + cast the object again. |
| 76 | + """ |
| 77 | + sendmail = self.hs.get_send_email_handler()._sendmail |
| 78 | + return cast(AsyncMock, sendmail) |
| 79 | + |
69 | 80 | def test_POST_appservice_registration_valid(self) -> None: |
70 | 81 | user_id = "@as_user_kermit:test" |
71 | 82 | as_token = "i_am_an_app_service" |
@@ -747,6 +758,33 @@ def test_request_token_existing_email_inhibit_error(self) -> None: |
747 | 758 |
|
748 | 759 | self.assertIsNotNone(channel.json_body.get("sid")) |
749 | 760 |
|
| 761 | + @unittest.override_config( |
| 762 | + { |
| 763 | + "public_baseurl": "https://test_server", |
| 764 | + "email": { |
| 765 | + "smtp_host": "mail_server", |
| 766 | + "smtp_port": 2525, |
| 767 | + "notif_from": "sender@host", |
| 768 | + }, |
| 769 | + } |
| 770 | + ) |
| 771 | + def test_request_token_allowed_when_email_flow_is_advertised(self) -> None: |
| 772 | + sendmail = self._get_sendmail_mock() |
| 773 | + sendmail.reset_mock() |
| 774 | + |
| 775 | + channel = self.make_request( |
| 776 | + "POST", |
| 777 | + b"register/email/requestToken", |
| 778 | + { |
| 779 | + "client_secret": "foobar", |
| 780 | + "email": "test@example.com", |
| 781 | + "send_attempt": 1, |
| 782 | + }, |
| 783 | + ) |
| 784 | + self.assertEqual(200, channel.code, channel.result) |
| 785 | + self.assertIsNotNone(channel.json_body.get("sid")) |
| 786 | + sendmail.assert_awaited_once() |
| 787 | + |
750 | 788 | @unittest.override_config( |
751 | 789 | { |
752 | 790 | "public_baseurl": "https://test_server", |
|
0 commit comments