Skip to content

Commit b973b70

Browse files
committed
fix multi-threading issue
1 parent ad76ce5 commit b973b70

File tree

4 files changed

+487
-36
lines changed

4 files changed

+487
-36
lines changed

nettacker/core/app.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from nettacker.core.messages import messages as _
2727
from nettacker.core.module import Module
28+
from nettacker.core.queue_manager import initialize_thread_pool, shutdown_thread_pool
2829
from nettacker.core.socks_proxy import set_socks_proxy
2930
from nettacker.core.utils import common as common_utils
3031
from nettacker.core.utils.common import wait_for_threads_to_finish
@@ -245,6 +246,12 @@ def start_scan(self, scan_id):
245246
target_groups.remove([])
246247

247248
log.info(_("start_multi_process").format(len(self.arguments.targets), len(target_groups)))
249+
250+
# Initialize the enhanced thread pool for cross-process sharing
251+
num_processes = len(target_groups)
252+
max_workers_per_process = getattr(self.arguments, "parallel_module_scan", None)
253+
initialize_thread_pool(num_processes, max_workers_per_process)
254+
248255
active_processes = []
249256
for t_id, target_groups in enumerate(target_groups):
250257
process = multiprocess.Process(
@@ -253,7 +260,12 @@ def start_scan(self, scan_id):
253260
process.start()
254261
active_processes.append(process)
255262

256-
return wait_for_threads_to_finish(active_processes, sub_process=True)
263+
result = wait_for_threads_to_finish(active_processes, sub_process=True)
264+
265+
# Shutdown the thread pool after scanning is complete
266+
shutdown_thread_pool()
267+
268+
return result
257269

258270
def scan_target(
259271
self,

nettacker/core/lib/base.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from nettacker.config import Config
1111
from nettacker.core.messages import messages as _
12+
from nettacker.core.queue_manager import dependency_resolver
1213
from nettacker.core.utils.common import merge_logs_to_list, remove_sensitive_header_keys
1314
from nettacker.database.db import find_temp_events, submit_temp_logs_to_db, submit_logs_to_db
1415
from nettacker.logger import get_logger, TerminalCodes
@@ -47,14 +48,40 @@ def filter_large_content(self, content, filter_rate=150):
4748
return content
4849

4950
def get_dependent_results_from_database(self, target, module_name, scan_id, event_names):
51+
"""
52+
Efficiently get dependency results without busy-waiting.
53+
Uses event-driven approach to avoid CPU consumption.
54+
"""
55+
# Try to get results efficiently using the new dependency resolver
56+
results = dependency_resolver.get_dependency_results_efficiently(
57+
target, module_name, scan_id, event_names, {}, self, ()
58+
)
59+
60+
if results is not None:
61+
return results
62+
63+
# Fallback to original implementation for backward compatibility
64+
# but with increased sleep time to reduce CPU usage
5065
events = []
5166
for event_name in event_names.split(","):
52-
while True:
67+
retry_count = 0
68+
max_retries = 300 # 30 seconds with 0.1s sleep
69+
70+
while retry_count < max_retries:
5371
event = find_temp_events(target, module_name, scan_id, event_name)
5472
if event:
5573
events.append(json.loads(event.event)["response"]["conditions_results"])
5674
break
57-
time.sleep(0.1)
75+
76+
retry_count += 1
77+
# Exponential backoff to reduce CPU usage
78+
sleep_time = min(0.1 * (1.5 ** (retry_count // 10)), 1.0)
79+
time.sleep(sleep_time)
80+
else:
81+
# Timeout reached
82+
log.warn(f"Timeout waiting for dependency: {event_name} for {target}")
83+
events.append(None)
84+
5885
return events
5986

6087
def find_and_replace_dependent_values(self, sub_step, dependent_on_temp_event):
@@ -123,18 +150,26 @@ def process_conditions(
123150
# Remove sensitive keys from headers before submitting to DB
124151
event = remove_sensitive_header_keys(event)
125152
if "save_to_temp_events_only" in event.get("response", ""):
153+
event_name = event["response"]["save_to_temp_events_only"]
154+
155+
# Submit to database
126156
submit_temp_logs_to_db(
127157
{
128158
"date": datetime.now(),
129159
"target": target,
130160
"module_name": module_name,
131161
"scan_id": scan_id,
132-
"event_name": event["response"]["save_to_temp_events_only"],
162+
"event_name": event_name,
133163
"port": event.get("ports", ""),
134164
"event": event,
135165
"data": response,
136166
}
137167
)
168+
169+
# Notify dependency resolver that a dependency is now available
170+
dependency_resolver.notify_dependency_available(
171+
target, module_name, scan_id, event_name, response
172+
)
138173
if event["response"]["conditions_results"] and "save_to_temp_events_only" not in event.get(
139174
"response", ""
140175
):
@@ -279,9 +314,37 @@ def run(
279314
sub_step[attr_name.rstrip("s")] = int(value) if attr_name == "ports" else value
280315

281316
if "dependent_on_temp_event" in backup_response:
282-
temp_event = self.get_dependent_results_from_database(
283-
target, module_name, scan_id, backup_response["dependent_on_temp_event"]
317+
# Try to get dependency results efficiently
318+
temp_event = dependency_resolver.get_dependency_results_efficiently(
319+
target,
320+
module_name,
321+
scan_id,
322+
backup_response["dependent_on_temp_event"],
323+
sub_step,
324+
self,
325+
(
326+
sub_step,
327+
module_name,
328+
target,
329+
scan_id,
330+
options,
331+
process_number,
332+
module_thread_number,
333+
total_module_thread_number,
334+
request_number_counter,
335+
total_number_of_requests,
336+
),
284337
)
338+
339+
# If dependencies are not available yet, the task is queued
340+
# Return early to avoid blocking the thread
341+
if temp_event is None:
342+
log.verbose_event_info(
343+
f"Task queued waiting for dependencies: {target} -> {module_name}"
344+
)
345+
return False
346+
347+
# Dependencies are available, continue with execution
285348
sub_step = self.replace_dependent_values(sub_step, temp_event)
286349

287350
action = getattr(self.library(), backup_method)

nettacker/core/module.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from nettacker import logger
99
from nettacker.config import Config
1010
from nettacker.core.messages import messages as _
11+
from nettacker.core.queue_manager import thread_pool
1112
from nettacker.core.template import TemplateLoader
1213
from nettacker.core.utils.common import expand_module_steps, wait_for_threads_to_finish
1314
from nettacker.database.db import find_events
@@ -118,26 +119,44 @@ def generate_loops(self):
118119
self.module_content["payloads"] = expand_module_steps(self.module_content["payloads"])
119120

120121
def sort_loops(self):
121-
steps = []
122+
"""
123+
Sort loops to optimize dependency resolution:
124+
1. Independent steps first
125+
2. Steps that generate dependencies (save_to_temp_events_only)
126+
3. Steps that consume dependencies (dependent_on_temp_event)
127+
"""
122128
for index in range(len(self.module_content["payloads"])):
123-
for step in copy.deepcopy(self.module_content["payloads"][index]["steps"]):
124-
if "dependent_on_temp_event" not in step[0]["response"]:
125-
steps.append(step)
129+
independent_steps = []
130+
dependency_generators = []
131+
dependency_consumers = []
126132

127133
for step in copy.deepcopy(self.module_content["payloads"][index]["steps"]):
128-
if (
129-
"dependent_on_temp_event" in step[0]["response"]
130-
and "save_to_temp_events_only" in step[0]["response"]
131-
):
132-
steps.append(step)
134+
step_response = step[0]["response"] if step and len(step) > 0 else {}
135+
136+
has_dependency = "dependent_on_temp_event" in step_response
137+
generates_dependency = "save_to_temp_events_only" in step_response
138+
139+
if not has_dependency and not generates_dependency:
140+
independent_steps.append(step)
141+
elif generates_dependency and not has_dependency:
142+
dependency_generators.append(step)
143+
elif generates_dependency and has_dependency:
144+
dependency_generators.append(step) # Generator first
145+
elif has_dependency and not generates_dependency:
146+
dependency_consumers.append(step)
147+
else:
148+
independent_steps.append(step) # Fallback
133149

134-
for step in copy.deepcopy(self.module_content["payloads"][index]["steps"]):
135-
if (
136-
"dependent_on_temp_event" in step[0]["response"]
137-
and "save_to_temp_events_only" not in step[0]["response"]
138-
):
139-
steps.append(step)
140-
self.module_content["payloads"][index]["steps"] = steps
150+
# Combine in optimal order
151+
sorted_steps = independent_steps + dependency_generators + dependency_consumers
152+
self.module_content["payloads"][index]["steps"] = sorted_steps
153+
154+
log.verbose_info(
155+
f"Sorted {len(sorted_steps)} steps: "
156+
f"{len(independent_steps)} independent, "
157+
f"{len(dependency_generators)} generators, "
158+
f"{len(dependency_consumers)} consumers"
159+
)
141160

142161
def start(self):
143162
active_threads = []
@@ -158,11 +177,14 @@ def start(self):
158177
importlib.import_module(f"nettacker.core.lib.{library.lower()}"),
159178
f"{library.capitalize()}Engine",
160179
)()
180+
161181
for step in payload["steps"]:
162182
for sub_step in step:
163-
thread = Thread(
164-
target=engine.run,
165-
args=(
183+
# Try to use shared thread pool if available, otherwise use local threads
184+
if thread_pool and hasattr(thread_pool, "submit_task"):
185+
# Submit to shared thread pool
186+
thread_pool.submit_task(
187+
engine.run,
166188
sub_step,
167189
self.module_name,
168190
self.target,
@@ -173,9 +195,35 @@ def start(self):
173195
self.total_module_thread_number,
174196
request_number_counter,
175197
total_number_of_requests,
176-
),
177-
)
178-
thread.name = f"{self.target} -> {self.module_name} -> {sub_step}"
198+
)
199+
else:
200+
# Use local thread (fallback to original behavior)
201+
thread = Thread(
202+
target=engine.run,
203+
args=(
204+
sub_step,
205+
self.module_name,
206+
self.target,
207+
self.scan_id,
208+
self.module_inputs,
209+
self.process_number,
210+
self.module_thread_number,
211+
self.total_module_thread_number,
212+
request_number_counter,
213+
total_number_of_requests,
214+
),
215+
)
216+
thread.name = f"{self.target} -> {self.module_name} -> {sub_step}"
217+
thread.start()
218+
active_threads.append(thread)
219+
220+
# Manage local thread pool size
221+
wait_for_threads_to_finish(
222+
active_threads,
223+
maximum=self.module_inputs["thread_per_host"],
224+
terminable=True,
225+
)
226+
179227
request_number_counter += 1
180228
log.verbose_event_info(
181229
_("sending_module_request").format(
@@ -188,13 +236,8 @@ def start(self):
188236
total_number_of_requests,
189237
)
190238
)
191-
thread.start()
192239
time.sleep(self.module_inputs["time_sleep_between_requests"])
193-
active_threads.append(thread)
194-
wait_for_threads_to_finish(
195-
active_threads,
196-
maximum=self.module_inputs["thread_per_host"],
197-
terminable=True,
198-
)
199240

200-
wait_for_threads_to_finish(active_threads, maximum=None, terminable=True)
241+
# Wait for any remaining local threads to finish
242+
if active_threads:
243+
wait_for_threads_to_finish(active_threads, maximum=None, terminable=True)

0 commit comments

Comments
 (0)