-
-
Notifications
You must be signed in to change notification settings - Fork 637
Expand file tree
/
Copy pathclasses.py
More file actions
438 lines (381 loc) · 13.7 KB
/
classes.py
File metadata and controls
438 lines (381 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
import base64
import logging
import traceback
import typing
from abc import ABCMeta, abstractmethod
from pathlib import PosixPath
import requests
from billiard.exceptions import SoftTimeLimitExceeded
from django.conf import settings
from django.core.files import File
from django.utils import timezone
from django.utils.functional import cached_property
from requests import HTTPError
from api_app.decorators import abstractclassproperty, classproperty
from api_app.models import AbstractReport, Job, PythonConfig, PythonModule
from certego_saas.apps.user.models import User
logger = logging.getLogger(__name__)
class Plugin(metaclass=ABCMeta):
"""
Abstract Base class for plugins. Provides a framework for defining and running
plugins within a specified configuration.
Attributes:
config (PythonConfig): Configuration for the plugin.
kwargs: Additional keyword arguments.
"""
def __init__(
self,
config: PythonConfig,
**kwargs,
):
self._config = config
self.kwargs = kwargs
# some post init processing
# monkeypatch if in test suite
if settings.STAGE_CI or settings.MOCK_CONNECTIONS:
self._monkeypatch()
@property
def name(self):
"""
Get the name of the plugin.
Returns:
str: The name of the plugin.
"""
return self._config.name
@abstractclassproperty
def python_base_path(cls) -> PosixPath:
NotImplementedError()
@classmethod
def all_subclasses(cls):
"""
Retrieve all subclasses of the plugin class.
Returns:
list: Sorted list of plugin subclasses.
"""
posix_dir = PosixPath(str(cls.python_base_path).replace(".", "/"))
for plugin in posix_dir.rglob("*.py"):
if plugin.stem == "__init__":
continue
package = f"{str(plugin.parent).replace('/', '.')}.{plugin.stem}"
__import__(package)
classes = cls.__subclasses__()
return sorted(
[class_ for class_ in classes if not class_.__name__.startswith("MockUp")],
key=lambda x: x.__name__,
)
@cached_property
def _job(self) -> "Job":
"""
Get the job associated with the plugin.
Returns:
Job: The job instance.
"""
return Job.objects.get(pk=self.job_id)
@property
def job_id(self) -> int:
"""
Get the job ID.
Returns:
int: The job ID.
"""
return self._job_id
@job_id.setter
def job_id(self, value):
"""
Set the job ID.
Args:
value (int): The job ID.
"""
self._job_id = value
@cached_property
def _user(self):
"""
Get the user associated with the job.
Returns:
User: The user instance.
"""
return self._job.user
def __repr__(self):
"""
Get the string representation of the plugin.
Returns:
str: The string representation of the plugin.
"""
return str(self)
def __str__(self):
"""
Get the string representation of the plugin.
Returns:
str: The string representation of the plugin.
"""
try:
return f"({self.__class__.__name__}, job: #{self.job_id})"
except AttributeError:
return f"{self.__class__.__name__}"
def config(self, runtime_configuration: typing.Dict):
"""
Configure the plugin with runtime parameters.
Args:
runtime_configuration (dict): Runtime configuration parameters.
"""
self.__parameters = self._config.read_configured_params(
self._user, runtime_configuration
)
for parameter in self.__parameters:
attribute_name = (
f"_{parameter.name}" if parameter.is_secret else parameter.name
)
setattr(self, attribute_name, parameter.value)
logger.debug(
f"Adding to {self.__class__.__name__} "
f"param {attribute_name} with value {parameter.value} "
)
def before_run(self):
"""
Function called directly before the run function.
"""
@abstractmethod
def run(self) -> dict:
"""
Called from *start* function and wrapped in a try-catch block.
Should be overwritten in child class.
Returns:
dict: Report generated by the plugin.
"""
def after_run(self):
"""
Function called after the run function.
"""
self.report.end_time = timezone.now()
self.report.save()
def after_run_success(self, content: typing.Any):
"""
Handle the successful completion of the run function.
Args:
content (Any): Content generated by the plugin.
"""
# avoiding JSON serialization errors for types: File and bytes
report_content = content
if isinstance(report_content, typing.List):
report_content = []
for n in content:
if isinstance(n, File):
report_content.append(base64.b64encode(n.read()).decode("utf-8"))
elif isinstance(n, bytes):
report_content.append(base64.b64encode(n).decode("utf-8"))
else:
report_content.append(n)
self.report.report = report_content
self.report.status = self.report.STATUSES.SUCCESS.value
self.report.save(update_fields=["status", "report"])
def log_error(self, e):
"""
Log an error encountered during the run function.
Args:
e (Exception): The exception to log.
"""
if isinstance(
e, (*self.get_exceptions_to_catch(), SoftTimeLimitExceeded, HTTPError)
):
error_message = self.get_error_message(e)
logger.error(error_message)
else:
traceback.print_exc()
error_message = self.get_error_message(e, is_base_err=True)
logger.exception(error_message)
def after_run_failed(self, e: Exception):
"""
Handle the failure of the run function.
Args:
e (Exception): The exception that caused the failure.
"""
self.report.errors.append(str(e))
self.report.status = self.report.STATUSES.FAILED
self.report.save(update_fields=["status", "errors"])
if isinstance(e, HTTPError) and (
hasattr(e, "response")
and hasattr(e.response, "status_code")
and e.response.status_code == 429
):
self.disable_for_rate_limit()
else:
self.log_error(e)
if settings.STAGE_CI:
raise e
@abstractclassproperty
def report_model(cls) -> typing.Type[AbstractReport]:
"""
Returns Model to be used for *init_report_object*
"""
raise NotImplementedError()
@abstractclassproperty
def config_model(cls) -> typing.Type[PythonConfig]:
"""
Returns Model to be used for *init_report_object*
"""
raise NotImplementedError()
@abstractmethod
def get_exceptions_to_catch(self) -> list:
"""
Returns list of `Exception`'s to handle.
"""
raise NotImplementedError()
def get_error_message(self, err, is_base_err=False):
"""
Returns error message for
*_handle_analyzer_exception* fn
"""
return (
f"{self}."
f" {'Unexpected error' if is_base_err else f'{self.config_model.__name__} error'}:" # noqa
f" '{err}'"
)
def start(
self, job_id: int, runtime_configuration: dict, task_id: str, *args, **kwargs
):
"""
Entrypoint function to execute the plugin.
calls `before_run`, `run`, `after_run`
in that order with exception handling.
"""
self.job_id = job_id
self.report: AbstractReport = self._config.generate_empty_report(
self._job, task_id, AbstractReport.STATUSES.RUNNING.value
)
try:
self.config(runtime_configuration)
self.before_run()
_result = self.run()
except Exception as e:
self.after_run_failed(e)
else:
self.after_run_success(_result)
finally:
# add end time of process
self.after_run()
def _handle_exception(self, exc, is_base_err: bool = False) -> None:
if not is_base_err:
traceback.print_exc()
error_message = self.get_error_message(exc, is_base_err=is_base_err)
logger.error(error_message)
self.report.errors.append(str(exc))
self.report.status = self.report.STATUSES.FAILED
@classmethod
def _monkeypatch(cls, patches: list = None) -> None:
"""
Hook to monkey-patch class for testing purposes.
"""
if patches is None:
patches = []
for mock_fn in patches:
cls.start = mock_fn(cls.start)
@classproperty
def python_module(cls) -> PythonModule:
"""
Get the Python module associated with the plugin.
Returns:
PythonModule: The Python module instance.
"""
valid_module = cls.__module__.replace(str(cls.python_base_path), "")
# remove the starting dot
valid_module = valid_module[1:]
return PythonModule.objects.get(
module=f"{valid_module}.{cls.__name__}", base_path=cls.python_base_path
)
@classmethod
def update(cls) -> bool:
"""
Update the plugin. Must be implemented by subclasses.
Returns:
bool: Whether the update was successful.
"""
raise NotImplementedError("No update implemented")
def _get_health_check_url(self, user: User = None) -> typing.Optional[str]:
"""
Get the URL for performing a health check.
Args:
user (User): The user instance.
Returns:
typing.Optional[str]: The health check URL.
"""
params = (
self._config.parameters.annotate_configured(self._config, user)
.annotate_value_for_user(self._config, user)
.filter(name__icontains="url")
)
for param in params:
if not param.configured or not param.value:
continue
url = param.value
logger.info(f"Url retrieved to verify is {param.name} for {self}")
return url
if hasattr(self, "url") and self.url:
return self.url
return None
def health_check(self, user: User = None) -> bool:
"""
Perform a health check for the plugin.
Args:
user (User): The user instance.
Returns:
bool: Whether the health check was successful.
"""
url = self._get_health_check_url(user)
if url and url.startswith("http"):
if settings.STAGE_CI or settings.MOCK_CONNECTIONS:
return True
logger.info(f"healthcheck url {url} for {self}")
try:
# momentarily set this to False to
# avoid fails for https services
response = requests.head(url, timeout=10, verify=False)
# This may happen when even the HEAD request is protected by authentication
# We cannot create a generic health check that consider auth too
# because every analyzer has its own way to authenticate
# So, in this case, we will consider it as check passed because we got an answer
# For ex 405 code is when HEADs are not allowed. But it is the same. The service answered.
if 400 <= response.status_code <= 408:
return True
response.raise_for_status()
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.HTTPError,
) as e:
logger.info(f"healthcheck failed: url {url} for {self}. Error: {e}")
return False
else:
return True
raise NotImplementedError()
def disable_for_rate_limit(self):
"""
Disable the plugin due to rate limiting.
"""
logger.info(f"Trying to disable for rate limit {self}")
if self._user.has_membership():
org_configuration = self._config.get_or_create_org_configuration(
self._user.membership.organization
)
if org_configuration.rate_limit_timeout is not None:
api_key_parameter = self.__parameters.filter(
name__contains="api_key"
).first()
# if we do not have api keys OR the api key was org based
# OR if the api key is not actually required and we do not have it set
if (
not api_key_parameter
or api_key_parameter.is_from_org
or (not api_key_parameter.required and not api_key_parameter.value)
):
org_configuration.disable_for_rate_limit()
else:
logger.warning(
f"Not disabling {self} because api key used is personal"
)
else:
logger.warning(
f"You are trying to disable {self}"
" for rate limit without specifying a timeout."
)
else:
logger.info(f"User {self._user.username} is not in organization.")