diff --git a/setup.py b/setup.py index f3d35d6f..3cfee814 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="pythae", - version="0.0.4", + version="0.0.5", author="Clement Chadebec (HekA team INRIA)", author_email="clement.chadebec@inria.fr", description="Unifying Generative Autoencoders in Python", diff --git a/src/pythae/models/base/base_model.py b/src/pythae/models/base/base_model.py index 308dc027..99c5a3f1 100644 --- a/src/pythae/models/base/base_model.py +++ b/src/pythae/models/base/base_model.py @@ -19,7 +19,7 @@ from ..nn import BaseDecoder, BaseEncoder from ..nn.default_architectures import Decoder_AE_MLP from .base_config import BaseAEConfig, EnvironmentConfig -from .base_utils import CPU_Unpickler, ModelOutput, hf_hub_is_available +from .base_utils import CPU_Unpickler, ModelOutput, hf_hub_is_available, model_card_template logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -188,15 +188,6 @@ def push_to_hf_hub(self, hf_hub_path: str): # pragma: no cover api = HfApi() hf_operations = [] - hf_operations.append( - CommitOperationAdd( - path_in_repo="README.md", - path_or_fileobj=os.path.join( - os.path.dirname(os.path.abspath(__file__)), "model_card.md" - ), - ) - ) - for file in model_files: hf_operations.append( CommitOperationAdd( @@ -205,6 +196,16 @@ def push_to_hf_hub(self, hf_hub_path: str): # pragma: no cover ) ) + with open(os.path.join(tempdir, "model_card.md"), "w") as f: + f.write(model_card_template) + + hf_operations.append( + CommitOperationAdd( + path_in_repo="README.md", + path_or_fileobj=os.path.join(tempdir, "model_card.md"), + ) + ) + try: api.create_commit( commit_message=f"Uploading {self.model_name} in {hf_hub_path}", diff --git a/src/pythae/models/base/base_utils.py b/src/pythae/models/base/base_utils.py index a53165b1..817cadf6 100644 --- a/src/pythae/models/base/base_utils.py +++ b/src/pythae/models/base/base_utils.py @@ -12,6 +12,21 @@ logger.addHandler(console) logger.setLevel(logging.INFO) +model_card_template = """--- +language: en +tags: +- pythae +license: apache-2.0 +--- + +### Downloading this model from the Hub +This model was trained with pythae. It can be downloaded or reloaded using the method `load_from_hf_hub` +```python +>>> from pythae.models import AutoModel +>>> model = AutoModel.load_from_hf_hub(hf_hub_path="your_hf_username/repo_name") +``` +""" + def hf_hub_is_available(): return importlib.util.find_spec("huggingface_hub") is not None diff --git a/src/pythae/models/base/model_card.md b/src/pythae/models/base/model_card.md deleted file mode 100644 index f9269de4..00000000 --- a/src/pythae/models/base/model_card.md +++ /dev/null @@ -1,13 +0,0 @@ ---- -language: en -tags: -- pythae -license: apache-2.0 ---- - -### Downloading this model from the Hub -This model was trained with pythae. It can be downloaded or reloaded using the method `load_from_hf_hub` -```python ->>> from pythae.models import AutoModel ->>> model = AutoModel.load_from_hf_hub(hf_hub_path="your_hf_username/repo_name") -``` \ No newline at end of file