Skip to content

Chronos-2 fine-tuning does not respect the GPU index #457

@abdulfatir

Description

@abdulfatir

Regardless of which GPU a Chronos2Pipeline is loaded on, fine-tuning will always happen on GPU 0. This is because we have this hack to disable data parallel.

To reproduce:

import numpy as np

from chronos import Chronos2Pipeline


def generate_data(num_items: int = 10_000):
    rng = np.random.default_rng(seed=42)
    train_data = [{"target": rng.normal(size=2048)} for i in range(num_items)]

    return train_data


def main():
    train_data = generate_data()
    pipeline = Chronos2Pipeline.from_pretrained("amazon/chronos-2", device_map="cuda:5")
    print(pipeline.model.device)  # cuda:5
    ft_pipeline = pipeline.fit(train_data, context_length=512, prediction_length=64, num_steps=10)
    print(ft_pipeline.model.device)  # cuda:0


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions