Skip to content

Commit 216409c

Browse files
Simplify spyre_setup.py and fix distributed setup (#190)
Most of the env vars can be handled automatically in the Sentient backend now, but it seems that flex requires `torchrun` style RANK and WORLD_SIZE env vars for configuring distributed serving. Credit to @tdoublep for the insight of setting RANK and WORLD_SIZE for flex. --------- Signed-off-by: Travis Johnson <[email protected]>
1 parent 324f252 commit 216409c

File tree

3 files changed

+11
-142
lines changed

3 files changed

+11
-142
lines changed
Lines changed: 9 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,147 +1,16 @@
1-
import json
21
import os
3-
import sys
4-
import tempfile
52

6-
import torch
73

8-
# ==============================================================
9-
# Common utilities
10-
# ==============================================================
11-
#-------------
12-
# Discover the world size and my rank (envars set by torchrun)
13-
# https://pytorch.org/docs/stable/elastic/run.html#environment-variables
14-
#-------------
15-
local_rank = int(os.getenv("LOCAL_RANK", 0))
16-
rank = int(os.getenv("RANK", 0))
17-
world_rank = rank
18-
world_size = int(os.getenv("WORLD_SIZE", 1))
19-
20-
def dprint(text):
21-
print(f"[{rank:2d}/{world_size:2d}]: {text}")
22-
23-
# ==============================================================
24-
# Common setup
25-
# ==============================================================
26-
def spyre_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False):
27-
# -------------
28-
# Envar setup for backend
29-
# -------------
30-
# Environment variable created by the runtime to identify the specific Spyre card that is assigned to this rank
31-
spyre_config_file_envar = "AIU_CONFIG_FILE_" + str(rank)
32-
33-
# Default to senulator backend unless user specified otherwise
4+
def spyre_setup():
5+
# default to senulator backend unless user specified otherwise
346
os.environ.setdefault("FLEX_COMPUTE", "SENULATOR")
35-
os.environ.setdefault("FLEX_DEVICE", "MOCK")
36-
37-
# Each rank needs a unique space to write its binaries
38-
# For both 'export' and '__pycache'
39-
# https://docs.python.org/3/library/sys.html#sys.pycache_prefix
40-
with tempfile.TemporaryDirectory() as exportdir:
41-
os.environ.setdefault("DEEPRT_EXPORT_DIR", exportdir)
42-
os.environ.setdefault("DTCOMPILER_EXPORT_DIR", exportdir)
43-
if world_size > 1:
44-
os.environ["DEEPRT_EXPORT_DIR"] += f"/{rank}"
45-
os.environ["DTCOMPILER_EXPORT_DIR"] += f"/{rank}"
46-
sys.pycache_prefix=os.getenv("DEEPRT_EXPORT_DIR")+"/py-" + str(rank)
47-
os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "0")
48-
49-
# Inform Flex of the size of this job
50-
os.environ.setdefault("FLEX_RDMA_WORLD_SIZE", str(world_size))
51-
os.environ.setdefault("FLEX_RDMA_WORLD_RANK", str(rank))
52-
os.environ.setdefault("FLEX_RDMA_LOCAL_SIZE", str(world_size))
53-
os.environ.setdefault("FLEX_RDMA_LOCAL_RANK", str(rank))
54-
for peer_rank in range(world_size):
55-
pcie_env_str="AIU_WORLD_RANK_"+str(peer_rank)
56-
flex_env_str="FLEX_RDMA_PCI_BUS_ADDR_"+str(peer_rank)
57-
if os.getenv(pcie_env_str) is None:
58-
raise RuntimeError(f"Error: The environment variable {pcie_env_str} is not defined")
59-
if os.getenv(flex_env_str) is None:
60-
raise RuntimeError(f"Error: The environment variable {flex_env_str} is not defined")
61-
if os.getenv("DUMP_MEMMAP") is not None:
62-
if os.getenv("SDSC_REF_DIR") is None:
63-
os.environ["SDSC_REF_DIR"] = os.environ["DEEPRT_EXPORT_DIR"]
64-
else:
65-
os.environ["SDSC_REF_DIR"] += f"/{rank}"
66-
assert (
67-
os.getenv("DUMP_MEMMAP_DIR") is not None
68-
), "DUMP_MEMMAP_DIR not set while DUMP_MEMMAP set"
69-
os.environ["DUMP_MEMMAP_DIR"] += f"/{rank}"
70-
os.makedirs(
71-
os.environ["DUMP_MEMMAP_DIR"], exist_ok=True
72-
) # directory needs to exist
73-
74-
for peer_rank in range(world_size):
75-
pcie_env_str = "AIU_WORLD_RANK_" + str(peer_rank)
76-
flex_env_str = "FLEX_RDMA_PCI_BUS_ADDR_" + str(peer_rank)
77-
if os.getenv("FLEX_COMPUTE") == "SENULATOR":
78-
if os.getenv(pcie_env_str) is not None:
79-
os.environ[flex_env_str] = os.getenv(pcie_env_str)
80-
else:
81-
os.environ[pcie_env_str] = f"0000:{rank:02x}:01.0"
82-
os.environ[flex_env_str] = f"0000:{rank:02x}:01.0"
83-
else:
84-
if os.getenv(flex_env_str) is None:
85-
if os.getenv("PCIDEVICE_IBM_COM_SENTIENT_PF") is not None:
86-
os.environ[pcie_env_str] = os.getenv(
87-
"PCIDEVICE_IBM_COM_SENTIENT_PF"
88-
)
89-
90-
if os.getenv(pcie_env_str) is not None:
91-
os.environ[flex_env_str] = os.getenv(pcie_env_str)
92-
else:
93-
raise RuntimeError(
94-
f"[{rank}/{world_size}]: ERROR: {flex_env_str} and {pcie_env_str} were not set for peer {peer_rank}."
95-
)
96-
if rank == 0 and verbose:
97-
dprint(f"PCI Addr Rank {peer_rank} {pcie_env_str}={os.environ[pcie_env_str]}")
98-
dprint(f"PCI Addr Rank {peer_rank} {flex_env_str}={os.environ[flex_env_str]}")
99-
100-
if rank == 0 and verbose:
101-
dprint(f"FLEX_COMPUTE=" + os.getenv("FLEX_COMPUTE"))
102-
dprint(f"FLEX_DEVICE=" + os.getenv("FLEX_DEVICE"))
103-
dprint(f"DEEPRT_EXPORT_DIR=" + os.getenv("DEEPRT_EXPORT_DIR"))
104-
dprint(f"DTCOMPILER_EXPORT_DIR=" + os.getenv("DTCOMPILER_EXPORT_DIR"))
105-
if os.getenv(spyre_config_file_envar) is not None:
106-
dprint(f"{spyre_config_file_envar}=" + os.environ[spyre_config_file_envar])
107-
if os.getenv("SENLIB_DEVEL_CONFIG_FILE") is not None:
108-
dprint(f"SENLIB_DEVEL_CONFIG_FILE=" + os.environ["SENLIB_DEVEL_CONFIG_FILE"])
109-
if os.getenv(flex_env_str) is not None:
110-
dprint(f"{flex_env_str}=" + os.environ[flex_env_str])
111-
dprint(f"FLEX_RDMA_LOCAL_RANK=" + os.getenv("FLEX_RDMA_LOCAL_RANK"))
112-
dprint(f"FLEX_RDMA_LOCAL_SIZE=" + os.getenv("FLEX_RDMA_LOCAL_SIZE"))
113-
dprint(f"FLEX_RDMA_WORLD_RANK=" + os.getenv("FLEX_RDMA_WORLD_RANK"))
114-
dprint(f"FLEX_RDMA_WORLD_SIZE=" + os.getenv("FLEX_RDMA_WORLD_SIZE"))
115-
116-
if os.getenv("FLEX_COMPUTE") == "SENTIENT":
117-
pcie_env_str = "AIU_WORLD_RANK_" + str(rank)
118-
if os.getenv(pcie_env_str) is not None:
119-
device_id = os.getenv(pcie_env_str)
120-
else:
121-
with open(os.getenv(spyre_config_file_envar)) as fd:
122-
data = json.load(fd)
123-
device_id = data["GENERAL"]["sen_bus_id"]
124-
dprint(f"Spyre: Enabled ({device_id})")
125-
else:
126-
dprint(f"Spyre: Disabled (Senulator)")
127-
128-
129-
# ==============================================================
130-
# Distributed setup
131-
# ==============================================================
132-
def spyre_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False):
133-
if local_rank < 0:
134-
local_rank = rank
135-
if local_size < 0:
136-
local_size = world_size
1377

138-
if os.getenv("TORCHELASTIC_RUN_ID") is None:
139-
os.environ["MASTER_ADDR"] = "localhost"
140-
os.environ["MASTER_PORT"] = "12355"
141-
elif rank == 0 or verbose:
142-
dprint(f"Detected running via torchrun")
8+
def spyre_dist_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False):
9+
# make sure to have torchrun env vars for flex
10+
os.environ.setdefault("RANK", str(rank))
11+
os.environ.setdefault("WORLD_SIZE", str(world_size))
14312

144-
if rank == 0 or verbose:
145-
dprint(f"Parallel Backend: {torch.distributed.get_backend()}")
13+
if verbose:
14+
print(f"Distributed rank {os.environ['RANK']} / {os.environ['WORLD_SIZE']}")
14615

147-
spyre_setup(rank, world_size)
16+
spyre_setup()

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def init_device(self) -> None:
227227
elif envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in [
228228
"sendnn", "sendnn_decoder"
229229
]:
230-
spyre_setup.spyre_setup(rank=0, world_size=1, verbose=True)
230+
spyre_setup.spyre_setup()
231231

232232
ensure_model_parallel_initialized(
233233
self.parallel_config.tensor_parallel_size,

vllm_spyre/worker/spyre_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def init_device(self) -> None:
150150
elif envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in [
151151
"sendnn", "sendnn_decoder"
152152
]:
153-
spyre_setup.spyre_setup(rank=0, world_size=1, verbose=True)
153+
spyre_setup.spyre_setup()
154154

155155
ensure_model_parallel_initialized(
156156
self.parallel_config.tensor_parallel_size,

0 commit comments

Comments
 (0)