Skip to content

Commit 32178db

Browse files
committed
continue to flow
1 parent d6fb9bf commit 32178db

File tree

4 files changed

+158
-27
lines changed

4 files changed

+158
-27
lines changed

src/unitxt/ccc_inference.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,25 @@
1717
class ServerManager:
1818
def __init__(self):
1919
self.shutdown_flag = False
20-
self.inactivity_timeout = 60
20+
self.inactivity_timeout = 6000
2121
self.monitor_thread = threading.Thread(target=self.monitor_activity, daemon=True)
22-
self.monitor_thread.start()
22+
2323
self.last_request_time = time.time()
2424
self.shutdown_flag = False
25+
self.configuration = None
26+
self.workers_status = {}
27+
28+
def set_configuration(self, configuration):
29+
self.configuration = configuration
30+
31+
def get_configuration(self):
32+
return self.configuration
33+
34+
def register_worker(self, id):
35+
self.workers_status[id] = {"status": "registered"}
36+
37+
def start_monitoring(self):
38+
self.monitor_thread.start()
2539

2640
def update_last_request_time(self):
2741
self.last_request_time = time.time()
@@ -48,32 +62,54 @@ def shutdown_server(self):
4862
time.sleep(1)
4963
os._exit(0) # This immediately stops the program
5064

65+
5166
server_manager = ServerManager()
5267

68+
5369
@app.before_request
5470
def update_activity():
5571
server_manager.update_last_request_time()
5672

73+
5774
@app.route("/isup", methods=["GET"])
5875
def isup():
5976
return jsonify({"status": "up"}), 200
6077

78+
6179
@app.route("/version", methods=["GET"])
6280
def version():
6381
return jsonify({"version": "1.0.0"}), 200
6482

83+
6584
@app.route("/infer", methods=["POST"])
6685
def infer():
67-
data = request.json.get("dataset", [])
68-
predictions = [f"Processed: {item}" for item in data]
86+
data = request.json
87+
predictions = [0.202 for item in data]
6988
return jsonify(predictions)
7089

90+
91+
@app.route("/set_configuration", methods=["POST"])
92+
def set_configuration():
93+
configuration = request.json
94+
server_manager.set_configuration(configuration)
95+
return jsonify({"message": "configuration has been set"})
96+
97+
98+
@app.route("/register", methods=["POST"])
99+
def register():
100+
id = request.json
101+
server_manager.register_worker(id)
102+
return jsonify(server_manager.get_configuration())
103+
104+
71105
@app.route("/shutdown", methods=["POST"])
72106
def shutdown():
73107
app.logger.info("Received shutdown request")
74108
server_manager.shutdown_server()
75109
return jsonify({"message": "Shutting down server..."}), 200
76110

111+
77112
if __name__ == "__main__":
113+
server_manager.start_monitoring()
78114
app.logger.info("Server started on port {PORT}")
79115
app.run(host="0.0.0.0", port=PORT)

src/unitxt/ccc_worker.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import argparse
2+
3+
import requests
4+
5+
server_url = "http://localhost/"
6+
7+
8+
def post(endpoint, data):
9+
# print(f"{server_url}{endpoint}")
10+
response = requests.post(f"{server_url}{endpoint}", json=data)
11+
if response.status_code == 200:
12+
return response.json()
13+
raise RuntimeError("Failed to post from to server:", response.status_code, response.text)
14+
15+
16+
def main(**kwargs):
17+
worker_id = kwargs["id"]
18+
# get configuration and create actual inference engine
19+
configuration = post("register", worker_id)
20+
configuration["1"] = 2
21+
#print(configuration)
22+
finish = False
23+
while not finish:
24+
# get batch from server
25+
batch = None
26+
if not batch:
27+
finish = True
28+
continue
29+
# create predictions for batch
30+
# return predictions to server
31+
32+
33+
34+
35+
# should be --kwargs key1=value1 key2=value2 key3=value3
36+
if __name__ == "__main__":
37+
parser = argparse.ArgumentParser(description="Generic argument parser")
38+
39+
# Accept arbitrary key-value pairs as arguments
40+
parser.add_argument("--kwargs", nargs="+", help="Pass key=value pairs", default=[])
41+
42+
args = parser.parse_args()
43+
44+
# Convert key=value pairs to a dictionary
45+
kwargs_dict = {}
46+
for item in args.kwargs:
47+
if "=" in item:
48+
key, value = item.split("=", 1)
49+
kwargs_dict[key] = value
50+
51+
main(**kwargs_dict)

src/unitxt/inference.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3358,12 +3358,18 @@ class CCCInferenceEngine(
33583358
ccc_path: str
33593359
ccc_python: str
33603360
server_port: str = "8000"
3361+
num_of_workers: int = 10
33613362

3362-
def prepare_engine(self):
3363+
def post(self, endpoint, data=None):
3364+
response = requests.post(f"{self.server_url}/{endpoint}", json=data, timeout=5)
3365+
if response.status_code == 200:
3366+
return response.json()
3367+
raise RuntimeError("Failed to post from to server:", response.status_code, response.text)
3368+
3369+
def start_ccc_server(self):
33633370
import paramiko
33643371
server_file = "ccc_inference.py"
33653372
local_server_path = os.path.dirname(os.path.abspath(__file__))
3366-
self.server_url = f"http://{self.ccc_host}:{self.server_port}"
33673373

33683374
ssh = paramiko.SSHClient()
33693375
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
@@ -3372,8 +3378,9 @@ def prepare_engine(self):
33723378
sftp = ssh.open_sftp()
33733379
sftp.put(os.path.join(local_server_path, server_file), os.path.join(self.ccc_path, server_file))
33743380
sftp.close()
3375-
ssh.exec_command(f"cd {self.ccc_path} && nohup {self.ccc_python} {server_file} --port {self.server_port} > server.log 2>&1 &")
3376-
time.sleep(2) # Wait 2 seconds before checking
3381+
ssh.exec_command(
3382+
f"cd {self.ccc_path} && nohup {self.ccc_python} {server_file} --port {self.server_port} > server.log 2>&1 &")
3383+
time.sleep(1)
33773384
try:
33783385
response = requests.get(f"{self.server_url}/isup", timeout=5)
33793386
if response.status_code == 200:
@@ -3394,14 +3401,23 @@ def prepare_engine(self):
33943401
except requests.RequestException as err:
33953402
raise RuntimeError(f"Failed to start ccc server. Response: {server_log_content}") from err
33963403

3397-
get_logger().info("OK")
3398-
get_logger().info(111)
3399-
self.shutdown_server()
3400-
get_logger().info(222)
3404+
def prepare_engine(self):
3405+
self.server_url = f"http://{self.ccc_host}:{self.server_port}"
3406+
if "localhost" in self.ccc_host:
3407+
response = requests.get(f"{self.server_url}/isup", timeout=5)
3408+
if response.status_code == 200:
3409+
get_logger().info("Server is up and running!")
3410+
else:
3411+
raise RuntimeError("server is down!")
3412+
else:
3413+
self.start_ccc_server()
3414+
self.post("set_configuration", data=self.to_dict([HFGenerationParamsMixin]))
34013415

34023416
def shutdown_server(self):
3417+
if "localhost" in self.ccc_host:
3418+
return
34033419
try:
3404-
requests.post(f"{self.server_url}/shutdown", timeout=5)
3420+
self.post("shutdown")
34053421
except:
34063422
pass
34073423

@@ -3414,4 +3430,5 @@ def _infer(
34143430
dataset: Union[List[Dict[str, Any]], Dataset],
34153431
return_meta_data: bool = False,
34163432
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
3417-
pass
3433+
messages = [self.to_messages(instance) for instance in dataset]
3434+
return self.post("infer", data=messages)

test_ccc.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,38 @@
1+
import hashlib
2+
import json
3+
import os
4+
import time
5+
6+
import joblib
17
import unitxt
8+
from unitxt import load_dataset
29
from unitxt.inference import CCCInferenceEngine
310
from unitxt.logging_utils import set_verbosity
411

12+
13+
def get_cache_filename(cache_dir="cache", **kwargs):
14+
"""Generate a unique filename for caching based on function arguments."""
15+
os.makedirs(cache_dir, exist_ok=True)
16+
hash_key = hashlib.md5(json.dumps(kwargs, sort_keys=True).encode()).hexdigest()
17+
return os.path.join(cache_dir, f"dataset_{hash_key}.pkl")
18+
19+
20+
def load_dataset_cached(**kwargs):
21+
"""Load dataset with disk caching."""
22+
cache_file = get_cache_filename(**kwargs)
23+
24+
if os.path.exists(cache_file):
25+
# print("Loading from cache...")
26+
return joblib.load(cache_file)
27+
28+
# print("Loading dataset from source...")
29+
data = load_dataset(**kwargs) # Your actual function call
30+
# print("Saving to cache...")
31+
joblib.dump(data, cache_file)
32+
33+
return data
34+
35+
536
if __name__ == "__main__":
637
set_verbosity("debug")
738
unitxt.settings.allow_unverified_code = True
@@ -12,20 +43,15 @@
1243
"metrics.llm_as_judge.direct.rits.llama3_1_70b[context_fields=[question],"
1344
f"criteria=metrics.llm_as_judge.direct.criteria.{criterion}]"
1445
]
15-
# dataset = load_dataset(card="cards.openbook_qa",
16-
# metrics=metrics,
17-
# split='test')
18-
# #dataset = dataset.select(range(10))
46+
dataset = load_dataset_cached(card="cards.openbook_qa", metrics=metrics, split="test")
47+
#dataset = dataset.select(range(10))
1948
inference_model = CCCInferenceEngine(max_new_tokens=13,
2049
ccc_host="cccxl013.pok.ibm.com",
50+
#ccc_host="localhost",
2151
ccc_user="eladv",
2252
ccc_path="/u/eladv/fusion/inference_server",
2353
ccc_python="/dccstor/fuse/envs/fm-eval/bin/python")
24-
# model="watsonx/meta-llama/llama-3-2-1b-instruct",
25-
# max_tokens=256,
26-
# use_cache=True
27-
# )
28-
#
54+
2955
# def my_wrapper(original_method):
3056
# random.seed(int(time.time()))
3157
# async def wrapped(*args, **kwargs):
@@ -45,10 +71,11 @@
4571
#
4672
# inference_model._infer_instance = my_wrapper(inference_model._infer_instance)
4773

48-
# start_time = time.time()
49-
# predictions = inference_model.infer(dataset)
50-
# end_time = time.time()
51-
#
74+
start_time = time.time()
75+
predictions = inference_model.infer(dataset)
76+
end_time = time.time()
77+
78+
# print(f"len predictions: {len(predictions)} first 10 predictions: {predictions[:10]}")
5279
# print(f"predictions contains {predictions.count(None)} Nones")
5380
#
5481
# mode = 'validate'

0 commit comments

Comments
 (0)