Skip to content

Commit

Permalink
[Feature] Remove and check for prints in codebase using flake8-print (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 4, 2024
1 parent 4151a83 commit 750a114
Show file tree
Hide file tree
Showing 19 changed files with 76 additions and 58 deletions.
3 changes: 2 additions & 1 deletion .github/unittest/linux/scripts/run-clang-format.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import argparse
import difflib
import fnmatch
import logging
import multiprocessing
import os
import signal
Expand Down Expand Up @@ -216,7 +217,7 @@ def print_trouble(prog, message, use_colors):
error_text = "error:"
if use_colors:
error_text = bold_red(error_text)
print(f"{prog}: {error_text} {message}", file=sys.stderr)
logging.error(f"{prog}: {error_text} {message}")


def main():
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_torchrec/scripts/run-clang-format.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import argparse
import difflib
import fnmatch
import logging
import multiprocessing
import os
import signal
Expand Down Expand Up @@ -216,7 +217,7 @@ def print_trouble(prog, message, use_colors):
error_text = "error:"
if use_colors:
error_text = bold_red(error_text)
print(f"{prog}: {error_text} {message}", file=sys.stderr)
logging.error(f"{prog}: {error_text} {message}")


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import argparse
import difflib
import fnmatch
import logging
import multiprocessing
import os
import signal
Expand Down Expand Up @@ -216,7 +217,7 @@ def print_trouble(prog, message, use_colors):
error_text = "error:"
if use_colors:
error_text = bold_red(error_text)
print(f"{prog}: {error_text} {message}", file=sys.stderr)
logging.error(f"{prog}: {error_text} {message}")


def main():
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ repos:
- flake8-bugbear==22.10.27
- flake8-comprehensions==3.10.1
- torchfix==0.0.2
- flake8-print==5.0.0

- repo: https://github.com/PyCQA/pydocstyle
rev: 6.1.1
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import time
from collections import defaultdict
Expand Down Expand Up @@ -31,7 +31,7 @@ def pytest_sessionfinish(maxprint=50):
out_str += f"\t{key}{spaces}{item: 4.4f}s\n"
if i == maxprint - 1:
break
print(out_str)
logging.info(out_str)


@pytest.fixture(autouse=True)
Expand Down
38 changes: 20 additions & 18 deletions benchmarks/distributed/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import argparse
import collections

import logging
import os
import time
from functools import wraps
Expand Down Expand Up @@ -148,12 +150,12 @@ def load(cls, dataset, path):
# locks the tensorclass and ensures that is_memmap will return True.
data.memmap_()
t0 = time.time()
print("loading...", end="\t")
logging.info("loading...", end="\t")
snapshot = torchsnapshot.Snapshot(path=path)
sd = dict(data.state_dict())
app_state = {"state": torchsnapshot.StateDict(data=sd)}
snapshot.restore(app_state=app_state)
print(f"done! Took: {time.time()-t0:4.4f}s")
logging.info(f"done! Took: {time.time()-t0:4.4f}s")
return data

def save(self, path):
Expand Down Expand Up @@ -299,7 +301,7 @@ def init(self, train_data_tc):
)

def train(self) -> None:
print("train")
logging.info("train")
len_data = self.data.shape[0]
pbar = tqdm.tqdm(total=len_data)

Expand Down Expand Up @@ -344,7 +346,7 @@ def train(self) -> None:
if iteration >= self.world_size:
total += batch.shape[0]
t = time.time() - t0
print(f"time spent: {t:4.4f}s, Rate: {total/t} fps")
logging.info(f"time spent: {t:4.4f}s, Rate: {total/t} fps")
return {"time": t, "rate": total / t}

def create_data_nodes(self, data):
Expand All @@ -358,14 +360,14 @@ def create_data_nodes(self, data):
reraise=True,
)
def create_data_node(self, node, local_transform) -> rpc.RRef:
print(f"Creating DataNode object on remote node {node}")
logging.info(f"Creating DataNode object on remote node {node}")
data_info = rpc.get_worker_info(f"{DATA_NODE}_{node}")
data_rref = rpc.remote(
data_info,
DataNode,
args=(node, BATCH_SIZE, self.single_gpu, local_transform),
)
print(f"Connected to data node {data_info}")
logging.info(f"Connected to data node {data_info}")
time.sleep(5)
self.datanodes.append(data_rref)

Expand All @@ -389,7 +391,7 @@ def __init__(
single_gpu: bool = False,
make_transform: bool = True,
):
print("Creating DataNode object")
logging.info("Creating DataNode object")
self.rank = rank
self.id = rpc.get_worker_info().id
self.single_gpu = single_gpu
Expand Down Expand Up @@ -422,10 +424,10 @@ def __init__(
self.collate = Collate(self.collate_transform, device=device)
self.initialized = False
self.count = 0
print("done!")
logging.info("done!")

def set_data(self, data):
print("initializing")
logging.info("initializing")
self.initialized = True
self.data: ImageNetData = data

Expand Down Expand Up @@ -471,7 +473,7 @@ def init_rpc(
rpc_backend_options=options,
)

print(f"Initialised {name}")
logging.info(f"Initialised {name}")


def shutdown():
Expand All @@ -491,7 +493,7 @@ def func(rank, world_size, args, train_data_tc, single_gpu, trainer_transform):
import wandb

if not args.wandb_key:
print("no wandb key provided, using it offline")
logging.info("no wandb key provided, using it offline")
mode = "offline"
else:
mode = "online"
Expand All @@ -512,14 +514,14 @@ def func(rank, world_size, args, train_data_tc, single_gpu, trainer_transform):
rate = stats["rate"]
wandb.log(stats, step=i)
wandb.log({"min time": min_time, "max_rate": rate})
print(f"FINAL: time spent: {min_time:4.4f}s, Rate: {rate} fps")
logging.info(f"FINAL: time spent: {min_time:4.4f}s, Rate: {rate} fps")


if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except Exception as err:
print(f"Could not start mp with spawn method. Error: {err}")
logging.info(f"Could not start mp with spawn method. Error: {err}")

args = parser.parse_args()
world_size = args.world_size
Expand All @@ -532,7 +534,7 @@ def func(rank, world_size, args, train_data_tc, single_gpu, trainer_transform):

names = [TRAINER_NODE, *[f"{DATA_NODE}_{rank}" for rank in range(1, world_size)]]

print("preparing data")
logging.info("preparing data")
data_dir = Path("/datasets01_ontap/imagenet_full_size/061417/")
train_data_raw = datasets.ImageFolder(
root=data_dir / "train",
Expand All @@ -545,15 +547,15 @@ def func(rank, world_size, args, train_data_tc, single_gpu, trainer_transform):
]

if load_path:
print("loading...", end="\t")
logging.info("loading...", end="\t")
train_data_tc = ImageNetData.load(train_data_raw, load_path)
print("done")
logging.info("done")
else:
train_data_tc = ImageNetData.from_dataset(train_data_raw)
if save_path:
print("saving...", end="\t")
logging.info("saving...", end="\t")
train_data_tc.save(save_path)
print("done")
logging.info("done")

with mp.Pool(world_size) as pool:
pool.starmap(
Expand Down
7 changes: 5 additions & 2 deletions benchmarks/distributed/distributed_benchmark_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import time

Expand Down Expand Up @@ -72,8 +77,6 @@ def exec_distributed_test(rank_node):
break
except RuntimeError:
time.sleep(0.1)
print("-", end="")
print("")

def fill_tensordict(tensordict, idx):
tensordict[idx] = TensorDict(
Expand Down
14 changes: 10 additions & 4 deletions benchmarks/fx_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import timeit

import torch
Expand Down Expand Up @@ -66,7 +72,7 @@ def forward(self, x):
batch_size=[32],
)

print(
logging.info(
"forward, TensorDictSequential",
timeit.timeit(
"module(tensordict)",
Expand All @@ -75,7 +81,7 @@ def forward(self, x):
),
)

print(
logging.info(
"forward, GraphModule",
timeit.timeit(
"module(tensordict)",
Expand All @@ -94,7 +100,7 @@ def forward(self, x):
nested_graph_module = symbolic_trace(nested_tdmodule)
tensordict = TensorDict({"input": torch.rand(32, 100)}, [32])

print(
logging.info(
"nested_forward, TensorDictSequential",
timeit.timeit(
"module(tensordict)",
Expand All @@ -103,7 +109,7 @@ def forward(self, x):
),
)

print(
logging.info(
"nested_forward, GraphModule",
timeit.timeit(
"module(tensordict)",
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ per-file-ignores =
test/smoke_test.py: F401
test/smoke_test_deps.py: F401
test_*.py: E731, E266, TOR101
tutorials/*/**.py: T201
exclude = venv
extend-select = B901, C401, C408, C409

Expand Down
13 changes: 7 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import distutils.command.clean
import glob
import logging
import os
import shutil
import subprocess
Expand Down Expand Up @@ -86,13 +87,13 @@ def run(self):

# Remove tensordict extension
for path in (ROOT_DIR / "tensordict").glob("**/*.so"):
print(f"removing '{path}'")
logging.info(f"removing '{path}'")
path.unlink()
# Remove build directory
build_dirs = [ROOT_DIR / "build"]
for path in build_dirs:
if path.exists():
print(f"removing '{path}' (and everything under it)")
logging.info(f"removing '{path}' (and everything under it)")
shutil.rmtree(str(path), ignore_errors=True)


Expand All @@ -109,7 +110,7 @@ def get_extensions():
}
debug_mode = os.getenv("DEBUG", "0") == "1"
if debug_mode:
print("Compiling in debug mode")
logging.info("Compiling in debug mode")
extra_compile_args = {
"cxx": [
"-O0",
Expand Down Expand Up @@ -151,11 +152,11 @@ def _main(argv):
version = get_nightly_version() if is_nightly else get_version()

write_version_file(version)
print(f"Building wheel {package_name}-{version}")
print(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}")
logging.info(f"Building wheel {package_name}-{version}")
logging.info(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}")

pytorch_package_dep = _get_pytorch_version(is_nightly)
print("-- PyTorch dependency:", pytorch_package_dep)
logging.info("-- PyTorch dependency:", pytorch_package_dep)

long_description = (ROOT_DIR / "README.md").read_text()
sys.argv = [sys.argv[0], *unknown]
Expand Down
4 changes: 3 additions & 1 deletion tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from __future__ import annotations

import logging

from copy import deepcopy
from typing import Any, Iterable

Expand All @@ -16,7 +18,7 @@

_has_functorch = True
except ImportError:
print(
logging.info(
"failed to import functorch. TensorDict's features that do not require "
"functional programming should work, but functionality and performance "
"may be affected. Consider installing functorch and/or upgrating pytorch."
Expand Down
6 changes: 4 additions & 2 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import concurrent.futures
import dataclasses
import inspect
import logging

import math
import os

Expand Down Expand Up @@ -691,7 +693,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
val[2] = N

@staticmethod
def print(prefix=None):
def print(prefix=None): # noqa: T202
keys = list(timeit._REG)
keys.sort()
for name in keys:
Expand All @@ -701,7 +703,7 @@ def print(prefix=None):
strings.append(
f"{name} took {timeit._REG[name][0] * 1000:4.4} msec (total = {timeit._REG[name][1]} sec)"
)
print(" -- ".join(strings))
logging.info(" -- ".join(strings))

@staticmethod
def erase():
Expand Down
1 change: 0 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def pytest_sessionfinish(maxprint=50):
out_str += f"\t{key}{spaces}{item: 4.4f}s\n"
if i == maxprint - 1:
break
print(out_str)


@pytest.fixture(autouse=True)
Expand Down
Loading

0 comments on commit 750a114

Please sign in to comment.