generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 595
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Regardless of which GPU a Chronos2Pipeline is loaded on, fine-tuning will always happen on GPU 0. This is because we have this hack to disable data parallel.
To reproduce:
import numpy as np
from chronos import Chronos2Pipeline
def generate_data(num_items: int = 10_000):
rng = np.random.default_rng(seed=42)
train_data = [{"target": rng.normal(size=2048)} for i in range(num_items)]
return train_data
def main():
train_data = generate_data()
pipeline = Chronos2Pipeline.from_pretrained("amazon/chronos-2", device_map="cuda:5")
print(pipeline.model.device) # cuda:5
ft_pipeline = pipeline.fit(train_data, context_length=512, prediction_length=64, num_steps=10)
print(ft_pipeline.model.device) # cuda:0
if __name__ == "__main__":
main()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working