Skip to content

Commit b3d6560

Browse files
authored
Use Futures to deduplicate incoming requests (#37)
Fixes #23 As of #19, we cache (most) scan results in memory to avoid having to download and scan the same file twice (with the caveat that we don't cache the contents of big files, but it still saves scanning them). However, it does not prevent duplicate work if multiple requests for the same file happen simultaneously, since we only update this cache when a scan completes. This is a legitimate concern, since we can easily imagine a situation where a user posts a large file into a room and multiple users try to download it at the same time. As of this change, * when a request for a given file with given encryption and thumbnailing parameters comes in, and we're not already running a download/scan for it (with those parameters), we create a `Future` that we store into a new `_current_scans` cache, and run the scan. Once the scan completes, we resolve the `Future` with the result, and remove it from the cache. * when a request for the same file comes in _before_ the previous request completed, we don't download and/or scan the file - instead we look up the matching `Future` from the `_current_scans` cache and `await` it. We then return its result.
1 parent c8b7496 commit b3d6560

2 files changed

Lines changed: 106 additions & 2 deletions

File tree

matrix_content_scanner/scanner/scanner.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
1415
import hashlib
1516
import logging
1617
import os
1718
import subprocess
19+
from asyncio import Future
1820
from pathlib import Path
1921
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
2022

@@ -85,6 +87,10 @@ def __init__(self, mcs: "MatrixContentScanner"):
8587
# for unencrypted files).
8688
self._allowed_mimetypes = mcs.config.scan.allowed_mimetypes
8789

90+
# Cache of futures for files that are currently scanning and downloading, so that
91+
# concurrent requests don't cause a file to be downloaded and scanned twice.
92+
self._current_scans: Dict[str, Future[MediaDescription]] = {}
93+
8894
async def scan_file(
8995
self,
9096
media_path: str,
@@ -100,6 +106,9 @@ async def scan_file(
100106
downloading the file again (unless we purposefully did not cache the file's
101107
content to save up on memory).
102108
109+
If a file is currently already being downloaded or scanned as a result of another
110+
request, don't download it again and use the result from the first request.
111+
103112
Args:
104113
media_path: The `server_name/media_id` path for the media.
105114
metadata: The metadata attached to the file (e.g. decryption key), or None if
@@ -115,9 +124,72 @@ async def scan_file(
115124
FileDirtyError if the result of the scan said that the file is dirty, or if
116125
the media path is malformed.
117126
"""
118-
# Compute the cache key for the media.
127+
# Compute the key to use when caching, both in the current scans cache and in the
128+
# results cache.
119129
cache_key = self._get_cache_key_for_file(media_path, metadata, thumbnail_params)
130+
if cache_key not in self._current_scans:
131+
# Create a future in the context of the current event loop.
132+
loop = asyncio.get_event_loop()
133+
f = loop.create_future()
134+
# Register the future in the current scans cache so that subsequent queries
135+
# can use it.
136+
self._current_scans[cache_key] = f
137+
# Try to download and scan the file.
138+
try:
139+
res = await self._scan_file(
140+
cache_key, media_path, metadata, thumbnail_params
141+
)
142+
# Set the future's result, and mark it as done.
143+
f.set_result(res)
144+
# Return the result.
145+
return res
146+
except Exception as e:
147+
# If there's an exception, catch it, pass it on to the future, and raise
148+
# it.
149+
f.set_exception(e)
150+
# We retrieve the exception from the future, because if we don't and no
151+
# other request is awaiting on the future, asyncio complains about "Future
152+
# exception was never retrieved".
153+
f.exception()
154+
raise
155+
finally:
156+
# Remove the future from the cache.
157+
del self._current_scans[cache_key]
158+
159+
return await self._current_scans[cache_key]
160+
161+
async def _scan_file(
162+
self,
163+
cache_key: str,
164+
media_path: str,
165+
metadata: Optional[JsonDict] = None,
166+
thumbnail_params: Optional[MultiDictProxy[str]] = None,
167+
) -> MediaDescription:
168+
"""Download and scan the given media.
169+
170+
Unless the scan fails with one of the codes listed in `do_not_cache_exit_codes`,
171+
also cache the result.
172+
173+
If the file already has an entry in the result cache, return this value without
174+
downloading the file again (unless we purposefully did not cache the file's
175+
content to save up on memory).
120176
177+
Args:
178+
cache_key: The key to use to cache the result of the scan in the result cache.
179+
media_path: The `server_name/media_id` path for the media.
180+
metadata: The metadata attached to the file (e.g. decryption key), or None if
181+
the file isn't encrypted.
182+
thumbnail_params: If present, then we want to request and scan a thumbnail
183+
generated with the provided parameters instead of the full media.
184+
185+
Returns:
186+
A description of the media.
187+
188+
Raises:
189+
ContentScannerRestError if the file could not be downloaded.
190+
FileDirtyError if the result of the scan said that the file is dirty, or if
191+
the media path is malformed.
192+
"""
121193
# The media to scan.
122194
media: Optional[MediaDescription] = None
123195

tests/scanner/test_scanner.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
1415
import copy
15-
from typing import Dict, List, Optional
16+
from typing import Any, Dict, List, Optional
1617
from unittest import IsolatedAsyncioTestCase
1718
from unittest.mock import Mock
1819

@@ -303,6 +304,37 @@ async def test_invalid_media_path(self) -> None:
303304
with self.assertRaises(FileDirtyError):
304305
await self.scanner.scan_file(MEDIA_PATH + "/baz")
305306

307+
async def test_deduplicate_scans(self) -> None:
308+
"""Tests that if two scan requests come in for the same file and with the same
309+
parameter, only one download/scan happens.
310+
"""
311+
312+
# Change the Mock's side effect to introduce some delay, to simulate a long
313+
# download time. We sleep asynchronously to allow additional scans requests to be
314+
# processed.
315+
async def _scan_file(*args: Any) -> MediaDescription:
316+
await asyncio.sleep(0.2)
317+
318+
return self.downloader_res
319+
320+
scan_mock = Mock(side_effect=_scan_file)
321+
self.scanner._scan_file = scan_mock # type: ignore[assignment]
322+
323+
# Request two scans of the same file at the same time.
324+
results = await asyncio.gather(
325+
asyncio.create_task(self.scanner.scan_file(MEDIA_PATH)),
326+
asyncio.create_task(self.scanner.scan_file(MEDIA_PATH)),
327+
)
328+
329+
# Check that the scanner has only been called once, meaning that the second
330+
# call did not trigger a scan.
331+
scan_mock.assert_called_once()
332+
333+
# Check that we got two results, and that we actually got the correct media
334+
# description in the second scan.
335+
self.assertEqual(len(results), 2, results)
336+
self.assertEqual(results[0].content, results[1].content, results)
337+
306338
def _setup_encrypted(self) -> None:
307339
"""Sets up class properties to make the downloader return an encrypted file
308340
instead of a plain text one.

0 commit comments

Comments
 (0)