diff --git a/README.md b/README.md index 311d22c..2536c93 100644 --- a/README.md +++ b/README.md @@ -11,34 +11,66 @@ [![Gitter](https://badges.gitter.im/DeepLabCut/community.svg)](https://gitter.im/DeepLabCut/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![Twitter Follow](https://img.shields.io/twitter/follow/DeepLabCut.svg?label=DeepLabCut&style=social)](https://twitter.com/DeepLabCut) -This package contains a [DeepLabCut](http://www.mousemotorlab.org/deeplabcut) inference pipeline for real-time applications that has minimal (software) dependencies. Thus, it is as easy to install as possible (in particular, on atypical systems like [NVIDIA Jetson boards](https://developer.nvidia.com/buy-jetson)). - -**Performance:** If you would like to see estimates on how your model should perform given different video sizes, neural network type, and hardware, please see: https://deeplabcut.github.io/DLC-inferencespeed-benchmark/ - -If you have different hardware, please consider submitting your results too! https://github.com/DeepLabCut/DLC-inferencespeed-benchmark - -**What this SDK provides:** This package provides a `DLCLive` class which enables pose estimation online to provide feedback. This object loads and prepares a DeepLabCut network for inference, and will return the predicted pose for single images. - -To perform processing on poses (such as predicting the future pose of an animal given it's current pose, or to trigger external hardware like send TTL pulses to a laser for optogenetic stimulation), this object takes in a `Processor` object. Processor objects must contain two methods: process and save. - -- The `process` method takes in a pose, performs some processing, and returns processed pose. +This package contains a [DeepLabCut](http://www.mousemotorlab.org/deeplabcut) inference +pipeline for real-time applications that has minimal (software) dependencies. Thus, it +is as easy to install as possible (in particular, on atypical systems like [ +NVIDIA Jetson boards](https://developer.nvidia.com/buy-jetson)). + +If you've used DeepLabCut-Live with TensorFlow models and want to try the PyTorch +version, take a look at [_Switching from TensorFlow to PyTorch_]( +#Switching-from-TensorFlow-to-PyTorch) + +**Performance of TensorFlow models:** If you would like to see estimates on how your +model should perform given different video sizes, neural network type, and hardware, +please see: [deeplabcut.github.io/DLC-inferencespeed-benchmark/ +](https://deeplabcut.github.io/DLC-inferencespeed-benchmark/). **We're working on +getting these benchmarks for PyTorch architectures as well.** + +If you have different hardware, please consider [submitting your results too]( +https://github.com/DeepLabCut/DLC-inferencespeed-benchmark)! + +**What this SDK provides:** This package provides a `DLCLive` class which enables pose +estimation online to provide feedback. This object loads and prepares a DeepLabCut +network for inference, and will return the predicted pose for single images. + +To perform processing on poses (such as predicting the future pose of an animal given +its current pose, or to trigger external hardware like send TTL pulses to a laser for +optogenetic stimulation), this object takes in a `Processor` object. Processor objects +must contain two methods: `process` and `save`. + +- The `process` method takes in a pose, performs some processing, and returns processed +pose. - The `save` method saves any valuable data created by or used by the processor For more details and examples, see documentation [here](dlclive/processor/README.md). -###### 🔥🔥🔥🔥🔥 Note :: alone, this object does not record video or capture images from a camera. This must be done separately, i.e. see our [DeepLabCut-live GUI](https://github.com/gkane26/DeepLabCut-live-GUI).🔥🔥🔥 - -### News! -- March 2022: DeepLabCut-Live! 1.0.2 supports poetry installation `poetry install deeplabcut-live`, thanks to PR #60. -- March 2021: DeepLabCut-Live! [**version 1.0** is released](https://pypi.org/project/deeplabcut-live/), with support for tensorflow 1 and tensorflow 2! -- Feb 2021: DeepLabCut-Live! was featured in **Nature Methods**: ["Real-time behavioral analysis"](https://www.nature.com/articles/s41592-021-01072-z) -- Jan 2021: full **eLife** paper is published: ["Real-time, low-latency closed-loop feedback using markerless posture tracking"](https://elifesciences.org/articles/61909) -- Dec 2020: we talked to **RTS Suisse Radio** about DLC-Live!: ["Capture animal movements in real time"](https://www.rts.ch/play/radio/cqfd/audio/capturer-les-mouvements-des-animaux-en-temps-reel?id=11782529) - - -### Installation: - -Please see our instruction manual to install on a [Windows or Linux machine](docs/install_desktop.md) or on a [NVIDIA Jetson Development Board](docs/install_jetson.md). Note, this code works with tensorflow (TF) 1 or TF 2 models, but TF requires that whatever version you exported your model with, you must import with the same version (i.e., export with TF1.13, then use TF1.13 with DlC-Live; export with TF2.3, then use TF2.3 with DLC-live). +**🔥🔥🔥🔥🔥 Note :: alone, this object does not record video or capture images from a +camera. This must be done separately, i.e. see our [DeepLabCut-live GUI]( +https://github.com/DeepLabCut/DeepLabCut-live-GUI).🔥🔥🔥🔥🔥** + +### News! + +- **WIP 2025**: DeepLabCut-Live is implemented for models trained with the PyTorch engine! +- March 2022: DeepLabCut-Live! 1.0.2 supports poetry installation `poetry install +deeplabcut-live`, thanks to PR #60. +- March 2021: DeepLabCut-Live! [**version 1.0** is released](https://pypi.org/project/deeplabcut-live/), with support for +tensorflow 1 and tensorflow 2! +- Feb 2021: DeepLabCut-Live! was featured in **Nature Methods**: [ +"Real-time behavioral analysis"](https://www.nature.com/articles/s41592-021-01072-z) +- Jan 2021: full **eLife** paper is published: ["Real-time, low-latency closed-loop +feedback using markerless posture tracking"](https://elifesciences.org/articles/61909) +- Dec 2020: we talked to **RTS Suisse Radio** about DLC-Live!: ["Capture animal +movements in real time"]( +https://www.rts.ch/play/radio/cqfd/audio/capturer-les-mouvements-des-animaux-en-temps-reel?id=11782529) + +### Installation + +Please see our instruction manual to install on a [Windows or Linux machine]( +docs/install_desktop.md) or on a [NVIDIA Jetson Development Board]( +docs/install_jetson.md). Note, this code works with PyTorch, TensorFlow 1 or TensorFlow +2 models, but whatever engine you exported your model with, you must import with the +same version (i.e., export a PyTorch model, then install PyTorch, export with TF1.13, +then use TF1.13 with DlC-Live; export with TF2.3, then use TF2.3 with DLC-live). - available on pypi as: `pip install deeplabcut-live` @@ -46,11 +78,25 @@ Note, you can then test your installation by running: `dlc-live-test` -If installed properly, this script will i) create a temporary folder ii) download the full_dog model from the [DeepLabCut Model Zoo](http://www.mousemotorlab.org/dlc-modelzoo), iii) download a short video clip of a dog, and iv) run inference while displaying keypoints. v) remove the temporary folder. +If installed properly, this script will i) create a temporary folder ii) download the +full_dog model from the [DeepLabCut Model Zoo]( +http://www.mousemotorlab.org/dlc-modelzoo), iii) download a short video clip of +a dog, and iv) run inference while displaying keypoints. v) remove the temporary folder. DLC LIVE TEST -### Quick Start: instructions for use: +PyTorch and TensorFlow can be installed as extras with `deeplabcut-live` - though be +careful with the versions you install! + +```bash +# Install deeplabcut-live and PyTorch +`pip install deeplabcut-live[pytorch]` + +# Install deeplabcut-live and TensorFlow +`pip install deeplabcut-live[tf]` +``` + +### Quick Start: instructions for use 1. Initialize `Processor` (if desired) 2. Initialize the `DLCLive` object @@ -66,81 +112,161 @@ dlc_live.get_pose() `DLCLive` **parameters:** - - `path` = string; full path to the exported DLC model directory - - `model_type` = string; the type of model to use for inference. Types include: - - `base` = the base DeepLabCut model - - `tensorrt` = apply [tensor-rt](https://developer.nvidia.com/tensorrt) optimizations to model - - `tflite` = use [tensorflow lite](https://www.tensorflow.org/lite) inference (in progress...) - - `cropping` = list of int, optional; cropping parameters in pixel number: [x1, x2, y1, y2] - - `dynamic` = tuple, optional; defines parameters for dynamic cropping of images - - `index 0` = use dynamic cropping, bool - - `index 1` = detection threshold, float - - `index 2` = margin (in pixels) around identified points, int - - `resize` = float, optional; factor by which to resize image (resize=0.5 downsizes both width and height of image by half). Can be used to downsize large images for faster inference - - `processor` = dlc pose processor object, optional - - `display` = bool, optional; display processed image with DeepLabCut points? Can be used to troubleshoot cropping and resizing parameters, but is very slow +- `path` = string; full path to the exported DLC model directory +- `model_type` = string; the type of model to use for inference. Types include: + - `pytorch` = the base PyTorch DeepLabCut model + - `base` = the base TensorFlow DeepLabCut model + - `tensorrt` = apply [tensor-rt](https://developer.nvidia.com/tensorrt) optimizations to model + - `tflite` = use [tensorflow lite](https://www.tensorflow.org/lite) inference (in progress...) +- `cropping` = list of int, optional; cropping parameters in pixel number: [x1, x2, y1, y2] +- `dynamic` = tuple, optional; defines parameters for dynamic cropping of images + - `index 0` = use dynamic cropping, bool + - `index 1` = detection threshold, float + - `index 2` = margin (in pixels) around identified points, int +- `resize` = float, optional; factor by which to resize image (resize=0.5 downsizes + both width and height of image by half). Can be used to downsize large images for + faster inference +- `processor` = dlc pose processor object, optional +- `display` = bool, optional; display processed image with DeepLabCut points? Can be + used to troubleshoot cropping and resizing parameters, but is very slow `DLCLive` **inputs:** - - `` = path to the folder that has the `.pb` files that you acquire after running `deeplabcut.export_model` - - `` = is a numpy array of each frame +- `` = + - For TensorFlow models: path to the folder that has the `.pb` files that you + acquire after running `deeplabcut.export_model` + - For PyTorch models: path to the `.pt` file that is generated after running + `deeplabcut.export_model` +- `` = is a numpy array of each frame + +#### DLCLive - PyTorch Specific Guide + +This guide is for users who trained a model with the PyTorch engine with +`DeepLabCut 3.0`. +Once you've trained your model in [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut) +and you are happy with its performance, you can export the model to be used for live +inference with DLCLive! + +### Switching from TensorFlow to PyTorch + +This section is for users who **have already used DeepLabCut-Live** with +TensorFlow models (through DeepLabCut 1.X or 2.X) and want to switch to using the +PyTorch Engine. Some quick notes: + +- You may need to adapt your code slightly when creating the DLCLive instance. +- Processors that were created for TensorFlow models will function the same way with +PyTorch models. As multi-animal models can be used with PyTorch, the shape of the `pose` +array given to the processor may be `(num_individuals, num_keypoints, 3)`. Just call +`DLCLive(..., single_animal=True)` and it will work. ### Benchmarking/Analyzing your exported DeepLabCut models -DeepLabCut-live offers some analysis tools that allow users to peform the following operations on videos, from python or from the command line: +DeepLabCut-live offers some analysis tools that allow users to perform the following +operations on videos, from python or from the command line: + +#### Test inference speed across a range of image sizes + +Downsizing images can be done by specifying the `resize` or `pixels` parameter. Using +the `pixels` parameter will resize images to the desired number of `pixels`, without +changing the aspect ratio. Results will be saved (along with system info) to a pickle +file if you specify an output directory. + +Inside a **python** shell or script, you can run: -1. Test inference speed across a range of image sizes, downsizing images by specifying the `resize` or `pixels` parameter. Using the `pixels` parameter will resize images to the desired number of `pixels`, without changing the aspect ratio. Results will be saved (along with system info) to a pickle file if you specify an output directory. -##### python ```python -dlclive.benchmark_videos('/path/to/exported/model', ['/path/to/video1', '/path/to/video2'], output='/path/to/output', resize=[1.0, 0.75, '0.5']) -``` -##### command line +dlclive.benchmark_videos( + "/path/to/exported/model", + ["/path/to/video1", "/path/to/video2"], + output="/path/to/output", + resize=[1.0, 0.75, '0.5'], +) ``` + +From the **command line**, you can run: + +```bash dlc-live-benchmark /path/to/exported/model /path/to/video1 /path/to/video2 -o /path/to/output -r 1.0 0.75 0.5 ``` -2. Display keypoints to visually inspect the accuracy of exported models on different image sizes (note, this is slow and only for testing purposes): +#### Display keypoints to visually inspect the accuracy of exported models on different image sizes (note, this is slow and only for testing purposes): + +Inside a **python** shell or script, you can run: -##### python ```python -dlclive.benchmark_videos('/path/to/exported/model', '/path/to/video', resize=0.5, display=True, pcutoff=0.5, display_radius=4, cmap='bmy') -``` -##### command line +dlclive.benchmark_videos( + "/path/to/exported/model", + "/path/to/video", + resize=0.5, + display=True, + pcutoff=0.5, + display_radius=4, + cmap='bmy' +) ``` + +From the **command line**, you can run: + +```bash dlc-live-benchmark /path/to/exported/model /path/to/video -r 0.5 --display --pcutoff 0.5 --display-radius 4 --cmap bmy ``` -3. Analyze and create a labeled video using the exported model and desired resize parameters. This option functions similar to `deeplabcut.benchmark_videos` and `deeplabcut.create_labeled_video` (note, this is slow and only for testing purposes). +#### Analyze and create a labeled video using the exported model and desired resize parameters. + +This option functions similar to `deeplabcut.benchmark_videos` and +`deeplabcut.create_labeled_video` (note, this is slow and only for testing purposes). + +Inside a **python** shell or script, you can run: -##### python ```python -dlclive.benchmark_videos('/path/to/exported/model', '/path/to/video', resize=[1.0, 0.75, 0.5], pcutoff=0.5, display_radius=4, cmap='bmy', save_poses=True, save_video=True) +dlclive.benchmark_videos( + "/path/to/exported/model", + "/path/to/video", + resize=[1.0, 0.75, 0.5], + pcutoff=0.5, + display_radius=4, + cmap='bmy', + save_poses=True, + save_video=True, +) ``` -##### command line + +From the **command line**, you can run: + ``` dlc-live-benchmark /path/to/exported/model /path/to/video -r 0.5 --pcutoff 0.5 --display-radius 4 --cmap bmy --save-poses --save-video ``` ## License: -This project is licensed under the GNU AGPLv3. Note that the software is provided "as is", without warranty of any kind, express or implied. If you use the code or data, we ask that you please cite us! This software is available for licensing via the EPFL Technology Transfer Office (https://tto.epfl.ch/, info.tto@epfl.ch). +This project is licensed under the GNU AGPLv3. Note that the software is provided "as +is", without warranty of any kind, express or implied. If you use the code or data, we +ask that you please cite us! This software is available for licensing via the EPFL +Technology Transfer Office (https://tto.epfl.ch/, info.tto@epfl.ch). ## Community Support, Developers, & Help: -This is an actively developed package and we welcome community development and involvement. - -- If you want to contribute to the code, please read our guide [here](https://github.com/DeepLabCut/DeepLabCut/blob/master/CONTRIBUTING.md), which is provided at the main repository of DeepLabCut. - -- We are a community partner on the [![Image.sc forum](https://img.shields.io/badge/dynamic/json.svg?label=forum&url=https%3A%2F%2Fforum.image.sc%2Ftags%2Fdeeplabcut.json&query=%24.topic_list.tags.0.topic_count&colorB=brightgreen&&suffix=%20topics&logo=)](https://forum.image.sc/tags/deeplabcut). Please post help and support questions on the forum with the tag DeepLabCut. Check out their mission statement [Scientific Community Image Forum: A discussion forum for scientific image software](https://journals.plos.org/plosbiology/article?id=10.1371/journal.pbio.3000340). - -- If you encounter a previously unreported bug/code issue, please post here (we encourage you to search issues first): https://github.com/DeepLabCut/DeepLabCut-live/issues - -- For quick discussions here: [![Gitter](https://badges.gitter.im/DeepLabCut/community.svg)](https://gitter.im/DeepLabCut/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) +This is an actively developed package, and we welcome community development and +involvement. + +- If you want to contribute to the code, please read our guide [here]( +https://github.com/DeepLabCut/DeepLabCut/blob/master/CONTRIBUTING.md), which is provided +at the main repository of DeepLabCut. +- We are a community partner on the [![Image.sc forum](https://img.shields.io/badge/dynamic/json.svg?label=forum&url=https%3A%2F%2Fforum.image.sc%2Ftags%2Fdeeplabcut.json&query=%24.topic_list.tags.0.topic_count&colorB=brightgreen&&suffix=%20topics&logo=)](https://forum.image.sc/tags/deeplabcut). Please post help and +support questions on the forum with the tag DeepLabCut. Check out their mission +statement [Scientific Community Image Forum: A discussion forum for scientific image +software](https://journals.plos.org/plosbiology/article?id=10.1371/journal.pbio.3000340). +- If you encounter a previously unreported bug/code issue, please post here (we +encourage you to search issues first): [github.com/DeepLabCut/DeepLabCut-live/issues]( +https://github.com/DeepLabCut/DeepLabCut-live/issues) +- For quick discussions here: [![Gitter]( +https://badges.gitter.im/DeepLabCut/community.svg)]( +https://gitter.im/DeepLabCut/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) ### Reference: -If you utilize our tool, please [cite Kane et al, eLife 2020](https://elifesciences.org/articles/61909). The preprint is available here: https://www.biorxiv.org/content/10.1101/2020.08.04.236422v2 +If you utilize our tool, please [cite Kane et al, eLife 2020](https://elifesciences.org/articles/61909). The preprint is +available here: https://www.biorxiv.org/content/10.1101/2020.08.04.236422v2 ``` @Article{Kane2020dlclive, @@ -150,4 +276,3 @@ If you utilize our tool, please [cite Kane et al, eLife 2020](https://elifescien year = {2020}, } ``` - diff --git a/dlclive/__init__.py b/dlclive/__init__.py index 2eff208..71a89d9 100644 --- a/dlclive/__init__.py +++ b/dlclive/__init__.py @@ -5,7 +5,7 @@ Licensed under GNU Lesser General Public License v3.0 """ -from dlclive.version import __version__, VERSION +from dlclive.display import Display from dlclive.dlclive import DLCLive -from dlclive.processor import Processor -from dlclive.benchmark import benchmark, benchmark_videos, download_benchmarking_data +from dlclive.processor.processor import Processor +from dlclive.version import VERSION, __version__ diff --git a/dlclive/benchmark.py b/dlclive/benchmark.py index 4cb4fb1..2f0f2af 100644 --- a/dlclive/benchmark.py +++ b/dlclive/benchmark.py @@ -5,133 +5,93 @@ Licensed under GNU Lesser General Public License v3.0 """ - +import csv import platform -import os -import time +import subprocess import sys +import time import warnings -import subprocess -import typing -import pickle +from pathlib import Path + import colorcet as cc +import cv2 +import numpy as np +import torch from PIL import ImageColor -import ruamel +from pip._internal.operations import freeze try: - from pip._internal.operations import freeze -except ImportError: - from pip.operations import freeze - -from tqdm import tqdm -import numpy as np -import tensorflow as tf -import cv2 + import pandas as pd -from dlclive import DLCLive -from dlclive import VERSION -from dlclive import __file__ as dlcfile + has_pandas = True +except ModuleNotFoundError as err: + has_pandas = False -from dlclive.utils import decode_fourcc +try: + from tqdm import tqdm + has_tqdm = True +except ModuleNotFoundError as err: + has_tqdm = False -def download_benchmarking_data( - target_dir=".", - url="http://deeplabcut.rowland.harvard.edu/datasets/dlclivebenchmark.tar.gz", -): - """ - Downloads a DeepLabCut-Live benchmarking Data (videos & DLC models). - """ - import urllib.request - import tarfile - from tqdm import tqdm - def show_progress(count, block_size, total_size): - pbar.update(block_size) - - def tarfilenamecutting(tarf): - """' auxfun to extract folder path - ie. /xyz-trainsetxyshufflez/ - """ - for memberid, member in enumerate(tarf.getmembers()): - if memberid == 0: - parent = str(member.path) - l = len(parent) + 1 - if member.path.startswith(parent): - member.path = member.path[l:] - yield member - - response = urllib.request.urlopen(url) - print( - "Downloading the benchmarking data from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format( - url - ) - ) - total_size = int(response.getheader("Content-Length")) - pbar = tqdm(unit="B", total=total_size, position=0) - filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) - with tarfile.open(filename, mode="r:gz") as tar: - tar.extractall(target_dir, members=tarfilenamecutting(tar)) +from dlclive import DLCLive +from dlclive.utils import decode_fourcc +from dlclive.version import VERSION def get_system_info() -> dict: - """ Return summary info for system running benchmark + """ + Returns a summary of system information relevant to running benchmarking. + Returns ------- dict - Dictionary containing the following system information: - * ``host_name`` (str): name of machine - * ``op_sys`` (str): operating system - * ``python`` (str): path to python (which conda/virtual environment) - * ``device`` (tuple): (device type (``'GPU'`` or ``'CPU'```), device information) - * ``freeze`` (list): list of installed packages and versions - * ``python_version`` (str): python version - * ``git_hash`` (str, None): If installed from git repository, hash of HEAD commit - * ``dlclive_version`` (str): dlclive version from :data:`dlclive.VERSION` + A dictionary containing the following system information: + - host_name (str): Name of the machine. + - op_sys (str): Operating system. + - python (str): Path to the Python executable, indicating the conda/virtual + environment in use. + - device_type (str): Type of device used ('GPU' or 'CPU'). + - device (list): List containing the name of the GPU or CPU brand. + - freeze (list): List of installed Python packages with their versions. + - python_version (str): Version of Python in use. + - git_hash (str or None): If installed from git repository, hash of HEAD commit. + - dlclive_version (str): Version of the DLCLive package. """ - # get os - + # Get OS and host name op_sys = platform.platform() host_name = platform.node().replace(" ", "") - # A string giving the absolute path of the executable binary for the Python interpreter, on systems where this makes sense. + # Get Python executable path if platform.system() == "Windows": host_python = sys.executable.split(os.path.sep)[-2] else: host_python = sys.executable.split(os.path.sep)[-3] - # try to get git hash if possible - dlc_basedir = os.path.dirname(os.path.dirname(dlcfile)) + # Try to get git hash if possible git_hash = None + dlc_basedir = os.path.dirname(os.path.dirname(__file__)) try: - git_hash = subprocess.check_output( - ["git", "rev-parse", "HEAD"], cwd=dlc_basedir + git_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dlc_basedir) + .decode("utf-8") + .strip() ) - git_hash = git_hash.decode("utf-8").rstrip("\n") except subprocess.CalledProcessError: - # not installed from git repo, eg. pypi - # fine, pass quietly + # Not installed from git repo, e.g., pypi pass - # get device info (GPU or CPU) - dev = None - if tf.test.is_gpu_available(): - gpu_name = tf.test.gpu_device_name() - from tensorflow.python.client import device_lib - - dev_desc = [ - d.physical_device_desc - for d in device_lib.list_local_devices() - if d.name == gpu_name - ] - dev = [d.split(",")[1].split(":")[1].strip() for d in dev_desc] + # Get device info (GPU or CPU) + if torch.cuda.is_available(): dev_type = "GPU" + dev = [torch.cuda.get_device_name(torch.cuda.current_device())] else: from cpuinfo import get_cpu_info - dev = [get_cpu_info()["brand"]] dev_type = "CPU" + dev = [get_cpu_info()["brand_raw"]] return { "host_name": host_name, @@ -139,7 +99,6 @@ def get_system_info() -> dict: "python": host_python, "device_type": dev_type, "device": dev, - # pip freeze to get versions of all packages "freeze": list(freeze.freeze()), "python_version": sys.version, "git_hash": git_hash, @@ -148,66 +107,78 @@ def get_system_info() -> dict: def benchmark( - model_path, - video_path, - tf_config=None, - resize=None, - pixels=None, - cropping=None, - dynamic=(False, 0.5, 10), - n_frames=1000, - print_rate=False, - display=False, - pcutoff=0.0, - display_radius=3, - cmap="bmy", - save_poses=False, - save_video=False, - output=None, -) -> typing.Tuple[np.ndarray, tuple, bool, dict]: - """ Analyze DeepLabCut-live exported model on a video: - Calculate inference time, - display keypoints, or - get poses/create a labeled video + path: str | Path, + video_path: str | Path, + single_animal: bool = True, + resize: float | None = None, + pixels: int | None = None, + cropping: list[int] = None, + dynamic: tuple[bool, float, int] = (False, 0.5, 10), + n_frames: int = 1000, + print_rate: bool = False, + display: bool = False, + pcutoff: float = 0.0, + max_detections: int = 10, + display_radius: int = 3, + cmap: str = "bmy", + save_poses: bool = False, + save_video: bool = False, + output: str | Path | None = None, +) -> tuple[np.ndarray, tuple, dict]: + """Analyze DeepLabCut-live exported model on a video: + + Calculate inference time, display keypoints, or get poses/create a labeled video. Parameters ---------- - model_path : str + path : str path to exported DeepLabCut model video_path : str path to video file - tf_config : :class:`tensorflow.ConfigProto` - tensorflow session configuration + single_animal: bool + to make code behave like DLCLive for tensorflow models resize : int, optional - resize factor. Can only use one of resize or pixels. If both are provided, will use pixels. by default None + Resize factor. Can only use one of resize or pixels. If both are provided, will + use pixels. by default None pixels : int, optional - downsize image to this number of pixels, maintaining aspect ratio. Can only use one of resize or pixels. If both are provided, will use pixels. by default None + Downsize image to this number of pixels, maintaining aspect ratio. Can only use + one of resize or pixels. If both are provided, will use pixels. by default None cropping : list of int cropping parameters in pixel number: [x1, x2, y1, y2] dynamic: triple containing (state, detectiontreshold, margin) - If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), - then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is - expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. detectiontreshold), then object + boundaries are computed according to the smallest/largest x position and + smallest/largest y position of all body parts. This window is expanded by the + margin and from then on only the posture within this crop is analyzed (until the + object is lost, i.e. < detectiontreshold). The current position is utilized for + updating the crop window for the next frame (this is why the margin is important + and should be set large enough given the movement of the animal) n_frames : int, optional number of frames to run inference on, by default 1000 print_rate : bool, optional - flat to print inference rate frame by frame, by default False + flag to print inference rate frame by frame, by default False display : bool, optional - flag to display keypoints on images. Useful for checking the accuracy of exported models. + flag to display keypoints on images. Useful for checking the accuracy of + exported models. pcutoff : float, optional likelihood threshold to display keypoints + max_detections: int + for top-down models, the maximum number of individuals to detect in a frame display_radius : int, optional size (radius in pixels) of keypoint to display cmap : str, optional - a string indicating the :package:`colorcet` colormap, `options here `, by default "bmy" + a string indicating the :package:`colorcet` colormap, `options here + `, by default "bmy" save_poses : bool, optional - flag to save poses to an hdf5 file. If True, operates similar to :function:`DeepLabCut.benchmark_videos`, by default False + flag to save poses to an hdf5 file. If True, operates similar to + :function:`DeepLabCut.benchmark_videos`, by default False save_video : bool, optional - flag to save a labeled video. If True, operates similar to :function:`DeepLabCut.create_labeled_video`, by default False + flag to save a labeled video. If True, operates similar to + :function:`DeepLabCut.create_labeled_video`, by default False output : str, optional - path to directory to save pose and/or video file. If not specified, will use the directory of video_path, by default None + path to directory to save pose and/or video file. If not specified, will use + the directory of video_path, by default None Returns ------- @@ -215,8 +186,6 @@ def benchmark( vector of inference times tuple (image width, image height) - bool - tensorflow inference flag dict metadata for video @@ -234,10 +203,19 @@ def benchmark( Analyze a video (save poses to hdf5) and create a labeled video, similar to :function:`DeepLabCut.benchmark_videos` and :function:`create_labeled_video` dlclive.benchmark('/my/exported/model', 'my_video.avi', save_poses=True, save_video=True) """ + path = Path(path) + video_path = Path(video_path) + if not video_path.exists(): + raise ValueError(f"Could not find video: {video_path}: check that it exists!") - ### load video + if output is None: + output = video_path.parent + else: + output = Path(output) + output.mkdir(exist_ok=True, parents=True) - cap = cv2.VideoCapture(video_path) + # load video + cap = cv2.VideoCapture(str(video_path)) ret, frame = cap.read() n_frames = ( n_frames @@ -245,112 +223,107 @@ def benchmark( else (cap.get(cv2.CAP_PROP_FRAME_COUNT) - 1) ) n_frames = int(n_frames) - im_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - - ### get resize factor + im_size = ( + int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), + ) + # get resize factor if pixels is not None: resize = np.sqrt(pixels / (im_size[0] * im_size[1])) + if resize is not None: im_size = (int(im_size[0] * resize), int(im_size[1] * resize)) - ### create video writer - + # create video writer if save_video: colors = None - out_dir = ( - output - if output is not None - else os.path.dirname(os.path.realpath(video_path)) - ) - out_vid_base = os.path.basename(video_path) - out_vid_file = os.path.normpath( - f"{out_dir}/{os.path.splitext(out_vid_base)[0]}_DLCLIVE_LABELED.avi" - ) + out_vid_file = output / f"{video_path.stem}_DLCLIVE_LABELED.avi" fourcc = cv2.VideoWriter_fourcc(*"DIVX") fps = cap.get(cv2.CAP_PROP_FPS) - vwriter = cv2.VideoWriter(out_vid_file, fourcc, fps, im_size) - - ### check for pandas installation if using save_poses flag - - if save_poses: - try: - import pandas as pd - - use_pandas = True - except: - use_pandas = False - warnings.warn( - "Could not find installation of pandas; saving poses as a numpy array with the dimensions (n_frames, n_keypoints, [x, y, likelihood])." - ) - - ### initialize DLCLive and perform inference + print(out_vid_file) + print(fourcc) + print(fps) + print(im_size) + vid_writer = cv2.VideoWriter(str(out_vid_file), fourcc, fps, im_size) + # initialize DLCLive and perform inference inf_times = np.zeros(n_frames) poses = [] live = DLCLive( - model_path, - tf_config=tf_config, + model_path=path, + single_animal=single_animal, resize=resize, cropping=cropping, dynamic=dynamic, display=display, + max_detections=max_detections, pcutoff=pcutoff, display_radius=display_radius, display_cmap=cmap, ) poses.append(live.init_inference(frame)) - TFGPUinference = True if len(live.outputs) == 1 else False - iterator = range(n_frames) if (print_rate) or (display) else tqdm(range(n_frames)) - for i in iterator: + iterator = range(n_frames) + if print_rate or display: + iterator = tqdm(iterator) + for i in iterator: ret, frame = cap.read() - if not ret: warnings.warn( - "Did not complete {:d} frames. There probably were not enough frames in the video {}.".format( - n_frames, video_path - ) + f"Did not complete {n_frames:d} frames. There probably were not enough " + f"frames in the video {video_path}." ) break start_pose = time.time() poses.append(live.get_pose(frame)) inf_times[i] = time.time() - start_pose - if save_video: + this_pose = poses[-1] + + if single_animal: + # expand individual dimension + this_pose = this_pose[None] + + num_idv, num_bpt = this_pose.shape[:2] + num_colors = num_bpt if colors is None: all_colors = getattr(cc, cmap) colors = [ ImageColor.getcolor(c, "RGB")[::-1] - for c in all_colors[:: int(len(all_colors) / poses[-1].shape[0])] + for c in all_colors[:: int(len(all_colors) / num_colors)] ] - this_pose = poses[-1] - for j in range(this_pose.shape[0]): - if this_pose[j, 2] > pcutoff: - x = int(this_pose[j, 0]) - y = int(this_pose[j, 1]) - frame = cv2.circle( - frame, (x, y), display_radius, colors[j], thickness=-1 - ) + for j in range(num_idv): + for k in range(num_bpt): + color_idx = k + if this_pose[j, k, 2] > pcutoff: + x = int(this_pose[j, k, 0]) + y = int(this_pose[j, k, 1]) + frame = cv2.circle( + frame, + (x, y), + display_radius, + colors[color_idx], + thickness=-1, + ) if resize is not None: frame = cv2.resize(frame, im_size) - vwriter.write(frame) + vid_writer.write(frame) if print_rate: - print("pose rate = {:d}".format(int(1 / inf_times[i]))) + print(f"pose rate = {int(1 / inf_times[i]):d}") if print_rate: - print("mean pose rate = {:d}".format(int(np.mean(1 / inf_times)))) - - ### gather video and test parameterization + print(f"mean pose rate = {int(np.mean(1 / inf_times)):d}") + # gather video and test parameterization # dont want to fail here so gracefully failing on exception -- # eg. some packages of cv2 don't have CAP_PROP_CODEC_PIXEL_FORMAT try: @@ -360,17 +333,17 @@ def benchmark( try: fps = round(cap.get(cv2.CAP_PROP_FPS)) - except: + except Exception: fps = None try: pix_fmt = decode_fourcc(cap.get(cv2.CAP_PROP_CODEC_PIXEL_FORMAT)) - except: + except Exception: pix_fmt = "" try: frame_count = round(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - except: + except Exception: frame_count = None try: @@ -378,7 +351,7 @@ def benchmark( round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), ) - except: + except Exception: orig_im_size = None meta = { @@ -391,333 +364,408 @@ def benchmark( "dlclive_params": live.parameterization, } - ### close video and tensorflow session - + # close video cap.release() - live.close() - if save_video: - vwriter.release() + vid_writer.release() if save_poses: - - cfg_path = os.path.normpath(f"{model_path}/pose_cfg.yaml") - ruamel_file = ruamel.yaml.YAML() - dlc_cfg = ruamel_file.load(open(cfg_path, "r")) - bodyparts = dlc_cfg["all_joints_names"] - poses = np.array(poses) - - if use_pandas: - - poses = poses.reshape((poses.shape[0], poses.shape[1] * poses.shape[2])) - pdindex = pd.MultiIndex.from_product( - [bodyparts, ["x", "y", "likelihood"]], names=["bodyparts", "coords"] - ) - pose_df = pd.DataFrame(poses, columns=pdindex) - - out_dir = ( - output - if output is not None - else os.path.dirname(os.path.realpath(video_path)) - ) - out_vid_base = os.path.basename(video_path) - out_dlc_file = os.path.normpath( - f"{out_dir}/{os.path.splitext(out_vid_base)[0]}_DLCLIVE_POSES.h5" + bodyparts = live.cfg["metadata"]["bodyparts"] + max_idv = np.max([p.shape[0] for p in poses]) + + poses_array = -np.ones((len(poses), max_idv, len(bodyparts), 3)) + for i, p in enumerate(poses): + num_det = len(p) + poses_array[i, :num_det] = p + poses = poses_array + + num_frames, num_idv, num_bpts = poses.shape[:3] + individuals = [f"individual-{i}" for i in range(num_idv)] + + if has_pandas: + poses = poses.reshape((num_frames, num_idv * num_bpts * 3)) + col_index = pd.MultiIndex.from_product( + [individuals, bodyparts, ["x", "y", "likelihood"]], + names=["individual", "bodyparts", "coords"], ) - pose_df.to_hdf(out_dlc_file, key="df_with_missing", mode="w") + pose_df = pd.DataFrame(poses, columns=col_index) + + out_dlc_file = output / (video_path.stem + "_DLCLIVE_POSES.h5") + try: + pose_df.to_hdf(out_dlc_file, key="df_with_missing", mode="w") + except ImportError as err: + print( + "Cannot export predictions to H5 file. Install ``pytables`` extra " + f"to export to HDF: {err}" + ) + out_csv = Path(out_dlc_file).with_suffix(".csv") + pose_df.to_csv(out_csv) else: - - out_vid_base = os.path.basename(video_path) - out_dlc_file = os.path.normpath( - f"{out_dir}/{os.path.splitext(out_vid_base)[0]}_DLCLIVE_POSES.npy" + warnings.warn( + "Could not find installation of pandas; saving poses as a numpy array " + "with the dimensions (n_frames, n_keypoints, [x, y, likelihood])." ) - np.save(out_dlc_file, poses) + np.save(str(output / (video_path.stem + "_DLCLIVE_POSES.npy")), poses) - return inf_times, im_size, TFGPUinference, meta + return inf_times, im_size, meta -def save_inf_times( - sys_info, inf_times, im_size, TFGPUinference, model=None, meta=None, output=None +def benchmark_videos( + video_path: str, + model_path: str, + model_type: str, + device: str, + precision: str = "FP32", + display=True, + pcutoff=0.5, + display_radius=5, + resize=None, + cropping=None, # Adding cropping to the function parameters + dynamic=(False, 0.5, 10), + save_poses=False, + save_dir="model_predictions", + draw_keypoint_names=False, + cmap="bmy", + get_sys_info=True, + save_video=False, ): - """ Save inference time data collected using :function:`benchmark` with system information to a pickle file. - This is primarily used through :function:`benchmark_videos` - + """ + Analyzes a video to track keypoints using a DeepLabCut model, and optionally saves + the keypoint data and the labeled video. Parameters ---------- - sys_info : tuple - system information generated by :func:`get_system_info` - inf_times : :class:`numpy.ndarray` - array of inference times generated by :func:`benchmark` - im_size : tuple or :class:`numpy.ndarray` - image size (width, height) for each benchmark run. If an array, each row corresponds to a row in inf_times - TFGPUinference: bool - flag if using tensorflow inference or numpy inference DLC model - model: str, optional - name of model - meta : dict, optional - metadata returned by :func:`benchmark` - output : str, optional - path to directory to save data. If None, uses pwd, by default None + video_path : str + Path to the video file to be analyzed. + model_path : str + Path to the DeepLabCut model. + model_type : str + Type of the model (e.g., 'onnx'). + device : str + Device to run the model on ('cpu' or 'cuda'). + precision : str, optional, default='FP32' + Precision type for the model ('FP32' or 'FP16'). + display : bool, optional, default=True + Whether to display frame with labelled key points. + pcutoff : float, optional, default=0.5 + Probability cutoff below which keypoints are not visualized. + display_radius : int, optional, default=5 + Radius of circles drawn for keypoints on video frames. + resize : tuple of int (width, height) or None, optional + Resize dimensions for video frames. e.g. if resize = 0.5, the video will be + processed in half the original size. If None, no resizing is applied. + cropping : list of int or None, optional + Cropping parameters [x1, x2, y1, y2] in pixels. If None, no cropping is applied. + dynamic : tuple, optional, default=(False, 0.5, 10) (True/false), p cutoff, margin) + Parameters for dynamic cropping. If the state is true, then dynamic cropping + will be performed. That means that if an object is detected (i.e. any body part + > detectiontreshold), then object boundaries are computed according to the + smallest/largest x position and smallest/largest y position of all body parts. + This window is expanded by the margin and from then on only the posture within + this crop is analyzed (until the object is lost, i.e. detectiontreshold), - then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is - expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. `, by default "bmy" - save_poses : bool, optional - flag to save poses to an hdf5 file. If True, operates similar to :function:`DeepLabCut.benchmark_videos`, by default False - save_video : bool, optional - flag to save a labeled video. If True, operates similar to :function:`DeepLabCut.create_labeled_video`, by default False + # Get video writer setup + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + fps = cap.get(cv2.CAP_PROP_FPS) + frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + vwriter = cv2.VideoWriter( + filename=output_video_path, + fourcc=fourcc, + fps=fps, + frameSize=(frame_width, frame_height), + ) - Example - ------- - Return a vector of inference times for 10000 frames on one video or two videos: - dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', n_frames=10000) - dlclive.benchmark_videos('/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000) + while True: + ret, frame = cap.read() + if not ret: + break + # if frame_index == 0: + # pose = dlc_live.init_inference(frame) # load DLC model + try: + # pose = dlc_live.get_pose(frame) + if frame_index == 0: + # TODO trying to fix issues with dynamic cropping jumping back and forth + # between dyanmic cropped and original image + # dlc_live.dynamic = (False, dynamic[1], dynamic[2]) + pose, inf_time = dlc_live.init_inference(frame) # load DLC model + else: + # dlc_live.dynamic = dynamic + pose, inf_time = dlc_live.get_pose(frame) + except Exception as e: + print(f"Error analyzing frame {frame_index}: {e}") + continue + + poses.append({"frame": frame_index, "pose": pose}) + times.append(inf_time) + + if save_video: + # Visualize keypoints + this_pose = pose["poses"][0][0] + for j in range(this_pose.shape[0]): + if this_pose[j, 2] > pcutoff: + x, y = map(int, this_pose[j, :2]) + cv2.circle( + frame, + center=(x, y), + radius=display_radius, + color=colors[j], + thickness=-1, + ) - Return a vector of inference times, testing full size and resizing images to half the width and height for inference, for two videos - dlclive.benchmark_videos('/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000, resize=[1.0, 0.5]) + if draw_keypoint_names: + cv2.putText( + frame, + text=bodyparts[j], + org=(x + 10, y), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.5, + color=colors[j], + thickness=1, + lineType=cv2.LINE_AA, + ) + + vwriter.write(image=frame) + frame_index += 1 - Display keypoints to check the accuracy of an exported model - dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', display=True) + cap.release() + if save_video: + vwriter.release() - Analyze a video (save poses to hdf5) and create a labeled video, similar to :function:`DeepLabCut.benchmark_videos` and :function:`create_labeled_video` - dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', save_poses=True, save_video=True) - """ + if get_sys_info: + print(get_system_info()) - # convert video_paths to list + if save_poses: + save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp=timestamp) - video_path = video_path if type(video_path) is list else [video_path] + return poses, times - # fix resize - if pixels: - pixels = pixels if type(pixels) is list else [pixels] - resize = [None for p in pixels] - elif resize: - resize = resize if type(resize) is list else [resize] - pixels = [None for r in resize] - else: - resize = [None] - pixels = [None] - - # loop over videos - - for v in video_path: - - # initialize full inference times - - inf_times = [] - im_size_out = [] - - for i in range(len(resize)): - - print(f"\nRun {i+1} / {len(resize)}\n") - - this_inf_times, this_im_size, TFGPUinference, meta = benchmark( - model_path, - v, - tf_config=tf_config, - resize=resize[i], - pixels=pixels[i], - cropping=cropping, - dynamic=dynamic, - n_frames=n_frames, - print_rate=print_rate, - display=display, - pcutoff=pcutoff, - display_radius=display_radius, - cmap=cmap, - save_poses=save_poses, - save_video=save_video, - output=output, - ) +def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp): + """ + Saves the detected keypoint poses from the video to CSV and HDF5 files. - inf_times.append(this_inf_times) - im_size_out.append(this_im_size) + Parameters + ---------- + video_path : str + Path to the analyzed video file. + save_dir : str + Directory where the pose data files will be saved. + bodyparts : list of str + List of body part names corresponding to the keypoints. + poses : list of dict + List of dictionaries containing frame numbers and corresponding pose data. - inf_times = np.array(inf_times) - im_size_out = np.array(im_size_out) + Returns + ------- + None + """ - # save results + base_filename = os.path.splitext(os.path.basename(video_path))[0] + csv_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.csv") + h5_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.h5") - if output is not None: - sys_info = get_system_info() - save_inf_times( - sys_info, - inf_times, - im_size_out, - TFGPUinference, - model=os.path.basename(model_path), - meta=meta, - output=output, - ) + # Save to CSV + with open(csv_save_path, mode="w", newline="") as file: + writer = csv.writer(file) + header = ["frame"] + [ + f"{bp}_{axis}" for bp in bodyparts for axis in ["x", "y", "confidence"] + ] + writer.writerow(header) + for entry in poses: + frame_num = entry["frame"] + pose = entry["pose"]["poses"][0][0] + row = [frame_num] + [ + item.item() if isinstance(item, torch.Tensor) else item + for kp in pose + for item in kp + ] + writer.writerow(row) + + +import argparse +import os def main(): - """Provides a command line interface :function:`benchmark_videos` - """ + """Provides a command line interface to benchmark_videos function.""" + parser = argparse.ArgumentParser( + description="Analyze a video using a DeepLabCut model and visualize keypoints." + ) + parser.add_argument("model_path", type=str, help="Path to the model.") + parser.add_argument("video_path", type=str, help="Path to the video file.") + parser.add_argument("model_type", type=str, help="Type of the model (e.g., 'DLC').") + parser.add_argument( + "device", type=str, help="Device to run the model on (e.g., 'cuda' or 'cpu')." + ) + parser.add_argument( + "-p", + "--precision", + type=str, + default="FP32", + help="Model precision (e.g., 'FP32', 'FP16').", + ) + parser.add_argument( + "-d", "--display", action="store_true", help="Display keypoints on the video." + ) + parser.add_argument( + "-c", + "--pcutoff", + type=float, + default=0.5, + help="Probability cutoff for keypoints visualization.", + ) + parser.add_argument( + "-dr", + "--display-radius", + type=int, + default=5, + help="Radius of keypoint circles in the display.", + ) + parser.add_argument( + "-r", + "--resize", + type=int, + default=None, + help="Resize video frames to [width, height].", + ) + parser.add_argument( + "-x", + "--cropping", + type=int, + nargs=4, + default=None, + help="Cropping parameters [x1, x2, y1, y2].", + ) + parser.add_argument( + "-y", + "--dynamic", + type=float, + nargs=3, + default=[False, 0.5, 10], + help="Dynamic cropping [flag, pcutoff, margin].", + ) + parser.add_argument( + "--save-poses", action="store_true", help="Save the keypoint poses to files." + ) + parser.add_argument( + "--save-video", + action="store_true", + help="Save the output video with keypoints.", + ) + parser.add_argument( + "--save-dir", + type=str, + default="model_predictions", + help="Directory to save output files.", + ) + parser.add_argument( + "--draw-keypoint-names", + action="store_true", + help="Draw keypoint names on the video.", + ) + parser.add_argument( + "--cmap", type=str, default="bmy", help="Colormap for keypoints visualization." + ) + parser.add_argument( + "--no-sys-info", + action="store_false", + help="Do not print system info.", + dest="get_sys_info", + ) - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("model_path", type=str) - parser.add_argument("video_path", type=str, nargs="+") - parser.add_argument("-o", "--output", type=str, default=None) - parser.add_argument("-n", "--n-frames", type=int, default=1000) - parser.add_argument("-r", "--resize", type=float, nargs="+") - parser.add_argument("-p", "--pixels", type=float, nargs="+") - parser.add_argument("-v", "--print-rate", default=False, action="store_true") - parser.add_argument("-d", "--display", default=False, action="store_true") - parser.add_argument("-l", "--pcutoff", default=0.5, type=float) - parser.add_argument("-s", "--display-radius", default=3, type=int) - parser.add_argument("-c", "--cmap", type=str, default="bmy") - parser.add_argument("--cropping", nargs="+", type=int, default=None) - parser.add_argument("--dynamic", nargs="+", type=float, default=[]) - parser.add_argument("--save-poses", action="store_true") - parser.add_argument("--save-video", action="store_true") args = parser.parse_args() - if (args.cropping) and (len(args.cropping) < 4): - raise Exception( - "Cropping not properly specified. Must provide 4 values: x1, x2, y1, y2" - ) - - if not args.dynamic: - args.dynamic = (False, 0.5, 10) - elif len(args.dynamic) < 3: - raise Exception( - "Dynamic cropping not properly specified. Must provide three values: 0 or 1 as boolean flag, pcutoff, and margin" - ) - else: - args.dynamic = (bool(args.dynamic[0]), args.dynamic[1], args.dynamic[2]) - + # Call the benchmark_videos function with the parsed arguments benchmark_videos( - args.model_path, - args.video_path, - output=args.output, - resize=args.resize, - pixels=args.pixels, - cropping=args.cropping, - dynamic=args.dynamic, - n_frames=args.n_frames, - print_rate=args.print_rate, + video_path=args.video_path, + model_path=args.model_path, + model_type=args.model_type, + device=args.device, + precision=args.precision, display=args.display, pcutoff=args.pcutoff, display_radius=args.display_radius, - cmap=args.cmap, + resize=tuple(args.resize) if args.resize else None, + cropping=args.cropping, + dynamic=tuple(args.dynamic), save_poses=args.save_poses, + save_dir=args.save_dir, + draw_keypoint_names=args.draw_keypoint_names, + cmap=args.cmap, + get_sys_info=args.get_sys_info, save_video=args.save_video, ) diff --git a/dlclive/benchmark_pytorch.py b/dlclive/benchmark_pytorch.py new file mode 100644 index 0000000..bd5826f --- /dev/null +++ b/dlclive/benchmark_pytorch.py @@ -0,0 +1,484 @@ +import csv +import os +import platform +import subprocess +import sys +import time + +import colorcet as cc +import cv2 +import h5py +import numpy as np +import torch +from PIL import ImageColor +from pip._internal.operations import freeze + +from dlclive import DLCLive +from dlclive.version import VERSION + + +def get_system_info() -> dict: + """ + Returns a summary of system information relevant to running benchmarking. + + Returns + ------- + dict + A dictionary containing the following system information: + - host_name (str): Name of the machine. + - op_sys (str): Operating system. + - python (str): Path to the Python executable, indicating the conda/virtual environment in use. + - device_type (str): Type of device used ('GPU' or 'CPU'). + - device (list): List containing the name of the GPU or CPU brand. + - freeze (list): List of installed Python packages with their versions. + - python_version (str): Version of Python in use. + - git_hash (str or None): If installed from git repository, hash of HEAD commit. + - dlclive_version (str): Version of the DLCLive package. + """ + + # Get OS and host name + op_sys = platform.platform() + host_name = platform.node().replace(" ", "") + + # Get Python executable path + if platform.system() == "Windows": + host_python = sys.executable.split(os.path.sep)[-2] + else: + host_python = sys.executable.split(os.path.sep)[-3] + + # Try to get git hash if possible + git_hash = None + dlc_basedir = os.path.dirname(os.path.dirname(__file__)) + try: + git_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dlc_basedir) + .decode("utf-8") + .strip() + ) + except subprocess.CalledProcessError: + # Not installed from git repo, e.g., pypi + pass + + # Get device info (GPU or CPU) + if torch.cuda.is_available(): + dev_type = "GPU" + dev = [torch.cuda.get_device_name(torch.cuda.current_device())] + else: + from cpuinfo import get_cpu_info + + dev_type = "CPU" + dev = [get_cpu_info()["brand_raw"]] + + return { + "host_name": host_name, + "op_sys": op_sys, + "python": host_python, + "device_type": dev_type, + "device": dev, + "freeze": list(freeze.freeze()), + "python_version": sys.version, + "git_hash": git_hash, + "dlclive_version": VERSION, + } + + +def analyze_video( + video_path: str, + model_path: str, + model_type: str, + device: str, + precision: str = "FP32", + snapshot: str = None, + display=True, + pcutoff=0.5, + display_radius=5, + resize=None, + cropping=None, # Adding cropping to the function parameters + dynamic=(False, 0.5, 10), + save_poses=False, + save_dir="model_predictions", + draw_keypoint_names=False, + cmap="bmy", + get_sys_info=True, + save_video=False, +): + """ + Analyzes a video to track keypoints using a DeepLabCut model, and optionally saves the keypoint data and the labeled video. + + Parameters + ---------- + video_path : str + Path to the video file to be analyzed. + model_path : str + Path to the DeepLabCut model. + model_type : str + Type of the model (e.g., 'onnx'). + device : str + Device to run the model on ('cpu' or 'cuda'). + precision : str, optional, default='FP32' + Precision type for the model ('FP32' or 'FP16'). + snapshot : str, optional + Snapshot to use for the model, if using pytorch as model type. + display : bool, optional, default=True + Whether to display frame with labelled key points. + pcutoff : float, optional, default=0.5 + Probability cutoff below which keypoints are not visualized. + display_radius : int, optional, default=5 + Radius of circles drawn for keypoints on video frames. + resize : tuple of int (width, height) or None, optional + Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied. + cropping : list of int or None, optional + Cropping parameters [x1, x2, y1, y2] in pixels. If None, no cropping is applied. + dynamic : tuple, optional, default=(False, 0.5, 10) (True/false), p cutoff, margin) + Parameters for dynamic cropping. If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. pcutoff: + x, y = map(int, this_pose[j, :2]) + cv2.circle( + frame, + center=(x, y), + radius=display_radius, + color=colors[j], + thickness=-1, + ) + + if draw_keypoint_names: + cv2.putText( + frame, + text=bodyparts[j], + org=(x + 10, y), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.5, + color=colors[j], + thickness=1, + lineType=cv2.LINE_AA, + ) + + vwriter.write(image=frame) + frame_index += 1 + + cap.release() + if save_video: + vwriter.release() + + if get_sys_info: + print(get_system_info()) + + if save_poses: + save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp=timestamp) + + return poses, times + + +def save_poses_to_files(video_path, save_dir, bodyparts, poses, timestamp): + """ + Saves the detected keypoint poses from the video to CSV and HDF5 files. + + Parameters + ---------- + video_path : str + Path to the analyzed video file. + save_dir : str + Directory where the pose data files will be saved. + bodyparts : list of str + List of body part names corresponding to the keypoints. + poses : list of dict + List of dictionaries containing frame numbers and corresponding pose data. + + Returns + ------- + None + """ + + base_filename = os.path.splitext(os.path.basename(video_path))[0] + csv_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.csv") + h5_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.h5") + + # Save to CSV + with open(csv_save_path, mode="w", newline="") as file: + writer = csv.writer(file) + header = ["frame"] + [ + f"{bp}_{axis}" for bp in bodyparts for axis in ["x", "y", "confidence"] + ] + writer.writerow(header) + for entry in poses: + frame_num = entry["frame"] + pose = entry["pose"]["poses"][0][0] + row = [frame_num] + [ + item.item() if isinstance(item, torch.Tensor) else item + for kp in pose + for item in kp + ] + writer.writerow(row) + + # Save to HDF5 + with h5py.File(h5_save_path, "w") as hf: + hf.create_dataset(name="frames", data=[entry["frame"] for entry in poses]) + for i, bp in enumerate(bodyparts): + hf.create_dataset( + name=f"{bp}_x", + data=[ + ( + entry["pose"]["poses"][0][0][i, 0].item() + if isinstance(entry["pose"]["poses"][0][0][i, 0], torch.Tensor) + else entry["pose"]["poses"][0][0][i, 0] + ) + for entry in poses + ], + ) + hf.create_dataset( + name=f"{bp}_y", + data=[ + ( + entry["pose"]["poses"][0][0][i, 1].item() + if isinstance(entry["pose"]["poses"][0][0][i, 1], torch.Tensor) + else entry["pose"]["poses"][0][0][i, 1] + ) + for entry in poses + ], + ) + hf.create_dataset( + name=f"{bp}_confidence", + data=[ + ( + entry["pose"]["poses"][0][0][i, 2].item() + if isinstance(entry["pose"]["poses"][0][0][i, 2], torch.Tensor) + else entry["pose"]["poses"][0][0][i, 2] + ) + for entry in poses + ], + ) + + +import argparse +import os + + +def main(): + """Provides a command line interface to analyze_video function.""" + + parser = argparse.ArgumentParser( + description="Analyze a video using a DeepLabCut model and visualize keypoints." + ) + parser.add_argument("model_path", type=str, help="Path to the model.") + parser.add_argument("video_path", type=str, help="Path to the video file.") + parser.add_argument("model_type", type=str, help="Type of the model (e.g., 'DLC').") + parser.add_argument( + "device", type=str, help="Device to run the model on (e.g., 'cuda' or 'cpu')." + ) + parser.add_argument( + "-p", + "--precision", + type=str, + default="FP32", + help="Model precision (e.g., 'FP32', 'FP16').", + ) + parser.add_argument( + "-s", + "--snapshot", + type=str, + default=None, + help="Path to a specific model snapshot.", + ) + parser.add_argument( + "-d", "--display", action="store_true", help="Display keypoints on the video." + ) + parser.add_argument( + "-c", + "--pcutoff", + type=float, + default=0.5, + help="Probability cutoff for keypoints visualization.", + ) + parser.add_argument( + "-dr", + "--display-radius", + type=int, + default=5, + help="Radius of keypoint circles in the display.", + ) + parser.add_argument( + "-r", + "--resize", + type=int, + default=None, + help="Resize video frames to [width, height].", + ) + parser.add_argument( + "-x", + "--cropping", + type=int, + nargs=4, + default=None, + help="Cropping parameters [x1, x2, y1, y2].", + ) + parser.add_argument( + "-y", + "--dynamic", + type=float, + nargs=3, + default=[False, 0.5, 10], + help="Dynamic cropping [flag, pcutoff, margin].", + ) + parser.add_argument( + "--save-poses", action="store_true", help="Save the keypoint poses to files." + ) + parser.add_argument( + "--save-video", + action="store_true", + help="Save the output video with keypoints.", + ) + parser.add_argument( + "--save-dir", + type=str, + default="model_predictions", + help="Directory to save output files.", + ) + parser.add_argument( + "--draw-keypoint-names", + action="store_true", + help="Draw keypoint names on the video.", + ) + parser.add_argument( + "--cmap", type=str, default="bmy", help="Colormap for keypoints visualization." + ) + parser.add_argument( + "--no-sys-info", + action="store_false", + help="Do not print system info.", + dest="get_sys_info", + ) + + args = parser.parse_args() + + # Call the analyze_video function with the parsed arguments + analyze_video( + video_path=args.video_path, + model_path=args.model_path, + model_type=args.model_type, + device=args.device, + precision=args.precision, + snapshot=args.snapshot, + display=args.display, + pcutoff=args.pcutoff, + display_radius=args.display_radius, + resize=tuple(args.resize) if args.resize else None, + cropping=args.cropping, + dynamic=tuple(args.dynamic), + save_poses=args.save_poses, + save_dir=args.save_dir, + draw_keypoint_names=args.draw_keypoint_names, + cmap=args.cmap, + get_sys_info=args.get_sys_info, + save_video=args.save_video, + ) + + +if __name__ == "__main__": + main() diff --git a/dlclive/benchmark_tf.py b/dlclive/benchmark_tf.py new file mode 100644 index 0000000..d955496 --- /dev/null +++ b/dlclive/benchmark_tf.py @@ -0,0 +1,717 @@ +""" +DeepLabCut Toolbox (deeplabcut.org) +© A. & M. Mathis Labs + +Licensed under GNU Lesser General Public License v3.0 +""" + +import os +import pickle +import platform +import subprocess +import sys +import time +import typing +import warnings + +import colorcet as cc +import ruamel +from PIL import ImageColor + +try: + from pip._internal.operations import freeze +except ImportError: + from pip.operations import freeze + +import cv2 +import numpy as np +import tensorflow as tf +from dlclive import VERSION, DLCLive +from dlclive import __file__ as dlcfile +from dlclive.utils import decode_fourcc +from tqdm import tqdm + + +def download_benchmarking_data( + target_dir=".", + url="http://deeplabcut.rowland.harvard.edu/datasets/dlclivebenchmark.tar.gz", +): + """ + Downloads a DeepLabCut-Live benchmarking Data (videos & DLC models). + """ + import tarfile + import urllib.request + + from tqdm import tqdm + + def show_progress(count, block_size, total_size): + pbar.update(block_size) + + def tarfilenamecutting(tarf): + """' auxfun to extract folder path + ie. /xyz-trainsetxyshufflez/ + """ + for memberid, member in enumerate(tarf.getmembers()): + if memberid == 0: + parent = str(member.path) + l = len(parent) + 1 + if member.path.startswith(parent): + member.path = member.path[l:] + yield member + + response = urllib.request.urlopen(url) + print( + "Downloading the benchmarking data from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format( + url + ) + ) + total_size = int(response.getheader("Content-Length")) + pbar = tqdm(unit="B", total=total_size, position=0) + filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress) + with tarfile.open(filename, mode="r:gz") as tar: + tar.extractall(target_dir, members=tarfilenamecutting(tar)) + + +def get_system_info() -> dict: + """Return summary info for system running benchmark + Returns + ------- + dict + Dictionary containing the following system information: + * ``host_name`` (str): name of machine + * ``op_sys`` (str): operating system + * ``python`` (str): path to python (which conda/virtual environment) + * ``device`` (tuple): (device type (``'GPU'`` or ``'CPU'```), device information) + * ``freeze`` (list): list of installed packages and versions + * ``python_version`` (str): python version + * ``git_hash`` (str, None): If installed from git repository, hash of HEAD commit + * ``dlclive_version`` (str): dlclive version from :data:`dlclive.VERSION` + """ + + # get os + + op_sys = platform.platform() + host_name = platform.node().replace(" ", "") + + # A string giving the absolute path of the executable binary for the Python interpreter, on systems where this makes sense. + if platform.system() == "Windows": + host_python = sys.executable.split(os.path.sep)[-2] + else: + host_python = sys.executable.split(os.path.sep)[-3] + + # try to get git hash if possible + dlc_basedir = os.path.dirname(os.path.dirname(dlcfile)) + git_hash = None + try: + git_hash = subprocess.check_output( + ["git", "rev-parse", "HEAD"], cwd=dlc_basedir + ) + git_hash = git_hash.decode("utf-8").rstrip("\n") + except subprocess.CalledProcessError: + # not installed from git repo, eg. pypi + # fine, pass quietly + pass + + # get device info (GPU or CPU) + dev = None + if tf.test.is_gpu_available(): + gpu_name = tf.test.gpu_device_name() + from tensorflow.python.client import device_lib + + dev_desc = [ + d.physical_device_desc + for d in device_lib.list_local_devices() + if d.name == gpu_name + ] + dev = [d.split(",")[1].split(":")[1].strip() for d in dev_desc] + dev_type = "GPU" + else: + from cpuinfo import get_cpu_info + + dev = [get_cpu_info()["brand"]] + dev_type = "CPU" + + return { + "host_name": host_name, + "op_sys": op_sys, + "python": host_python, + "device_type": dev_type, + "device": dev, + # pip freeze to get versions of all packages + "freeze": list(freeze.freeze()), + "python_version": sys.version, + "git_hash": git_hash, + "dlclive_version": VERSION, + } + + +def benchmark( + model_path, + video_path, + tf_config=None, + resize=None, + pixels=None, + cropping=None, + dynamic=(False, 0.5, 10), + n_frames=1000, + print_rate=False, + display=False, + pcutoff=0.0, + display_radius=3, + cmap="bmy", + save_poses=False, + save_video=False, + output=None, +) -> typing.Tuple[np.ndarray, tuple, bool, dict]: + """Analyze DeepLabCut-live exported model on a video: + Calculate inference time, + display keypoints, or + get poses/create a labeled video + + Parameters + ---------- + model_path : str + path to exported DeepLabCut model + video_path : str + path to video file + tf_config : :class:`tensorflow.ConfigProto` + tensorflow session configuration + resize : int, optional + resize factor. Can only use one of resize or pixels. If both are provided, will use pixels. by default None + pixels : int, optional + downsize image to this number of pixels, maintaining aspect ratio. Can only use one of resize or pixels. If both are provided, will use pixels. by default None + cropping : list of int + cropping parameters in pixel number: [x1, x2, y1, y2] + dynamic: triple containing (state, detectiontreshold, margin) + If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), + then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is + expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. `, by default "bmy" + save_poses : bool, optional + flag to save poses to an hdf5 file. If True, operates similar to :function:`DeepLabCut.benchmark_videos`, by default False + save_video : bool, optional + flag to save a labeled video. If True, operates similar to :function:`DeepLabCut.create_labeled_video`, by default False + output : str, optional + path to directory to save pose and/or video file. If not specified, will use the directory of video_path, by default None + + Returns + ------- + :class:`numpy.ndarray` + vector of inference times + tuple + (image width, image height) + bool + tensorflow inference flag + dict + metadata for video + + Example + ------- + Return a vector of inference times for 10000 frames: + dlclive.benchmark('/my/exported/model', 'my_video.avi', n_frames=10000) + + Return a vector of inference times, resizing images to half the width and height for inference + dlclive.benchmark('/my/exported/model', 'my_video.avi', n_frames=10000, resize=0.5) + + Display keypoints to check the accuracy of an exported model + dlclive.benchmark('/my/exported/model', 'my_video.avi', display=True) + + Analyze a video (save poses to hdf5) and create a labeled video, similar to :function:`DeepLabCut.benchmark_videos` and :function:`create_labeled_video` + dlclive.benchmark('/my/exported/model', 'my_video.avi', save_poses=True, save_video=True) + """ + + ### load video + + cap = cv2.VideoCapture(video_path) + ret, frame = cap.read() + n_frames = ( + n_frames + if (n_frames > 0) and (n_frames < cap.get(cv2.CAP_PROP_FRAME_COUNT) - 1) + else (cap.get(cv2.CAP_PROP_FRAME_COUNT) - 1) + ) + n_frames = int(n_frames) + im_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + ### get resize factor + + if pixels is not None: + resize = np.sqrt(pixels / (im_size[0] * im_size[1])) + if resize is not None: + im_size = (int(im_size[0] * resize), int(im_size[1] * resize)) + + ### create video writer + + if save_video: + colors = None + out_dir = ( + output + if output is not None + else os.path.dirname(os.path.realpath(video_path)) + ) + out_vid_base = os.path.basename(video_path) + out_vid_file = os.path.normpath( + f"{out_dir}/{os.path.splitext(out_vid_base)[0]}_DLCLIVE_LABELED.avi" + ) + fourcc = cv2.VideoWriter_fourcc(*"DIVX") + fps = cap.get(cv2.CAP_PROP_FPS) + vwriter = cv2.VideoWriter(out_vid_file, fourcc, fps, im_size) + + ### check for pandas installation if using save_poses flag + + if save_poses: + try: + import pandas as pd + + use_pandas = True + except: + use_pandas = False + warnings.warn( + "Could not find installation of pandas; saving poses as a numpy array with the dimensions (n_frames, n_keypoints, [x, y, likelihood])." + ) + + ### initialize DLCLive and perform inference + + inf_times = np.zeros(n_frames) + poses = [] + + live = DLCLive( + model_path, + model_type="base", + tf_config=tf_config, + resize=resize, + cropping=cropping, + dynamic=dynamic, + display=display, + pcutoff=pcutoff, + display_radius=display_radius, + display_cmap=cmap, + ) + + poses.append(live.init_inference(frame)) + TFGPUinference = True if len(live.runner.outputs) == 1 else False + + iterator = range(n_frames) if (print_rate) or (display) else tqdm(range(n_frames)) + for i in iterator: + ret, frame = cap.read() + + if not ret: + warnings.warn( + "Did not complete {:d} frames. There probably were not enough frames in the video {}.".format( + n_frames, video_path + ) + ) + break + + start_pose = time.time() + poses.append(live.get_pose(frame)) + inf_times[i] = time.time() - start_pose + + if save_video: + if colors is None: + all_colors = getattr(cc, cmap) + colors = [ + ImageColor.getcolor(c, "RGB")[::-1] + for c in all_colors[:: int(len(all_colors) / poses[-1].shape[0])] + ] + + this_pose = poses[-1] + for j in range(this_pose.shape[0]): + if this_pose[j, 2] > pcutoff: + x = int(this_pose[j, 0]) + y = int(this_pose[j, 1]) + frame = cv2.circle( + frame, (x, y), display_radius, colors[j], thickness=-1 + ) + + if resize is not None: + frame = cv2.resize(frame, im_size) + vwriter.write(frame) + + if print_rate: + print("pose rate = {:d}".format(int(1 / inf_times[i]))) + + if print_rate: + print("mean pose rate = {:d}".format(int(np.mean(1 / inf_times)))) + + ### gather video and test parameterization + + # dont want to fail here so gracefully failing on exception -- + # eg. some packages of cv2 don't have CAP_PROP_CODEC_PIXEL_FORMAT + try: + fourcc = decode_fourcc(cap.get(cv2.CAP_PROP_FOURCC)) + except: + fourcc = "" + + try: + fps = round(cap.get(cv2.CAP_PROP_FPS)) + except: + fps = None + + try: + pix_fmt = decode_fourcc(cap.get(cv2.CAP_PROP_CODEC_PIXEL_FORMAT)) + except: + pix_fmt = "" + + try: + frame_count = round(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + except: + frame_count = None + + try: + orig_im_size = ( + round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), + ) + except: + orig_im_size = None + + meta = { + "video_path": video_path, + "video_codec": fourcc, + "video_pixel_format": pix_fmt, + "video_fps": fps, + "video_total_frames": frame_count, + "original_frame_size": orig_im_size, + "dlclive_params": live.parameterization, + } + + ### close video and tensorflow session + + cap.release() + live.close() + + if save_video: + vwriter.release() + + if save_poses: + cfg_path = os.path.normpath(f"{model_path}/pose_cfg.yaml") + ruamel_file = ruamel.yaml.YAML() + dlc_cfg = ruamel_file.load(open(cfg_path, "r")) + bodyparts = dlc_cfg["all_joints_names"] + poses = np.array(poses) + + if use_pandas: + poses = poses.reshape((poses.shape[0], poses.shape[1] * poses.shape[2])) + pdindex = pd.MultiIndex.from_product( + [bodyparts, ["x", "y", "likelihood"]], names=["bodyparts", "coords"] + ) + pose_df = pd.DataFrame(poses, columns=pdindex) + + out_dir = ( + output + if output is not None + else os.path.dirname(os.path.realpath(video_path)) + ) + out_vid_base = os.path.basename(video_path) + out_dlc_file = os.path.normpath( + f"{out_dir}/{os.path.splitext(out_vid_base)[0]}_DLCLIVE_POSES.h5" + ) + pose_df.to_hdf(out_dlc_file, key="df_with_missing", mode="w") + + else: + out_vid_base = os.path.basename(video_path) + out_dlc_file = os.path.normpath( + f"{out_dir}/{os.path.splitext(out_vid_base)[0]}_DLCLIVE_POSES.npy" + ) + np.save(out_dlc_file, poses) + + return inf_times, im_size, TFGPUinference, meta + + +def save_inf_times( + sys_info, inf_times, im_size, TFGPUinference, model=None, meta=None, output=None +): + """Save inference time data collected using :function:`benchmark` with system information to a pickle file. + This is primarily used through :function:`benchmark_videos` + + + Parameters + ---------- + sys_info : tuple + system information generated by :func:`get_system_info` + inf_times : :class:`numpy.ndarray` + array of inference times generated by :func:`benchmark` + im_size : tuple or :class:`numpy.ndarray` + image size (width, height) for each benchmark run. If an array, each row corresponds to a row in inf_times + TFGPUinference: bool + flag if using tensorflow inference or numpy inference DLC model + model: str, optional + name of model + meta : dict, optional + metadata returned by :func:`benchmark` + output : str, optional + path to directory to save data. If None, uses pwd, by default None + + Returns + ------- + bool + flag indicating successful save + """ + + output = output if output is not None else os.getcwd() + model_type = None + if model is not None: + if "resnet" in model: + model_type = "resnet" + elif "mobilenet" in model: + model_type = "mobilenet" + else: + model_type = None + + fn_ind = 0 + base_name = ( + f"benchmark_{sys_info['host_name']}_{sys_info['device_type']}_{fn_ind}.pickle" + ) + out_file = os.path.normpath(f"{output}/{base_name}") + while os.path.isfile(out_file): + fn_ind += 1 + base_name = f"benchmark_{sys_info['host_name']}_{sys_info['device_type']}_{fn_ind}.pickle" + out_file = os.path.normpath(f"{output}/{base_name}") + + # summary stats (mean inference time & standard error of mean) + stats = zip( + np.mean(inf_times, 1), + np.std(inf_times, 1) * 1.0 / np.sqrt(np.shape(inf_times)[1]), + ) + + # for stat in stats: + # print("Stats:", stat) + + data = { + "model": model, + "model_type": model_type, + "TFGPUinference": TFGPUinference, + "im_size": im_size, + "inference_times": inf_times, + "stats": stats, + } + + data.update(sys_info) + if meta: + data.update(meta) + + os.makedirs(os.path.normpath(output), exist_ok=True) + pickle.dump(data, open(out_file, "wb")) + + return True + + +def benchmark_videos( + model_path, + video_path, + output=None, + n_frames=1000, + tf_config=None, + resize=None, + pixels=None, + cropping=None, + dynamic=(False, 0.5, 10), + print_rate=False, + display=False, + pcutoff=0.5, + display_radius=3, + cmap="bmy", + save_poses=False, + save_video=False, +): + """Analyze videos using DeepLabCut-live exported models. + Analyze multiple videos and/or multiple options for the size of the video + by specifying a resizing factor or the number of pixels to use in the image (keeping aspect ratio constant). + Options to record inference times (to examine inference speed), + display keypoints to visually check the accuracy, + or save poses to an hdf5 file as in :function:`deeplabcut.benchmark_videos` and + create a labeled video as in :function:`deeplabcut.create_labeled_video`. + + Parameters + ---------- + model_path : str + path to exported DeepLabCut model + video_path : str or list + path to video file or list of paths to video files + output : str + path to directory to save results + tf_config : :class:`tensorflow.ConfigProto` + tensorflow session configuration + resize : int, optional + resize factor. Can only use one of resize or pixels. If both are provided, will use pixels. by default None + pixels : int, optional + downsize image to this number of pixels, maintaining aspect ratio. Can only use one of resize or pixels. If both are provided, will use pixels. by default None + cropping : list of int + cropping parameters in pixel number: [x1, x2, y1, y2] + dynamic: triple containing (state, detectiontreshold, margin) + If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), + then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is + expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. `, by default "bmy" + save_poses : bool, optional + flag to save poses to an hdf5 file. If True, operates similar to :function:`DeepLabCut.benchmark_videos`, by default False + save_video : bool, optional + flag to save a labeled video. If True, operates similar to :function:`DeepLabCut.create_labeled_video`, by default False + + Example + ------- + Return a vector of inference times for 10000 frames on one video or two videos: + dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', n_frames=10000) + dlclive.benchmark_videos('/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000) + + Return a vector of inference times, testing full size and resizing images to half the width and height for inference, for two videos + dlclive.benchmark_videos('/my/exported/model', ['my_video1.avi', 'my_video2.avi'], n_frames=10000, resize=[1.0, 0.5]) + + Display keypoints to check the accuracy of an exported model + dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', display=True) + + Analyze a video (save poses to hdf5) and create a labeled video, similar to :function:`DeepLabCut.benchmark_videos` and :function:`create_labeled_video` + dlclive.benchmark_videos('/my/exported/model', 'my_video.avi', save_poses=True, save_video=True) + """ + + # convert video_paths to list + + video_path = video_path if type(video_path) is list else [video_path] + + # fix resize + + if pixels: + pixels = pixels if type(pixels) is list else [pixels] + resize = [None for p in pixels] + elif resize: + resize = resize if type(resize) is list else [resize] + pixels = [None for r in resize] + else: + resize = [None] + pixels = [None] + + # loop over videos + + for v in video_path: + # initialize full inference times + + inf_times = [] + im_size_out = [] + + for i in range(len(resize)): + print(f"\nRun {i+1} / {len(resize)}\n") + + this_inf_times, this_im_size, TFGPUinference, meta = benchmark( + model_path, + v, + tf_config=tf_config, + resize=resize[i], + pixels=pixels[i], + cropping=cropping, + dynamic=dynamic, + n_frames=n_frames, + print_rate=print_rate, + display=display, + pcutoff=pcutoff, + display_radius=display_radius, + cmap=cmap, + save_poses=save_poses, + save_video=save_video, + output=output, + ) + + inf_times.append(this_inf_times) + im_size_out.append(this_im_size) + + inf_times = np.array(inf_times) + im_size_out = np.array(im_size_out) + + # save results + + if output is not None: + sys_info = get_system_info() + save_inf_times( + sys_info, + inf_times, + im_size_out, + TFGPUinference, + model=os.path.basename(model_path), + meta=meta, + output=output, + ) + + +def main(): + """Provides a command line interface :function:`benchmark_videos`""" + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("model_path", type=str) + parser.add_argument("video_path", type=str, nargs="+") + parser.add_argument("-o", "--output", type=str, default=None) + parser.add_argument("-n", "--n-frames", type=int, default=1000) + parser.add_argument("-r", "--resize", type=float, nargs="+") + parser.add_argument("-p", "--pixels", type=float, nargs="+") + parser.add_argument("-v", "--print-rate", default=False, action="store_true") + parser.add_argument("-d", "--display", default=False, action="store_true") + parser.add_argument("-l", "--pcutoff", default=0.5, type=float) + parser.add_argument("-s", "--display-radius", default=3, type=int) + parser.add_argument("-c", "--cmap", type=str, default="bmy") + parser.add_argument("--cropping", nargs="+", type=int, default=None) + parser.add_argument("--dynamic", nargs="+", type=float, default=[]) + parser.add_argument("--save-poses", action="store_true") + parser.add_argument("--save-video", action="store_true") + args = parser.parse_args() + + if (args.cropping) and (len(args.cropping) < 4): + raise Exception( + "Cropping not properly specified. Must provide 4 values: x1, x2, y1, y2" + ) + + if not args.dynamic: + args.dynamic = (False, 0.5, 10) + elif len(args.dynamic) < 3: + raise Exception( + "Dynamic cropping not properly specified. Must provide three values: 0 or 1 as boolean flag, pcutoff, and margin" + ) + else: + args.dynamic = (bool(args.dynamic[0]), args.dynamic[1], args.dynamic[2]) + + benchmark_videos( + args.model_path, + args.video_path, + output=args.output, + resize=args.resize, + pixels=args.pixels, + cropping=args.cropping, + dynamic=args.dynamic, + n_frames=args.n_frames, + print_rate=args.print_rate, + display=args.display, + pcutoff=args.pcutoff, + display_radius=args.display_radius, + cmap=args.cmap, + save_poses=args.save_poses, + save_video=args.save_video, + ) + + +if __name__ == "__main__": + main() diff --git a/dlclive/check_install/check_install.py b/dlclive/check_install/check_install.py index 7601533..30d6e79 100755 --- a/dlclive/check_install/check_install.py +++ b/dlclive/check_install/check_install.py @@ -5,19 +5,16 @@ Licensed under GNU Lesser General Public License v3.0 """ - -import sys +import argparse import shutil -import warnings - -from dlclive import benchmark_videos +import sys import urllib.request -import argparse +import warnings from pathlib import Path -from dlclibrary.dlcmodelzoo.modelzoo_download import ( - download_huggingface_model, -) +from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model + +from dlclive.benchmark_tf import benchmark_videos MODEL_NAME = "superanimal_quadruped" SNAPSHOT_NAME = "snapshot-700000.pb" @@ -27,28 +24,33 @@ def urllib_pbar(count, blockSize, totalSize): percent = int(count * blockSize * 100 / totalSize) outstr = f"{round(percent)}%" sys.stdout.write(outstr) - sys.stdout.write("\b"*len(outstr)) + sys.stdout.write("\b" * len(outstr)) sys.stdout.flush() def main(): parser = argparse.ArgumentParser( - description="Test DLC-Live installation by downloading and evaluating a demo DLC project!") - parser.add_argument('--nodisplay', action='store_false', help="Run the test without displaying tracking") + description="Test DLC-Live installation by downloading and evaluating a demo DLC project!" + ) + parser.add_argument( + "--nodisplay", + action="store_false", + help="Run the test without displaying tracking", + ) args = parser.parse_args() display = args.nodisplay if not display: - print('Running without displaying video') + print("Running without displaying video") # make temporary directory in $HOME # TODO: why create this temp directory in $HOME? print("\nCreating temporary directory...\n") - tmp_dir = Path().home() / 'dlc-live-tmp' - tmp_dir.mkdir(mode=0o775,exist_ok=True) + tmp_dir = Path().home() / "dlc-live-tmp" + tmp_dir.mkdir(mode=0o775, exist_ok=True) - video_file = str(tmp_dir / 'dog_clip.avi') - model_dir = tmp_dir / 'DLC_Dog_resnet_50_iteration-0_shuffle-0' + video_file = str(tmp_dir / "dog_clip.avi") + model_dir = tmp_dir / "DLC_Dog_resnet_50_iteration-0_shuffle-0" # download dog test video from github: # TODO: Should check if the video's already there before downloading it (should have been cloned with the files) @@ -58,25 +60,31 @@ def main(): # download model from the DeepLabCut Model Zoo if Path(model_dir / SNAPSHOT_NAME).exists(): - print('Model already downloaded, using cached version') + print("Model already downloaded, using cached version") else: print("Downloading full_dog model from the DeepLabCut Model Zoo...") download_huggingface_model(MODEL_NAME, model_dir) # assert these things exist so we can give informative error messages assert Path(video_file).exists(), f"Missing video file {video_file}" - assert Path(model_dir / SNAPSHOT_NAME).exists(), f"Missing model file {model_dir / SNAPSHOT_NAME}" + assert Path( + model_dir / SNAPSHOT_NAME + ).exists(), f"Missing model file {model_dir / SNAPSHOT_NAME}" # run benchmark videos print("\n Running inference...\n") - benchmark_videos(str(model_dir), video_file, display=display, resize=0.5, pcutoff=0.25) + benchmark_videos( + str(model_dir), video_file, display=display, resize=0.5, pcutoff=0.25 + ) # deleting temporary files print("\n Deleting temporary files...\n") try: shutil.rmtree(tmp_dir) except PermissionError: - warnings.warn(f'Could not delete temporary directory {str(tmp_dir)} due to a permissions error, but otherwise dlc-live seems to be working fine!') + warnings.warn( + f"Could not delete temporary directory {str(tmp_dir)} due to a permissions error, but otherwise dlc-live seems to be working fine!" + ) print("\nDone!\n") diff --git a/dlclive/core/__init__.py b/dlclive/core/__init__.py new file mode 100644 index 0000000..117d127 --- /dev/null +++ b/dlclive/core/__init__.py @@ -0,0 +1,10 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# diff --git a/dlclive/core/config.py b/dlclive/core/config.py new file mode 100644 index 0000000..1305cf9 --- /dev/null +++ b/dlclive/core/config.py @@ -0,0 +1,28 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Helpers for configuration file IO""" +from pathlib import Path + +import ruamel.yaml + + +def read_yaml(file_path: str | Path) -> dict: + file_path = Path(file_path).resolve() + if not file_path.exists(): + raise FileNotFoundError( + f"The pose configuration file for the exported model at {str(file_path)} " + "was not found. Please check the path to the exported model directory" + ) + + with open(file_path, "r") as f: + cfg = ruamel.yaml.YAML(typ="safe", pure=True).load(f) + + return cfg diff --git a/dlclive/core/inferenceutils.py b/dlclive/core/inferenceutils.py new file mode 100644 index 0000000..c160f40 --- /dev/null +++ b/dlclive/core/inferenceutils.py @@ -0,0 +1,1313 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import heapq +import itertools +import multiprocessing +import operator +import pickle +import warnings +from collections import defaultdict +from dataclasses import dataclass +from math import erf, sqrt +from typing import Any, Iterable, Tuple + +import networkx as nx +import numpy as np +import pandas as pd +from scipy.optimize import linear_sum_assignment +from scipy.spatial import cKDTree +from scipy.spatial.distance import cdist, pdist +from scipy.special import softmax +from scipy.stats import chi2, gaussian_kde +from tqdm import tqdm + + +def _conv_square_to_condensed_indices(ind_row, ind_col, n): + if ind_row == ind_col: + raise ValueError("There are no diagonal elements in condensed matrices.") + + if ind_row < ind_col: + ind_row, ind_col = ind_col, ind_row + return n * ind_col - ind_col * (ind_col + 1) // 2 + ind_row - 1 - ind_col + + +Position = Tuple[float, float] + + +@dataclass(frozen=True) +class Joint: + pos: Position + confidence: float = 1.0 + label: int = None + idx: int = None + group: int = -1 + + +class Link: + def __init__(self, j1, j2, affinity=1): + self.j1 = j1 + self.j2 = j2 + self.affinity = affinity + self._length = sqrt((j1.pos[0] - j2.pos[0]) ** 2 + (j1.pos[1] - j2.pos[1]) ** 2) + + def __repr__(self): + return ( + f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}" + ) + + @property + def confidence(self): + return self.j1.confidence * self.j2.confidence + + @property + def idx(self): + return self.j1.idx, self.j2.idx + + @property + def length(self): + return self._length + + @length.setter + def length(self, length): + self._length = length + + def to_vector(self): + return [*self.j1.pos, *self.j2.pos] + + +class Assembly: + def __init__(self, size): + self.data = np.full((size, 4), np.nan) + self.confidence = 0 # 0 by default, overwritten otherwise with `add_joint` + self._affinity = 0 + self._links = [] + self._visible = set() + self._idx = set() + self._dict = dict() + + def __len__(self): + return len(self._visible) + + def __contains__(self, assembly): + return bool(self._visible.intersection(assembly._visible)) + + def __add__(self, other): + if other in self: + raise ValueError("Assemblies contain shared joints.") + + assembly = Assembly(self.data.shape[0]) + for link in self._links + other._links: + assembly.add_link(link) + return assembly + + @classmethod + def from_array(cls, array): + n_bpts, n_cols = array.shape + + # if a single coordinate is NaN for a bodypart, set all to NaN + array[np.isnan(array).any(axis=-1)] = np.nan + + ass = cls(size=n_bpts) + ass.data[:, :n_cols] = array + visible = np.flatnonzero(~np.isnan(array).any(axis=1)) + if n_cols < 3: # Only xy coordinates are being set + ass.data[visible, 2] = 1 # Set detection confidence to 1 + ass._visible.update(visible) + return ass + + @property + def xy(self): + return self.data[:, :2] + + @property + def extent(self): + bbox = np.empty(4) + bbox[:2] = np.nanmin(self.xy, axis=0) + bbox[2:] = np.nanmax(self.xy, axis=0) + return bbox + + @property + def area(self): + x1, y1, x2, y2 = self.extent + return (x2 - x1) * (y2 - y1) + + @property + def confidence(self): + return np.nanmean(self.data[:, 2]) + + @confidence.setter + def confidence(self, confidence): + self.data[:, 2] = confidence + + @property + def soft_identity(self): + data = self.data[~np.isnan(self.data).any(axis=1)] + unq, idx, cnt = np.unique(data[:, 3], return_inverse=True, return_counts=True) + avg = np.bincount(idx, weights=data[:, 2]) / cnt + soft = softmax(avg) + return dict(zip(unq.astype(int), soft)) + + @property + def affinity(self): + n_links = self.n_links + if not n_links: + return 0 + return self._affinity / n_links + + @property + def n_links(self): + return len(self._links) + + def intersection_with(self, other): + x11, y11, x21, y21 = self.extent + x12, y12, x22, y22 = other.extent + x1 = max(x11, x12) + y1 = max(y11, y12) + x2 = min(x21, x22) + y2 = min(y21, y22) + if x2 < x1 or y2 < y1: + return 0 + ll = np.array([x1, y1]) + ur = np.array([x2, y2]) + xy1 = self.xy[~np.isnan(self.xy).any(axis=1)] + xy2 = other.xy[~np.isnan(other.xy).any(axis=1)] + in1 = np.all((xy1 >= ll) & (xy1 <= ur), axis=1).sum() + in2 = np.all((xy2 >= ll) & (xy2 <= ur), axis=1).sum() + return min(in1 / len(self), in2 / len(other)) + + def add_joint(self, joint): + if joint.label in self._visible or joint.label is None: + return False + self.data[joint.label] = *joint.pos, joint.confidence, joint.group + self._visible.add(joint.label) + self._idx.add(joint.idx) + return True + + def remove_joint(self, joint): + if joint.label not in self._visible: + return False + self.data[joint.label] = np.nan + self._visible.remove(joint.label) + self._idx.remove(joint.idx) + return True + + def add_link(self, link, store_dict=False): + if store_dict: + # Selective copy; deepcopy is >5x slower + self._dict = { + "data": self.data.copy(), + "_affinity": self._affinity, + "_links": self._links.copy(), + "_visible": self._visible.copy(), + "_idx": self._idx.copy(), + } + i1, i2 = link.idx + if i1 in self._idx and i2 in self._idx: + self._affinity += link.affinity + self._links.append(link) + return False + if link.j1.label in self._visible and link.j2.label in self._visible: + return False + self.add_joint(link.j1) + self.add_joint(link.j2) + self._affinity += link.affinity + self._links.append(link) + return True + + def calc_pairwise_distances(self): + return pdist(self.xy, metric="sqeuclidean") + + +class Assembler: + def __init__( + self, + data, + *, + max_n_individuals, + n_multibodyparts, + graph=None, + paf_inds=None, + greedy=False, + pcutoff=0.1, + min_affinity=0.05, + min_n_links=2, + max_overlap=0.8, + identity_only=False, + nan_policy="little", + force_fusion=False, + add_discarded=False, + window_size=0, + method="m1", + ): + self.data = data + self.metadata = self.parse_metadata(self.data) + self.max_n_individuals = max_n_individuals + self.n_multibodyparts = n_multibodyparts + self.n_uniquebodyparts = self.n_keypoints - n_multibodyparts + self.greedy = greedy + self.pcutoff = pcutoff + self.min_affinity = min_affinity + self.min_n_links = min_n_links + self.max_overlap = max_overlap + self._has_identity = "identity" in self[0] + if identity_only and not self._has_identity: + warnings.warn( + "The network was not trained with identity; setting `identity_only` to False." + ) + self.identity_only = identity_only & self._has_identity + self.nan_policy = nan_policy + self.force_fusion = force_fusion + self.add_discarded = add_discarded + self.window_size = window_size + self.method = method + self.graph = graph or self.metadata["paf_graph"] + self.paf_inds = paf_inds or self.metadata["paf"] + self._gamma = 0.01 + self._trees = dict() + self.safe_edge = False + self._kde = None + self.assemblies = dict() + self.unique = dict() + + def __getitem__(self, item): + return self.data[self.metadata["imnames"][item]] + + @classmethod + def empty( + cls, + max_n_individuals, + n_multibodyparts, + n_uniquebodyparts, + graph, + paf_inds, + greedy=False, + pcutoff=0.1, + min_affinity=0.05, + min_n_links=2, + max_overlap=0.8, + identity_only=False, + nan_policy="little", + force_fusion=False, + add_discarded=False, + window_size=0, + method="m1", + ): + # Dummy data + n_bodyparts = n_multibodyparts + n_uniquebodyparts + data = { + "metadata": { + "all_joints_names": ["" for _ in range(n_bodyparts)], + "PAFgraph": graph, + "PAFinds": paf_inds, + }, + "0": {}, + } + return cls( + data, + max_n_individuals=max_n_individuals, + n_multibodyparts=n_multibodyparts, + graph=graph, + paf_inds=paf_inds, + greedy=greedy, + pcutoff=pcutoff, + min_affinity=min_affinity, + min_n_links=min_n_links, + max_overlap=max_overlap, + identity_only=identity_only, + nan_policy=nan_policy, + force_fusion=force_fusion, + add_discarded=add_discarded, + window_size=window_size, + method=method, + ) + + @property + def n_keypoints(self): + return self.metadata["num_joints"] + + def calibrate(self, train_data_file): + df = pd.read_hdf(train_data_file) + try: + df.drop("single", level="individuals", axis=1, inplace=True) + except KeyError: + pass + n_bpts = len(df.columns.get_level_values("bodyparts").unique()) + if n_bpts == 1: + warnings.warn("There is only one keypoint; skipping calibration...") + return + + xy = df.to_numpy().reshape((-1, n_bpts, 2)) + frac_valid = np.mean(~np.isnan(xy), axis=(1, 2)) + # Only keeps skeletons that are more than 90% complete + xy = xy[frac_valid >= 0.9] + if not xy.size: + warnings.warn("No complete poses were found. Skipping calibration...") + return + + # TODO Normalize dists by longest length? + # TODO Smarter imputation technique (Bayesian? Grassmann averages?) + dists = np.vstack([pdist(data, "sqeuclidean") for data in xy]) + mu = np.nanmean(dists, axis=0) + missing = np.isnan(dists) + dists = np.where(missing, mu, dists) + try: + kde = gaussian_kde(dists.T) + kde.mean = mu + self._kde = kde + self.safe_edge = True + except np.linalg.LinAlgError: + # Covariance matrix estimation fails due to numerical singularities + warnings.warn( + "The assembler could not be robustly calibrated. Continuing without it..." + ) + + def calc_assembly_mahalanobis_dist( + self, assembly, return_proba=False, nan_policy="little" + ): + if self._kde is None: + raise ValueError("Assembler should be calibrated first with training data.") + + dists = assembly.calc_pairwise_distances() - self._kde.mean + mask = np.isnan(dists) + # Distance is undefined if the assembly is empty + if not len(assembly) or mask.all(): + if return_proba: + return np.inf, 0 + return np.inf + + if nan_policy == "little": + inds = np.flatnonzero(~mask) + dists = dists[inds] + inv_cov = self._kde.inv_cov[np.ix_(inds, inds)] + # Correct distance to account for missing observations + factor = self._kde.d / len(inds) + else: + # Alternatively, reduce contribution of missing values to the Mahalanobis + # distance to zero by substituting the corresponding means. + dists[mask] = 0 + mask.fill(False) + inv_cov = self._kde.inv_cov + factor = 1 + dot = dists @ inv_cov + mahal = factor * sqrt(np.sum((dot * dists), axis=-1)) + if return_proba: + proba = 1 - chi2.cdf(mahal, np.sum(~mask)) + return mahal, proba + return mahal + + def calc_link_probability(self, link): + if self._kde is None: + raise ValueError("Assembler should be calibrated first with training data.") + + i = link.j1.label + j = link.j2.label + ind = _conv_square_to_condensed_indices(i, j, self.n_multibodyparts) + mu = self._kde.mean[ind] + sigma = self._kde.covariance[ind, ind] + z = (link.length**2 - mu) / sigma + return 2 * (1 - 0.5 * (1 + erf(abs(z) / sqrt(2)))) + + @staticmethod + def _flatten_detections(data_dict): + ind = 0 + coordinates = data_dict["coordinates"][0] + confidence = data_dict["confidence"] + ids = data_dict.get("identity", None) + if ids is None: + ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence] + else: + ids = [arr.argmax(axis=1) for arr in ids] + for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids)): + if not np.any(coords): + continue + for xy, p, g in zip(coords, conf, id_): + joint = Joint(tuple(xy), p.item(), i, ind, g) + ind += 1 + yield joint + + def extract_best_links(self, joints_dict, costs, trees=None): + links = [] + for ind in self.paf_inds: + s, t = self.graph[ind] + dets_s = joints_dict.get(s, None) + dets_t = joints_dict.get(t, None) + if dets_s is None or dets_t is None: + continue + if ind not in costs: + continue + lengths = costs[ind]["distance"] + if np.isinf(lengths).all(): + continue + aff = costs[ind][self.method].copy() + aff[np.isnan(aff)] = 0 + + if trees: + vecs = np.vstack( + [[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t] + ) + dists = [] + for n, tree in enumerate(trees, start=1): + d, _ = tree.query(vecs) + dists.append(np.exp(-self._gamma * n * d)) + w = np.mean(dists, axis=0) + aff *= w.reshape(aff.shape) + + if self.greedy: + conf = np.asarray( + [ + [det_s.confidence * det_t.confidence for det_t in dets_t] + for det_s in dets_s + ] + ) + rows, cols = np.where( + (conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity) + ) + candidates = sorted( + zip(rows, cols, aff[rows, cols], lengths[rows, cols]), + key=lambda x: x[2], + reverse=True, + ) + i_seen = set() + j_seen = set() + for i, j, w, l in candidates: + if i not in i_seen and j not in j_seen: + i_seen.add(i) + j_seen.add(j) + links.append(Link(dets_s[i], dets_t[j], w)) + if len(i_seen) == self.max_n_individuals: + break + else: # Optimal keypoint pairing + inds_s = sorted( + range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True + )[: self.max_n_individuals] + inds_t = sorted( + range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True + )[: self.max_n_individuals] + keep_s = [ + ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff + ] + keep_t = [ + ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff + ] + aff = aff[np.ix_(keep_s, keep_t)] + rows, cols = linear_sum_assignment(aff, maximize=True) + for row, col in zip(rows, cols): + w = aff[row, col] + if w >= self.min_affinity: + links.append(Link(dets_s[keep_s[row]], dets_t[keep_t[col]], w)) + return links + + def _fill_assembly(self, assembly, lookup, assembled, safe_edge, nan_policy): + stack = [] + visited = set() + tabu = [] + counter = itertools.count() + + def push_to_stack(i): + for j, link in lookup[i].items(): + if j in assembly._idx: + continue + if link.idx in visited: + continue + heapq.heappush(stack, (-link.affinity, next(counter), link)) + visited.add(link.idx) + + for idx in assembly._idx: + push_to_stack(idx) + + while stack and len(assembly) < self.n_multibodyparts: + _, _, best = heapq.heappop(stack) + i, j = best.idx + if i in assembly._idx: + new_ind = j + elif j in assembly._idx: + new_ind = i + else: + continue + if new_ind in assembled: + continue + if safe_edge: + d_old = self.calc_assembly_mahalanobis_dist( + assembly, nan_policy=nan_policy + ) + success = assembly.add_link(best, store_dict=True) + if not success: + assembly._dict = dict() + continue + d = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy) + if d < d_old: + push_to_stack(new_ind) + try: + _, _, link = heapq.heappop(tabu) + heapq.heappush(stack, (-link.affinity, next(counter), link)) + except IndexError: + pass + else: + heapq.heappush(tabu, (d - d_old, next(counter), best)) + assembly.__dict__.update(assembly._dict) + assembly._dict = dict() + else: + assembly.add_link(best) + push_to_stack(new_ind) + + def build_assemblies(self, links): + lookup = defaultdict(dict) + for link in links: + i, j = link.idx + lookup[i][j] = link + lookup[j][i] = link + + assemblies = [] + assembled = set() + + # Fill the subsets with unambiguous, complete individuals + G = nx.Graph([link.idx for link in links]) + for chain in nx.connected_components(G): + if len(chain) == self.n_multibodyparts: + edges = [tuple(sorted(edge)) for edge in G.edges(chain)] + assembly = Assembly(self.n_multibodyparts) + for link in links: + i, j = link.idx + if (i, j) in edges: + success = assembly.add_link(link) + if success: + lookup[i].pop(j) + lookup[j].pop(i) + assembled.update(assembly._idx) + assemblies.append(assembly) + + if len(assemblies) == self.max_n_individuals: + return assemblies, assembled + + for link in sorted(links, key=lambda x: x.affinity, reverse=True): + if any(i in assembled for i in link.idx): + continue + assembly = Assembly(self.n_multibodyparts) + assembly.add_link(link) + self._fill_assembly( + assembly, lookup, assembled, self.safe_edge, self.nan_policy + ) + for link in assembly._links: + i, j = link.idx + lookup[i].pop(j) + lookup[j].pop(i) + assembled.update(assembly._idx) + assemblies.append(assembly) + + # Fuse superfluous assemblies + n_extra = len(assemblies) - self.max_n_individuals + if n_extra > 0: + if self.safe_edge: + ds_old = [ + self.calc_assembly_mahalanobis_dist(assembly) + for assembly in assemblies + ] + while len(assemblies) > self.max_n_individuals: + ds = [] + for i, j in itertools.combinations(range(len(assemblies)), 2): + if assemblies[j] not in assemblies[i]: + temp = assemblies[i] + assemblies[j] + d = self.calc_assembly_mahalanobis_dist(temp) + delta = d - max(ds_old[i], ds_old[j]) + ds.append((i, j, delta, d, temp)) + if not ds: + break + min_ = sorted(ds, key=lambda x: x[2]) + i, j, delta, d, new = min_[0] + if delta < 0 or len(min_) == 1: + assemblies[i] = new + assemblies.pop(j) + ds_old[i] = d + ds_old.pop(j) + else: + break + elif self.force_fusion: + assemblies = sorted(assemblies, key=len) + for nrow in range(n_extra): + assembly = assemblies[nrow] + candidates = [a for a in assemblies[nrow:] if assembly not in a] + if not candidates: + continue + if len(candidates) == 1: + candidate = candidates[0] + else: + dists = [] + for cand in candidates: + d = cdist(assembly.xy, cand.xy) + dists.append(np.nanmin(d)) + candidate = candidates[np.argmin(dists)] + ind = assemblies.index(candidate) + assemblies[ind] += assembly + else: + store = dict() + for assembly in assemblies: + if len(assembly) != self.n_multibodyparts: + for i in assembly._idx: + store[i] = assembly + used = [link for assembly in assemblies for link in assembly._links] + unconnected = [link for link in links if link not in used] + for link in unconnected: + i, j = link.idx + try: + if store[j] not in store[i]: + temp = store[i] + store[j] + store[i].__dict__.update(temp.__dict__) + assemblies.remove(store[j]) + for idx in store[j]._idx: + store[idx] = store[i] + except KeyError: + pass + + # Second pass without edge safety + for assembly in assemblies: + if len(assembly) != self.n_multibodyparts: + self._fill_assembly(assembly, lookup, assembled, False, "") + assembled.update(assembly._idx) + + return assemblies, assembled + + def _assemble(self, data_dict, ind_frame): + joints = list(self._flatten_detections(data_dict)) + if not joints: + return None, None + + bag = defaultdict(list) + for joint in joints: + bag[joint.label].append(joint) + + assembled = set() + + if self.n_uniquebodyparts: + unique = np.full((self.n_uniquebodyparts, 3), np.nan) + for n, ind in enumerate(range(self.n_multibodyparts, self.n_keypoints)): + dets = bag[ind] + if not dets: + continue + if len(dets) > 1: + det = max(dets, key=lambda x: x.confidence) + else: + det = dets[0] + # Mark the unique body parts as assembled anyway so + # they are not used later on to fill assemblies. + assembled.update(d.idx for d in dets) + if det.confidence <= self.pcutoff and not self.add_discarded: + continue + unique[n] = *det.pos, det.confidence + if np.isnan(unique).all(): + unique = None + else: + unique = None + + if not any(i in bag for i in range(self.n_multibodyparts)): + return None, unique + + if self.n_multibodyparts == 1: + assemblies = [] + for joint in bag[0]: + if joint.confidence >= self.pcutoff: + ass = Assembly(self.n_multibodyparts) + ass.add_joint(joint) + assemblies.append(ass) + return assemblies, unique + + if self.max_n_individuals == 1: + get_attr = operator.attrgetter("confidence") + ass = Assembly(self.n_multibodyparts) + for ind in range(self.n_multibodyparts): + joints = bag[ind] + if not joints: + continue + ass.add_joint(max(joints, key=get_attr)) + return [ass], unique + + if self.identity_only: + assemblies = [] + get_attr = operator.attrgetter("group") + temp = sorted( + (joint for joint in joints if np.isfinite(joint.confidence)), + key=get_attr, + ) + groups = itertools.groupby(temp, get_attr) + for _, group in groups: + ass = Assembly(self.n_multibodyparts) + for joint in sorted(group, key=lambda x: x.confidence, reverse=True): + if ( + joint.confidence >= self.pcutoff + and joint.label < self.n_multibodyparts + ): + ass.add_joint(joint) + if len(ass): + assemblies.append(ass) + assembled.update(ass._idx) + else: + trees = [] + for j in range(1, self.window_size + 1): + tree = self._trees.get(ind_frame - j, None) + if tree is not None: + trees.append(tree) + + links = self.extract_best_links(bag, data_dict["costs"], trees) + if self._kde: + for link in links[::-1]: + p = max(self.calc_link_probability(link), 0.001) + link.affinity *= p + if link.affinity < self.min_affinity: + links.remove(link) + + if self.window_size >= 1 and links: + # Store selected edges for subsequent frames + vecs = np.vstack([link.to_vector() for link in links]) + self._trees[ind_frame] = cKDTree(vecs) + + assemblies, assembled_ = self.build_assemblies(links) + assembled.update(assembled_) + + # Remove invalid assemblies + discarded = set( + joint + for joint in joints + if joint.idx not in assembled and np.isfinite(joint.confidence) + ) + for assembly in assemblies[::-1]: + if 0 < assembly.n_links < self.min_n_links or not len(assembly): + for link in assembly._links: + discarded.update((link.j1, link.j2)) + assemblies.remove(assembly) + if 0 < self.max_overlap < 1: # Non-maximum pose suppression + if self._kde is not None: + scores = [ + -self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies + ] + else: + scores = [ass._affinity for ass in assemblies] + lst = list(zip(scores, assemblies)) + assemblies = [] + while lst: + temp = max(lst, key=lambda x: x[0]) + lst.remove(temp) + assemblies.append(temp[1]) + for pair in lst[::-1]: + if temp[1].intersection_with(pair[1]) >= self.max_overlap: + lst.remove(pair) + if len(assemblies) > self.max_n_individuals: + assemblies = sorted(assemblies, key=len, reverse=True) + for assembly in assemblies[self.max_n_individuals :]: + for link in assembly._links: + discarded.update((link.j1, link.j2)) + assemblies = assemblies[: self.max_n_individuals] + + if self.add_discarded and discarded: + # Fill assemblies with unconnected body parts + for joint in sorted(discarded, key=lambda x: x.confidence, reverse=True): + if self.safe_edge: + for assembly in assemblies: + if joint.label in assembly._visible: + continue + d_old = self.calc_assembly_mahalanobis_dist(assembly) + assembly.add_joint(joint) + d = self.calc_assembly_mahalanobis_dist(assembly) + if d < d_old: + break + assembly.remove_joint(joint) + else: + dists = [] + for i, assembly in enumerate(assemblies): + if joint.label in assembly._visible: + continue + d = cdist(assembly.xy, np.atleast_2d(joint.pos)) + dists.append((i, np.nanmin(d))) + if not dists: + continue + min_ = sorted(dists, key=lambda x: x[1]) + ind, _ = min_[0] + assemblies[ind].add_joint(joint) + + return assemblies, unique + + def assemble(self, chunk_size=1, n_processes=None): + self.assemblies = dict() + self.unique = dict() + # Spawning (rather than forking) multiple processes does not + # work nicely with the GUI or interactive sessions. + # In that case, we fall back to the serial assembly. + if chunk_size == 0 or multiprocessing.get_start_method() == "spawn": + for i, data_dict in enumerate(tqdm(self)): + assemblies, unique = self._assemble(data_dict, i) + if assemblies: + self.assemblies[i] = assemblies + if unique is not None: + self.unique[i] = unique + else: + global wrapped # Hack to make the function pickable + + def wrapped(i): + return i, self._assemble(self[i], i) + + n_frames = len(self.metadata["imnames"]) + with multiprocessing.Pool(n_processes) as p: + with tqdm(total=n_frames) as pbar: + for i, (assemblies, unique) in p.imap_unordered( + wrapped, range(n_frames), chunksize=chunk_size + ): + if assemblies: + self.assemblies[i] = assemblies + if unique is not None: + self.unique[i] = unique + pbar.update() + + def from_pickle(self, pickle_path): + with open(pickle_path, "rb") as file: + data = pickle.load(file) + self.unique = data.pop("single", {}) + self.assemblies = data + + @staticmethod + def parse_metadata(data): + params = dict() + params["joint_names"] = data["metadata"]["all_joints_names"] + params["num_joints"] = len(params["joint_names"]) + params["paf_graph"] = data["metadata"]["PAFgraph"] + params["paf"] = data["metadata"].get( + "PAFinds", np.arange(len(params["joint_names"])) + ) + params["bpts"] = params["ibpts"] = range(params["num_joints"]) + params["imnames"] = [fn for fn in list(data) if fn != "metadata"] + return params + + def to_h5(self, output_name): + data = np.full( + ( + len(self.metadata["imnames"]), + self.max_n_individuals, + self.n_multibodyparts, + 4, + ), + fill_value=np.nan, + ) + for ind, assemblies in self.assemblies.items(): + for n, assembly in enumerate(assemblies): + data[ind, n] = assembly.data + index = pd.MultiIndex.from_product( + [ + ["scorer"], + map(str, range(self.max_n_individuals)), + map(str, range(self.n_multibodyparts)), + ["x", "y", "likelihood"], + ], + names=["scorer", "individuals", "bodyparts", "coords"], + ) + temp = data[..., :3].reshape((data.shape[0], -1)) + df = pd.DataFrame(temp, columns=index) + df.to_hdf(output_name, key="ass") + + def to_pickle(self, output_name): + data = dict() + for ind, assemblies in self.assemblies.items(): + data[ind] = [ass.data for ass in assemblies] + if self.unique: + data["single"] = self.unique + with open(output_name, "wb") as file: + pickle.dump(data, file, pickle.HIGHEST_PROTOCOL) + + +@dataclass +class MatchedPrediction: + """A match between a prediction and a ground truth assembly + + The ground truth assembly should be None f the prediction was not matched to any GT, + and the OKS should be 0. + + Attributes: + prediction: A prediction made by a pose model. + score: The confidence score for the prediction. + ground_truth: If None, then this prediction is not matched to any ground truth + (this can happen when there are more predicted individuals than GT). + Otherwise, the ground truth assembly to which this prediction is matched. + oks: The OKS score between the prediction and the ground truth pose. + """ + + prediction: Assembly + score: float + ground_truth: Assembly | None + oks: float + + +def calc_object_keypoint_similarity( + xy_pred, + xy_true, + sigma, + margin=0, + symmetric_kpts=None, +): + visible_gt = ~np.isnan(xy_true).all(axis=1) + if visible_gt.sum() < 2: # At least 2 points needed to calculate scale + return np.nan + + true = xy_true[visible_gt] + scale_squared = np.product(np.ptp(true, axis=0) + np.spacing(1) + margin * 2) + if np.isclose(scale_squared, 0): + return np.nan + + k_squared = (2 * sigma) ** 2 + denom = 2 * scale_squared * k_squared + if symmetric_kpts is None: + pred = xy_pred[visible_gt] + pred[np.isnan(pred)] = np.inf + dist_squared = np.sum((pred - true) ** 2, axis=1) + oks = np.exp(-dist_squared / denom) + return np.mean(oks) + else: + oks = [] + xy_preds = [xy_pred] + combos = ( + pair + for l in range(len(symmetric_kpts)) + for pair in itertools.combinations(symmetric_kpts, l + 1) + ) + for pairs in combos: + # Swap corresponding keypoints + tmp = xy_pred.copy() + for pair in pairs: + tmp[pair, :] = tmp[pair[::-1], :] + xy_preds.append(tmp) + for xy_pred in xy_preds: + pred = xy_pred[visible_gt] + pred[np.isnan(pred)] = np.inf + dist_squared = np.sum((pred - true) ** 2, axis=1) + oks.append(np.mean(np.exp(-dist_squared / denom))) + return max(oks) + + +def match_assemblies( + predictions: list[Assembly], + ground_truth: list[Assembly], + sigma: float, + margin: int = 0, + symmetric_kpts: list[tuple[int, int]] | None = None, + greedy_matching: bool = False, + greedy_oks_threshold: float = 0.0, +) -> tuple[int, list[MatchedPrediction]]: + """Matches assemblies to ground truth predictions + + Returns: + int: the total number of valid ground truth assemblies + list[MatchedPrediction]: a list containing all valid predictions, potentially + matched to ground truth assemblies. + """ + # Only consider assemblies of at least two keypoints + predictions = [a for a in predictions if len(a) > 1] + ground_truth = [a for a in ground_truth if len(a) > 1] + num_ground_truth = len(ground_truth) + + # Sort predictions by score + inds_pred = np.argsort( + [ins.affinity if ins.n_links else ins.confidence for ins in predictions] + )[::-1] + predictions = np.asarray(predictions)[inds_pred] + + # indices of unmatched ground truth assemblies + matched = [ + MatchedPrediction( + prediction=p, + score=(p.affinity if p.n_links else p.confidence), + ground_truth=None, + oks=0.0, + ) + for p in predictions + ] + + # Greedy assembly matching like in pycocotools + if greedy_matching: + matched_gt_indices = set() + for idx, pred in enumerate(predictions): + oks = [ + calc_object_keypoint_similarity( + pred.xy, + gt.xy, + sigma, + margin, + symmetric_kpts, + ) + for gt in ground_truth + ] + if np.all(np.isnan(oks)): + continue + + ind_best = np.nanargmax(oks) + + # if this gt already matched, and not a crowd, continue + if ind_best in matched_gt_indices: + continue + + # Only match the pred to the GT if the OKS value is above a given threshold + if oks[ind_best] < greedy_oks_threshold: + continue + + matched_gt_indices.add(ind_best) + matched[idx].ground_truth = ground_truth[ind_best] + matched[idx].oks = oks[ind_best] + + # Global rather than greedy assembly matching + else: + inds_true = list(range(len(ground_truth))) + mat = np.zeros((len(predictions), len(ground_truth))) + for i, a_pred in enumerate(predictions): + for j, a_true in enumerate(ground_truth): + oks = calc_object_keypoint_similarity( + a_pred.xy, + a_true.xy, + sigma, + margin, + symmetric_kpts, + ) + if ~np.isnan(oks): + mat[i, j] = oks + rows, cols = linear_sum_assignment(mat, maximize=True) + for row, col in zip(rows, cols): + matched[row].ground_truth = ground_truth[col] + matched[row].oks = mat[row, col] + _ = inds_true.remove(col) + + return num_ground_truth, matched + + +def parse_ground_truth_data_file(h5_file): + df = pd.read_hdf(h5_file) + try: + df.drop("single", axis=1, level="individuals", inplace=True) + except KeyError: + pass + # Cast columns of dtype 'object' to float to avoid TypeError + # further down in _parse_ground_truth_data. + cols = df.select_dtypes(include="object").columns + if cols.to_list(): + df[cols] = df[cols].astype("float") + n_individuals = len(df.columns.get_level_values("individuals").unique()) + n_bodyparts = len(df.columns.get_level_values("bodyparts").unique()) + data = df.to_numpy().reshape((df.shape[0], n_individuals, n_bodyparts, -1)) + return _parse_ground_truth_data(data) + + +def _parse_ground_truth_data(data): + gt = dict() + for i, arr in enumerate(data): + temp = [] + for row in arr: + if np.isnan(row[:, :2]).all(): + continue + ass = Assembly.from_array(row) + temp.append(ass) + if not temp: + continue + gt[i] = temp + return gt + + +def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)): + if not hasattr(Assembly, criterion): + raise ValueError(f"Invalid criterion {criterion}.") + + if len(qs) != 2: + raise ValueError( + "Two percentiles (for lower and upper bounds) should be given." + ) + + tuples = [] + for frame_ind, assemblies in dict_of_assemblies.items(): + for assembly in assemblies: + tuples.append((frame_ind, getattr(assembly, criterion))) + frame_inds, vals = zip(*tuples) + vals = np.asarray(vals) + lo, up = np.percentile(vals, qs, interpolation="nearest") + inds = np.flatnonzero((vals < lo) | (vals > up)).tolist() + return list(set(frame_inds[i] for i in inds)) + + +def _compute_precision_and_recall( + num_gt_assemblies: int, + oks_values: np.ndarray, + oks_threshold: float, + recall_thresholds: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Computes the precision and recall scores at a given OKS threshold + + Args: + num_gt_assemblies: the number of ground truth assemblies (used to compute false + negatives + true positives). + oks_values: the OKS value to the matched GT assembly for each prediction + oks_threshold: the OKS threshold at which recall and precision are being + computed + recall_thresholds: the recall thresholds to use to compute scores + + Returns: + The precision and recall arrays at each recall threshold + """ + tp = np.cumsum(oks_values >= oks_threshold) + fp = np.cumsum(oks_values < oks_threshold) + rc = tp / num_gt_assemblies + pr = tp / (fp + tp + np.spacing(1)) + recall = rc[-1] + + # Guarantee precision decreases monotonically, see + # https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173 + for i in range(len(pr) - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] + + inds_rc = np.searchsorted(rc, recall_thresholds, side="left") + precision = np.zeros(inds_rc.shape) + valid = inds_rc < len(pr) + precision[valid] = pr[inds_rc[valid]] + return precision, recall + + +def evaluate_assembly_greedy( + assemblies_gt: dict[Any, list[Assembly]], + assemblies_pred: dict[Any, list[Assembly]], + oks_sigma: float, + oks_thresholds: Iterable[float], + margin: int | float = 0, + symmetric_kpts: list[tuple[int, int]] | None = None, +) -> dict: + """Runs greedy mAP evaluation, as done by pycocotools + + Args: + assemblies_gt: A dictionary mapping image ID (e.g. filepath) to ground truth + assemblies. Should contain all the same keys as ``assemblies_pred``. + assemblies_pred: A dictionary mapping image ID (e.g. filepath) to predicted + assemblies. Should contain all the same keys as ``assemblies_gt``. + oks_sigma: The sigma to use to compute OKS values for keypoints . + oks_thresholds: The OKS thresholds at which to compute precision & recall. + margin: The margin to use to compute bounding boxes from keypoints. + symmetric_kpts: The symmetric keypoints in the dataset. + """ + recall_thresholds = np.linspace( # np.linspace(0, 1, 101) + start=0.0, stop=1.00, num=int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True + ) + precisions = [] + recalls = [] + for oks_t in oks_thresholds: + all_matched = [] + total_gt_assemblies = 0 + for ind, gt_assembly in assemblies_gt.items(): + pred_assemblies = assemblies_pred.get(ind, []) + num_gt_assemblies, matched = match_assemblies( + pred_assemblies, + gt_assembly, + oks_sigma, + margin, + symmetric_kpts, + greedy_matching=True, + greedy_oks_threshold=oks_t, + ) + all_matched.extend(matched) + total_gt_assemblies += num_gt_assemblies + + if len(all_matched) == 0: + precisions.append(0.0) + recalls.append(0.0) + continue + + # Global sort of assemblies (across all images) by score + scores = np.asarray([-m.score for m in all_matched]) + sorted_pred_indices = np.argsort(scores, kind="mergesort") + oks = np.asarray([match.oks for match in all_matched])[sorted_pred_indices] + + # Compute prediction and recall + p, r = _compute_precision_and_recall( + total_gt_assemblies, oks, oks_t, recall_thresholds + ) + precisions.append(p) + recalls.append(r) + + precisions = np.asarray(precisions) + recalls = np.asarray(recalls) + return { + "precisions": precisions, + "recalls": recalls, + "mAP": precisions.mean(), + "mAR": recalls.mean(), + } + + +def evaluate_assembly( + ass_pred_dict, + ass_true_dict, + oks_sigma=0.072, + oks_thresholds=np.linspace(0.5, 0.95, 10), + margin=0, + symmetric_kpts=None, + greedy_matching=False, + with_tqdm: bool = True, +): + if greedy_matching: + return evaluate_assembly_greedy( + ass_true_dict, + ass_pred_dict, + oks_sigma=oks_sigma, + oks_thresholds=oks_thresholds, + margin=margin, + symmetric_kpts=symmetric_kpts, + ) + + # sigma is taken as the median of all COCO keypoint standard deviations + all_matched = [] + total_gt_assemblies = 0 + + gt_assemblies = ass_true_dict.items() + if with_tqdm: + gt_assemblies = tqdm(gt_assemblies) + + for ind, gt_assembly in gt_assemblies: + pred_assemblies = ass_pred_dict.get(ind, []) + num_gt, matched = match_assemblies( + pred_assemblies, + gt_assembly, + oks_sigma, + margin, + symmetric_kpts, + greedy_matching, + ) + all_matched.extend(matched) + total_gt_assemblies += num_gt + + if not all_matched: + return { + "precisions": np.array([]), + "recalls": np.array([]), + "mAP": 0.0, + "mAR": 0.0, + } + + conf_pred = np.asarray([match.score for match in all_matched]) + idx = np.argsort(-conf_pred, kind="mergesort") + # Sort matching score (OKS) in descending order of assembly affinity + oks = np.asarray([match.oks for match in all_matched])[idx] + recall_thresholds = np.linspace(0, 1, 101) + precisions = [] + recalls = [] + for t in oks_thresholds: + p, r = _compute_precision_and_recall( + total_gt_assemblies, oks, t, recall_thresholds + ) + precisions.append(p) + recalls.append(r) + + precisions = np.asarray(precisions) + recalls = np.asarray(recalls) + return { + "precisions": precisions, + "recalls": recalls, + "mAP": precisions.mean(), + "mAR": recalls.mean(), + } diff --git a/dlclive/core/runner.py b/dlclive/core/runner.py new file mode 100644 index 0000000..00295d2 --- /dev/null +++ b/dlclive/core/runner.py @@ -0,0 +1,96 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Base runner for DeepLabCut-Live""" +import abc +from pathlib import Path + +import numpy as np + + +class BaseRunner(abc.ABC): + """Base runner for live pose estimation using DeepLabCut-Live. + + Args: + path: The path to the model to run inference with. + + Attributes: + cfg: The pose configuration data. + path: The path to the model to run inference with. + """ + + def __init__(self, path: str | Path) -> None: + self.path = Path(path) + self.cfg = None + + @abc.abstractmethod + def close(self) -> None: + """Clears any resources used by the runner.""" + pass + + @abc.abstractmethod + def get_pose(self, frame: np.ndarray | None, **kwargs) -> np.ndarray | None: + """ + Abstract method to calculate and retrieve the pose of an object or system + based on the given input frame of data. This method must be implemented + by any subclass inheriting from this abstract base class to define the + specific approach for pose estimation. + + Parameters + ---------- + frame : np.ndarray + The input data or image frame used for estimating the pose. Typically + represents visual data such as video or image frames. + kwargs : dict, optional + Additional keyword arguments that may be required for specific pose + estimation techniques implemented in the subclass. + + Returns + ------- + np.ndarray + The estimated pose resulting from the pose estimation process. The + structure of the array may depend on the specific implementation + but typically represents transformations or coordinates. + """ + pass + + @abc.abstractmethod + def init_inference(self, frame: np.ndarray | None, **kwargs) -> np.ndarray | None: + """ + Initializes inference process on the provided frame. + + This method serves as an abstract base method, meant to be implemented by + subclasses. It takes an input image frame and optional additional parameters + to set up and perform inference. The method must return a processed result + as a numpy array. + + Parameters + ---------- + frame : np.ndarray + The input image frame for which inference needs to be set up. + kwargs : dict, optional + Additional parameters that may be required for specific implementation + of the inference initialization. + + Returns + ------- + np.ndarray + The result of the inference after being initialized and processed. + """ + pass + + @abc.abstractmethod + def read_config(self): + """Reads the pose configuration file. + + Raises: + FileNotFoundError: if the pose configuration file does not exist + """ + pass diff --git a/dlclive/display.py b/dlclive/display.py index cc324d8..e349c2c 100644 --- a/dlclive/display.py +++ b/dlclive/display.py @@ -5,28 +5,25 @@ Licensed under GNU Lesser General Public License v3.0 """ +from tkinter import Label, Tk -from tkinter import Tk, Label import colorcet as cc -from PIL import Image, ImageTk, ImageDraw +from PIL import Image, ImageDraw, ImageTk -class Display(object): +class Display: """ Simple object to display frames with DLC labels. Parameters ----------- - cmap : string - string indicating the Matoplotlib colormap to use. + cmap: string + The Matplotlib colormap to use. pcutoff : float likelihood threshold to display points """ def __init__(self, cmap="bmy", radius=3, pcutoff=0.5): - """ Constructor method - """ - self.cmap = cmap self.colors = None self.radius = radius @@ -34,8 +31,8 @@ def __init__(self, cmap="bmy", radius=3, pcutoff=0.5): self.window = None def set_display(self, im_size, bodyparts): - """ Create tkinter window to display image - + """Create tkinter window to display image + Parameters ---------- im_size : tuple @@ -64,46 +61,46 @@ def display_frame(self, frame, pose=None): pose :class:`numpy.ndarray` the pose estimated by DeepLabCut for the image """ - im_size = (frame.shape[1], frame.shape[0]) - if pose is not None: - if self.window is None: self.set_display(im_size, pose.shape[0]) img = Image.fromarray(frame) draw = ImageDraw.Draw(img) + if len(pose.shape) == 2: + pose = pose[None] for i in range(pose.shape[0]): - if pose[i, 2] > self.pcutoff: - try: - x0 = ( - pose[i, 0] - self.radius - if pose[i, 0] - self.radius > 0 - else 0 - ) - x1 = ( - pose[i, 0] + self.radius - if pose[i, 0] + self.radius < im_size[0] - else im_size[1] - ) - y0 = ( - pose[i, 1] - self.radius - if pose[i, 1] - self.radius > 0 - else 0 - ) - y1 = ( - pose[i, 1] + self.radius - if pose[i, 1] + self.radius < im_size[1] - else im_size[0] - ) - coords = [x0, y0, x1, y1] - draw.ellipse( - coords, fill=self.colors[i], outline=self.colors[i] - ) - except Exception as e: - print(e) + for j in range(pose.shape[1]): + if pose[i, j, 2] > self.pcutoff: + try: + x0 = ( + pose[i, j, 0] - self.radius + if pose[i, j, 0] - self.radius > 0 + else 0 + ) + x1 = ( + pose[i, j, 0] + self.radius + if pose[i, j, 0] + self.radius < im_size[0] + else im_size[1] + ) + y0 = ( + pose[i, j, 1] - self.radius + if pose[i, j, 1] - self.radius > 0 + else 0 + ) + y1 = ( + pose[i, j, 1] + self.radius + if pose[i, j, 1] + self.radius < im_size[1] + else im_size[0] + ) + coords = [x0, y0, x1, y1] + draw.ellipse( + coords, fill=self.colors[j], outline=self.colors[j] + ) + except Exception as e: + print(e) img_tk = ImageTk.PhotoImage(image=img, master=self.window) self.lab.configure(image=img_tk) @@ -113,5 +110,4 @@ def destroy(self): """ Destroys the opencv image window """ - self.window.destroy() diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index 210671e..74a1ffa 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -4,90 +4,145 @@ Licensed under GNU Lesser General Public License v3.0 """ +from __future__ import annotations -import os -import ruamel.yaml -import glob -import warnings -import numpy as np -import tensorflow as tf -import typing from pathlib import Path -from typing import Optional, Tuple, List - -try: - TFVER = [int(v) for v in tf.__version__.split(".")] - if TFVER[1] < 14: - from tensorflow.contrib.tensorrt import trt_convert as trt - else: - from tensorflow.python.compiler.tensorrt import trt_convert as trt -except Exception: - pass - -from dlclive.graph import ( - read_graph, - finalize_graph, - get_output_nodes, - get_output_tensors, - extract_graph, -) -from dlclive.pose import extract_cnn_output, argmax_pose_predict, multi_pose_predict +from typing import Any + +import numpy as np + +import dlclive.factory as factory +import dlclive.utils as utils +from dlclive.core.runner import BaseRunner from dlclive.display import Display -from dlclive import utils -from dlclive.exceptions import DLCLiveError, DLCLiveWarning -if typing.TYPE_CHECKING: - from dlclive.processor import Processor +from dlclive.exceptions import DLCLiveError +from dlclive.processor import Processor + -class DLCLive(object): +class DLCLive: """ - Object that loads a DLC network and performs inference on single images (e.g. images captured from a camera feed) + Class that loads a DLC network and performs inference on single images (e.g. + images captured from a camera feed) Parameters ----------- - path : string - Full path to exported model directory + model_path: Path + Full path to exported model (created when `deeplabcut.export_model(...)` was + called). For PyTorch models, this is a single model file. For TensorFlow models, + this is a directory containing the model snapshots. model_type: string, optional - which model to use: 'base', 'tensorrt' for tensorrt optimized graph, 'lite' for tensorflow lite optimized graph - - precision : string, optional - precision of model weights, only for model_type='tensorrt'. Can be 'FP16' (default), 'FP32', or 'INT8' - - cropping : list of int - cropping parameters in pixel number: [x1, x2, y1, y2] + Which model to use. For the PyTorch engine, options are [`pytorch`]. For the + TensorFlow engine, options are [`base`, `tensorrt`, `lite`]. + + precision: string, optional + Precision of model weights, for model_type "pytorch" and "tensorrt". Options + are, for different model_types: + "pytorch": {"FP32", "FP16"} + "tensorrt": {"FP32", "FP16", "INT8"} + + tf_config: + TensorFlow only. Optional ConfigProto for the TensorFlow session. + + single_animal: bool, default=True + PyTorch only. If True, the predicted pose array returned by the runner will be + (num_bodyparts, 3). As multi-animal pose estimation can be run with the PyTorch + engine, setting this to False means the returned pose array will be of shape + (num_detections, num_bodyparts, 3). + + device: str, optional, default=None + PyTorch only. The device on which to run inference, e.g. "cpu", "cuda" or + "cuda:0". If set to None or "auto", the device will be automatically selected + based on CUDA availability. + + top_down_config: dict, optional, default=None + PyTorch only. Configuration settings for top-down pose estimation models. Must + be provided when running top-down models and `top_down_dynamic` is None. The + parameters in the dict will be given to the `TopDownConfig` class (in + `dlclive/pose_estimation_pytorch/runner.py`). The `crop_size` does not need to + be set, as it will be read from the model configuration file. + Example parameters: + >>> # Running a top-down model with basic parameters + >>> top_down_config = { + >>> "bbox_cutoff": 0.5, # min confidence score for a bbox to be used + >>> "max_detections": 3, # max number of detections to return in a frame + >>> } + >>> # Running a top-down model with skip-frames + >>> top_down_config = { + >>> "bbox_cutoff": 0.5, # min confidence score for a bbox to be used + >>> "max_detections": 3, # max number of detections to return in a frame + >>> "skip_frames": { # only run the detector every 5 frames + >>> "skip": 5, # number of frames to skip between detections + >>> "margin": 5, # margin (in pixels) to use when generating bboxes + >>> }, + >>> } + + top_down_dynamic: dict, optional, default=None + PyTorch only. Single animal only. Top-down models do not need a detector to be + used for single animal pose estimation. This is equivalent to dynamic cropping + in TensorFlow or for bottom-up models, but crops are resized to the input size + required by the model. Pose estimation is never run on the full image. If no + animal is detected, the image is split into N by M "patches", and we run pose + estimation on the batch of patches. Pose is kept from the patch with the + highest likelyhood. No need to provide the `top_down_crop_size` parameter, as it + set using the model configuration file. + The parameters (except "type") will be passed to the `TopDownDynamicCropper` + class (in `dlclive/pose_estimation_pytorch/dynamic_cropping.py` + + Example parameters: + >>> top_down_dynamic = { + >>> "type": "TopDownDynamicCropper", + >>> "min_bbox_size": (50, 50), + >>> } + + cropping: list of int + Cropping parameters in pixel number: [x1, x2, y1, y2] dynamic: triple containing (state, detectiontreshold, margin) - If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), - then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is - expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. detectiontreshold), then object + boundaries are computed according to the smallest/largest x position and + smallest/largest y position of all body parts. This window is expanded by the + margin and from then on only the posture within this crop is analyzed (until the + object is lost, i.e. dict | None: + return self.runner.cfg - def read_config(self): - """ Reads configuration yaml file + def read_config(self) -> None: + """Reads configuration yaml file Raises ------ FileNotFoundError - error thrown if pose configuration file does nott exist + error thrown if pose configuration file does not exist """ - - cfg_path = Path(self.path).resolve() / "pose_cfg.yaml" - if not cfg_path.exists(): - raise FileNotFoundError( - f"The pose configuration file for the exported model at {str(cfg_path)} was not found. Please check the path to the exported model directory" - ) - - ruamel_file = ruamel.yaml.YAML() - self.cfg = ruamel_file.load(open(str(cfg_path), "r")) + self.runner.read_config() @property - def parameterization(self) -> dict: - """ - Return - Returns - ------- - """ + def parameterization( + self, + ) -> dict: return {param: getattr(self, param) for param in self.PARAMETERS} - def process_frame(self, frame): + def process_frame(self, frame: np.ndarray) -> np.ndarray: """ Crops an image according to the object's cropping and dynamic properties. @@ -195,38 +240,45 @@ def process_frame(self, frame): frame :class:`numpy.ndarray` processed frame: convert type, crop, convert color """ - - if frame.dtype != np.uint8: - - frame = utils.convert_to_ubyte(frame) - if self.cropping: - frame = frame[ self.cropping[2] : self.cropping[3], self.cropping[0] : self.cropping[1] ] if self.dynamic[0]: - if self.pose is not None: + # Deal with PyTorch multi-animal models + if len(self.pose.shape) == 3: + if len(self.pose) == 0: + pose = np.zeros((1, 3)) + elif len(self.pose) == 1: + pose = self.pose[0] + else: + raise ValueError( + "Cannot use Dynamic Cropping - more than 1 individual found" + ) - detected = self.pose[:, 2] > self.dynamic[1] + else: + pose = self.pose + detected = pose[:, 2] >= self.dynamic[1] if np.any(detected): + h, w = frame.shape[0], frame.shape[1] - x = self.pose[detected, 0] - y = self.pose[detected, 1] + x = pose[detected, 0] + y = pose[detected, 1] + xmin, xmax = int(np.min(x)), int(np.max(x)) + ymin, ymax = int(np.min(y)), int(np.max(y)) - x1 = int(max([0, int(np.amin(x)) - self.dynamic[2]])) - x2 = int(min([frame.shape[1], int(np.amax(x)) + self.dynamic[2]])) - y1 = int(max([0, int(np.amin(y)) - self.dynamic[2]])) - y2 = int(min([frame.shape[0], int(np.amax(y)) + self.dynamic[2]])) - self.dynamic_cropping = [x1, x2, y1, y2] + x1 = max([0, xmin - self.dynamic[2]]) + x2 = min([w, xmax + self.dynamic[2]]) + y1 = max([0, ymin - self.dynamic[2]]) + y2 = min([h, ymax + self.dynamic[2]]) + self.dynamic_cropping = [x1, x2, y1, y2] frame = frame[y1:y2, x1:x2] else: - self.dynamic_cropping = None if self.resize != 1: @@ -237,9 +289,10 @@ def process_frame(self, frame): return frame - def init_inference(self, frame=None, **kwargs): + def init_inference(self, frame=None, **kwargs) -> np.ndarray: """ - Load model and perform inference on first frame -- the first inference is usually very slow. + Load model and perform inference on first frame -- the first inference is + usually very slow. Parameters ----------- @@ -248,135 +301,20 @@ def init_inference(self, frame=None, **kwargs): Returns -------- - pose :class:`numpy.ndarray` - the pose estimated by DeepLabCut for the input image + pose: the pose estimated by DeepLabCut for the input image """ + if frame is None: + raise DLCLiveError("No frame provided to initialize inference.") - # get model file - - model_file = glob.glob(os.path.normpath(self.path + "/*.pb"))[0] - if not os.path.isfile(model_file): - raise FileNotFoundError( - "The model file {} does not exist.".format(model_file) - ) - - # process frame - - if frame is None and (self.model_type == "tflite"): - raise DLCLiveError( - "No image was passed to initialize inference. An image must be passed to the init_inference method" - ) - - if frame is not None: - if frame.ndim == 2: - self.convert2rgb = True - processed_frame = self.process_frame(frame) - - # load model - - if self.model_type == "base": - - graph_def = read_graph(model_file) - graph = finalize_graph(graph_def) - self.sess, self.inputs, self.outputs = extract_graph( - graph, tf_config=self.tf_config - ) - - elif self.model_type == "tflite": - - ### - # the frame size needed to initialize the tflite model as - # tflite does not support saving a model with dynamic input size - ### - - # get input and output tensor names from graph_def - graph_def = read_graph(model_file) - graph = finalize_graph(graph_def) - output_nodes = get_output_nodes(graph) - output_nodes = [on.replace("DLC/", "") for on in output_nodes] - - tf_version_2 = tf.__version__[0] == '2' - - if tf_version_2: - converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( - model_file, - ["Placeholder"], - output_nodes, - input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]}, - ) - else: - converter = tf.lite.TFLiteConverter.from_frozen_graph( - model_file, - ["Placeholder"], - output_nodes, - input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]}, - ) - - try: - tflite_model = converter.convert() - except Exception: - raise DLCLiveError( - ( - "This model cannot be converted to tensorflow lite format. " - "To use tensorflow lite for live inference, " - "make sure to set TFGPUinference=False " - "when exporting the model from DeepLabCut" - ) - ) - - self.tflite_interpreter = tf.lite.Interpreter(model_content=tflite_model) - self.tflite_interpreter.allocate_tensors() - self.inputs = self.tflite_interpreter.get_input_details() - self.outputs = self.tflite_interpreter.get_output_details() - - elif self.model_type == "tensorrt": - - graph_def = read_graph(model_file) - graph = finalize_graph(graph_def) - output_tensors = get_output_tensors(graph) - output_tensors = [ot.replace("DLC/", "") for ot in output_tensors] - - if (TFVER[0] > 1) | (TFVER[0] == 1 & TFVER[1] >= 14): - converter = trt.TrtGraphConverter( - input_graph_def=graph_def, - nodes_blacklist=output_tensors, - is_dynamic_op=True, - ) - graph_def = converter.convert() - else: - graph_def = trt.create_inference_graph( - input_graph_def=graph_def, - outputs=output_tensors, - max_batch_size=1, - precision_mode=self.precision, - is_dynamic_op=True, - ) - - graph = finalize_graph(graph_def) - self.sess, self.inputs, self.outputs = extract_graph( - graph, tf_config=self.tf_config - ) - - else: - - raise DLCLiveError( - "model_type = {} is not supported. model_type must be 'base', 'tflite', or 'tensorrt'".format( - self.model_type - ) - ) - - # get pose of first frame (first inference is often very slow) - - if frame is not None: - pose = self.get_pose(frame, **kwargs) - else: - pose = None + if frame.ndim >= 2: + self.convert2rgb = True + processed_frame = self.process_frame(frame) + self.pose = self.runner.init_inference(processed_frame) self.is_initialized = True + return self._post_process_pose(processed_frame, **kwargs) - return pose - - def get_pose(self, frame=None, **kwargs): + def get_pose(self, frame: np.ndarray | None = None, **kwargs) -> np.ndarray: """ Get the pose of an image @@ -389,92 +327,45 @@ def get_pose(self, frame=None, **kwargs): -------- pose :class:`numpy.ndarray` the pose estimated by DeepLabCut for the input image + inf_time:class: `float` + the pose inference time """ - if frame is None: raise DLCLiveError("No frame provided for live pose estimation") - frame = self.process_frame(frame) - - if self.model_type in ["base", "tensorrt"]: - - pose_output = self.sess.run( - self.outputs, feed_dict={self.inputs: np.expand_dims(frame, axis=0)} - ) - - elif self.model_type == "tflite": - - self.tflite_interpreter.set_tensor( - self.inputs[0]["index"], - np.expand_dims(frame, axis=0).astype(np.float32), - ) - self.tflite_interpreter.invoke() - - if len(self.outputs) > 1: - pose_output = [ - self.tflite_interpreter.get_tensor(self.outputs[0]["index"]), - self.tflite_interpreter.get_tensor(self.outputs[1]["index"]), - ] - else: - pose_output = self.tflite_interpreter.get_tensor( - self.outputs[0]["index"] - ) - - else: + if frame.ndim >= 2: + self.convert2rgb = True - raise DLCLiveError( - "model_type = {} is not supported. model_type must be 'base', 'tflite', or 'tensorrt'".format( - self.model_type - ) - ) - - # check if using TFGPUinference flag - # if not, get pose from network output - - if len(pose_output) > 1: - scmap, locref = extract_cnn_output(pose_output, self.cfg) - num_outputs = self.cfg.get("num_outputs", 1) - if num_outputs > 1: - self.pose = multi_pose_predict( - scmap, locref, self.cfg["stride"], num_outputs - ) - else: - self.pose = argmax_pose_predict(scmap, locref, self.cfg["stride"]) - else: - pose = np.array(pose_output[0]) - self.pose = pose[:, [1, 0, 2]] + processed_frame = self.process_frame(frame) + self.pose = self.runner.get_pose(processed_frame) + return self._post_process_pose(processed_frame, **kwargs) + def _post_process_pose(self, processed_frame: np.ndarray, **kwargs) -> np.ndarray: + """Post-processes the frame and pose.""" # display image if display=True before correcting pose for cropping/resizing - if self.display is not None: - self.display.display_frame(frame, self.pose) + self.display.display_frame(processed_frame, self.pose) # if frame is cropped, convert pose coordinates to original frame coordinates - if self.resize is not None: - self.pose[:, :2] *= 1 / self.resize + self.pose[..., :2] *= 1 / self.resize if self.cropping is not None: - self.pose[:, 0] += self.cropping[0] - self.pose[:, 1] += self.cropping[2] + self.pose[..., 0] += self.cropping[0] + self.pose[..., 1] += self.cropping[2] if self.dynamic_cropping is not None: - self.pose[:, 0] += self.dynamic_cropping[0] - self.pose[:, 1] += self.dynamic_cropping[2] + self.pose[..., 0] += self.dynamic_cropping[0] + self.pose[..., 1] += self.dynamic_cropping[2] # process the pose - if self.processor: self.pose = self.processor.process(self.pose, **kwargs) return self.pose - def close(self): - """ Close tensorflow session - """ - - self.sess.close() - self.sess = None + def close(self) -> None: self.is_initialized = False + self.runner.close() if self.display is not None: self.display.destroy() diff --git a/dlclive/exceptions.py b/dlclive/exceptions.py index 5d7a1aa..13c7c88 100644 --- a/dlclive/exceptions.py +++ b/dlclive/exceptions.py @@ -7,12 +7,12 @@ class DLCLiveError(Exception): - """ Generic error type for incorrect use of the DLCLive class """ + """Generic error type for incorrect use of the DLCLive class""" pass class DLCLiveWarning(Warning): - """ Generic warning for incorrect use of the DLCLive class """ + """Generic warning for incorrect use of the DLCLive class""" pass diff --git a/dlclive/factory.py b/dlclive/factory.py new file mode 100644 index 0000000..0c22e22 --- /dev/null +++ b/dlclive/factory.py @@ -0,0 +1,56 @@ +"""Factory to build runners for DeepLabCut-Live inference""" +from __future__ import annotations + +from pathlib import Path +from typing import Literal + +from dlclive.core.runner import BaseRunner + + +def build_runner( + model_type: Literal["pytorch", "tensorflow", "base", "tensorrt", "lite"], + model_path: str | Path, + **kwargs, +) -> BaseRunner: + """ + + Parameters + ---------- + model_type: str, optional + Which model to use. For the PyTorch engine, options are [`pytorch`]. For the + TensorFlow engine, options are [`base`, `tensorrt`, `lite`]. + model_path: str, Path + Full path to exported model (created when `deeplabcut.export_model(...)` was + called). For PyTorch models, this is a single model file. For TensorFlow models, + this is a directory containing the model snapshots. + + kwargs: dict, optional + PyTorch Engine Kwargs: + + TensorFlow Engine Kwargs: + + Returns + ------- + + """ + if model_type.lower() == "pytorch": + from dlclive.pose_estimation_pytorch.runner import PyTorchRunner + + valid = {"device", "precision", "single_animal", "dynamic", "top_down_config"} + return PyTorchRunner(model_path, **filter_keys(valid, kwargs)) + + elif model_type.lower() in ("tensorflow", "base", "tensorrt", "lite"): + from dlclive.pose_estimation_tensorflow.runner import TensorFlowRunner + + if model_type.lower() == "tensorflow": + model_type = "base" + + valid = {"tf_config", "precision"} + return TensorFlowRunner(model_path, model_type, **filter_keys(valid, kwargs)) + + raise ValueError(f"Unknown model type: {model_type}") + + +def filter_keys(valid: set[str], kwargs: dict) -> dict: + """Filters the keys in kwargs, only keeping those in valid.""" + return {k: v for k, v in kwargs.items() if k in valid} diff --git a/dlclive/graph.py b/dlclive/graph.py index 0841b46..72b3b76 100644 --- a/dlclive/graph.py +++ b/dlclive/graph.py @@ -106,12 +106,13 @@ def get_output_tensors(graph): def get_input_tensor(graph): - input_tensor = str(graph.get_operations()[0].name) + ":0" return input_tensor -def extract_graph(graph, tf_config=None): +def extract_graph( + graph, tf_config=None +) -> tuple[tf.Session, tf.Tensor, list[tf.Tensor]]: """ Initializes a tensorflow session with the specified graph and extracts the model's inputs and outputs diff --git a/dlclive/live_inference.py b/dlclive/live_inference.py new file mode 100644 index 0000000..6db7597 --- /dev/null +++ b/dlclive/live_inference.py @@ -0,0 +1,366 @@ +""" +DeepLabCut Toolbox (deeplabcut.org) +© A. & M. Mathis Labs + +Licensed under GNU Lesser General Public License v3.0 +""" + +import csv +import os +import platform +import subprocess +import sys +import time + +import colorcet as cc +import cv2 +import h5py +import torch +from PIL import ImageColor +from pip._internal.operations import freeze + +from dlclive import VERSION, DLCLive + + +def get_system_info() -> dict: + """ + Returns a summary of system information relevant to running benchmarking. + + Returns + ------- + dict + A dictionary containing the following system information: + - host_name (str): Name of the machine. + - op_sys (str): Operating system. + - python (str): Path to the Python executable, indicating the conda/virtual environment in use. + - device_type (str): Type of device used ('GPU' or 'CPU'). + - device (list): List containing the name of the GPU or CPU brand. + - freeze (list): List of installed Python packages with their versions. + - python_version (str): Version of Python in use. + - git_hash (str or None): If installed from git repository, hash of HEAD commit. + - dlclive_version (str): Version of the DLCLive package. + """ + + # Get OS and host name + op_sys = platform.platform() + host_name = platform.node().replace(" ", "") + + # Get Python executable path + if platform.system() == "Windows": + host_python = sys.executable.split(os.path.sep)[-2] + else: + host_python = sys.executable.split(os.path.sep)[-3] + + # Try to get git hash if possible + git_hash = None + dlc_basedir = os.path.dirname(os.path.dirname(__file__)) + try: + git_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dlc_basedir) + .decode("utf-8") + .strip() + ) + except subprocess.CalledProcessError: + # Not installed from git repo, e.g., pypi + pass + + # Get device info (GPU or CPU) + if torch.cuda.is_available(): + dev_type = "GPU" + dev = [torch.cuda.get_device_name(torch.cuda.current_device())] + else: + from cpuinfo import get_cpu_info + + dev_type = "CPU" + dev = [get_cpu_info()["brand_raw"]] + + return { + "host_name": host_name, + "op_sys": op_sys, + "python": host_python, + "device_type": dev_type, + "device": dev, + "freeze": list(freeze.freeze()), + "python_version": sys.version, + "git_hash": git_hash, + "dlclive_version": VERSION, + } + + +def analyze_live_video( + model_path: str, + model_type: str, + device: str, + camera: float = 0, + experiment_name: str = "Test", + precision: str = "FP32", + snapshot: str = None, + display=True, + pcutoff=0.5, + display_radius=5, + resize=None, + cropping=None, # Adding cropping to the function parameters + dynamic=(False, 0.5, 10), + save_poses=False, + save_dir="model_predictions", + draw_keypoint_names=False, + cmap="bmy", + get_sys_info=True, + save_video=False, +): + """ + Analyzes a video to track keypoints using a DeepLabCut model, and optionally saves the keypoint data and the labeled video. + + Parameters + ---------- + model_path : str + Path to the DeepLabCut model. + model_type : str + Type of the model (e.g., 'onnx'). + device : str + Device to run the model on ('cpu' or 'cuda'). + camera : float, default=0 (webcam) + The camera to record the live video from. + experiment_name : str, default = "Test" + Prefix to label generated pose and video files + precision : str, optional, default='FP32' + Precision type for the model ('FP32' or 'FP16'). + display : bool, optional, default=True + Whether to display frame with labelled key points. + pcutoff : float, optional, default=0.5 + Probability cutoff below which keypoints are not visualized. + display_radius : int, optional, default=5 + Radius of circles drawn for keypoints on video frames. + resize : tuple of int (width, height) or None, optional + Resize dimensions for video frames. e.g. if resize = 0.5, the video will be processed in half the original size. If None, no resizing is applied. + cropping : list of int or None, optional + Cropping parameters [x1, x2, y1, y2] in pixels. If None, no cropping is applied. + dynamic : tuple, optional, default=(False, 0.5, 10) (True/false), p cutoff, margin) + Parameters for dynamic cropping. If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. pcutoff: + x, y = map(int, this_pose[j, :2]) + cv2.circle( + frame, + center=(x, y), + radius=display_radius, + color=colors[j], + thickness=-1, + ) + + if draw_keypoint_names: + cv2.putText( + frame, + text=bodyparts[j], + org=(x + 10, y), + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.5, + color=colors[j], + thickness=1, + lineType=cv2.LINE_AA, + ) + if save_video: + vwriter.write(image=frame) + frame_index += 1 + + # Display the frame + if display: + cv2.imshow("DLCLive", frame) + + # Add key press check for quitting + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + cap.release() + + if save_video: + vwriter.release() + + cv2.destroyAllWindows() + + if get_sys_info: + print(get_system_info()) + + if save_poses: + save_poses_to_files( + experiment_name, save_dir, bodyparts, poses, timestamp=timestamp + ) + + return poses, times + + +def save_poses_to_files(experiment_name, save_dir, bodyparts, poses, timestamp): + """ + Saves the detected keypoint poses from the video to CSV and HDF5 files. + + Parameters + ---------- + video_path : str + Path to the analyzed video file. + save_dir : str + Directory where the pose data files will be saved. + bodyparts : list of str + List of body part names corresponding to the keypoints. + poses : list of dict + List of dictionaries containing frame numbers and corresponding pose data. + + Returns + ------- + None + """ + base_filename = os.path.splitext(os.path.basename(experiment_name))[0] + csv_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.csv") + h5_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.h5") + + # Save to CSV + with open(csv_save_path, mode="w", newline="") as file: + writer = csv.writer(file) + header = ["frame"] + [ + f"{bp}_{axis}" for bp in bodyparts for axis in ["x", "y", "confidence"] + ] + writer.writerow(header) + for entry in poses: + frame_num = entry["frame"] + pose_data = entry["pose"]["poses"][0][0] + # Convert tensor data to numeric values + row = [frame_num] + [ + item.item() if isinstance(item, torch.Tensor) else item + for kp in pose_data + for item in kp + ] + writer.writerow(row) + + # Save to HDF5 + with h5py.File(h5_save_path, "w") as hf: + hf.create_dataset(name="frames", data=[entry["frame"] for entry in poses]) + for i, bp in enumerate(bodyparts): + hf.create_dataset( + name=f"{bp}_x", + data=[ + ( + entry["pose"]["poses"][0][0][i, 0].item() + if isinstance(entry["pose"]["poses"][0][0][i, 0], torch.Tensor) + else entry["pose"]["poses"][0][0][i, 0] + ) + for entry in poses + ], + ) + hf.create_dataset( + name=f"{bp}_y", + data=[ + ( + entry["pose"]["poses"][0][0][i, 1].item() + if isinstance(entry["pose"]["poses"][0][0][i, 1], torch.Tensor) + else entry["pose"]["poses"][0][0][i, 1] + ) + for entry in poses + ], + ) + hf.create_dataset( + name=f"{bp}_confidence", + data=[ + ( + entry["pose"]["poses"][0][0][i, 2].item() + if isinstance(entry["pose"]["poses"][0][0][i, 2], torch.Tensor) + else entry["pose"]["poses"][0][0][i, 2] + ) + for entry in poses + ], + ) diff --git a/dlclive/pose_estimation_pytorch/__init__.py b/dlclive/pose_estimation_pytorch/__init__.py new file mode 100644 index 0000000..ae45b8d --- /dev/null +++ b/dlclive/pose_estimation_pytorch/__init__.py @@ -0,0 +1,15 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from dlclive.pose_estimation_pytorch.dynamic_cropping import ( + DynamicCropper, + TopDownDynamicCropper, +) +from dlclive.pose_estimation_pytorch.models import PoseModel diff --git a/dlclive/pose_estimation_pytorch/data/__init__.py b/dlclive/pose_estimation_pytorch/data/__init__.py new file mode 100644 index 0000000..2fabb6e --- /dev/null +++ b/dlclive/pose_estimation_pytorch/data/__init__.py @@ -0,0 +1,10 @@ +""" +DeepLabCut Toolbox (deeplabcut.org) +© A. & M. Mathis Labs + +Licensed under GNU Lesser General Public License v3.0 +""" +from dlclive.pose_estimation_pytorch.data.image import ( + top_down_crop, + top_down_crop_torch, +) diff --git a/dlclive/pose_estimation_pytorch/data/image.py b/dlclive/pose_estimation_pytorch/data/image.py new file mode 100644 index 0000000..c6f1705 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/data/image.py @@ -0,0 +1,140 @@ +""" +DeepLabCut Toolbox (deeplabcut.org) +© A. & M. Mathis Labs + +Licensed under GNU Lesser General Public License v3.0 +""" + +import cv2 +import numpy as np +import torch +from torchvision.transforms import functional as F + + +def fix_bbox_aspect_ratio( + bbox: tuple[float, float, float, float] | np.ndarray | torch.Tensor, + margin: int | float, + out_w: int, + out_h: int, +) -> tuple[int, int, int, int]: + x, y, w, h = bbox + cx = x + w / 2 + cy = y + h / 2 + w += 2 * margin + h += 2 * margin + + input_ratio = w / h + output_ratio = out_w / out_h + if input_ratio > output_ratio: # h/w < h0/w0 => h' = w * h0/w0 + h = w / output_ratio + elif input_ratio < output_ratio: # w/h < w0/h0 => w' = h * w0/h0 + w = h * output_ratio + + # cx,cy,w,h will now give the right ratio -> check if padding is needed + x1, y1 = int(round(cx - (w / 2))), int(round(cy - (h / 2))) + x2, y2 = int(round(cx + (w / 2))), int(round(cy + (h / 2))) + + return x1, y1, x2, y2 + + +def crop_corners( + bbox: tuple[int, int, int, int], + image_size: tuple[int, int], + center_padding: bool = True, +) -> tuple[int, int, int, int, int, int, int, int]: + """""" + x1, y1, x2, y2 = bbox + img_w, img_h = image_size + + # pad symmetrically - compute total padding across axis + pad_left, pad_right, pad_top, pad_bottom = 0, 0, 0, 0 + if x1 < 0: + pad_left = -x1 + x1 = 0 + if x2 > img_w: + pad_right = x2 - img_w + x2 = img_w + if y1 < 0: + pad_top = -y1 + y1 = 0 + if y2 > img_h: + pad_bottom = y2 - img_h + y2 = img_h + + pad_x = pad_left + pad_right + pad_y = pad_top + pad_bottom + if center_padding: + pad_left = pad_x // 2 + pad_top = pad_y // 2 + + return x1, y1, x2, y2, pad_left, pad_top, pad_x, pad_y + + +def top_down_crop( + image: np.ndarray | torch.Tensor, + bbox: tuple[float, float, float, float] | np.ndarray | torch.Tensor, + output_size: tuple[int, int], + margin: int = 0, + center_padding: bool = False, +) -> tuple[np.array, tuple[int, int], tuple[float, float]]: + """ + Crops images around bounding boxes for top-down pose estimation. Computes offsets so + that coordinates in the original image can be mapped to the cropped one; + + x_cropped = (x - offset_x) / scale_x + x_cropped = (y - offset_y) / scale_y + + Bounding boxes are expected to be in COCO-format (xywh). + + Args: + image: (h, w, c) the image to crop + bbox: (4,) the bounding box to crop around + output_size: the (width, height) of the output cropped image + margin: a margin to add around the bounding box before cropping + center_padding: whether to center the image in the padding if any is needed + + Returns: + cropped_image, (offset_x, offset_y), (scale_x, scale_y) + """ + image_h, image_w, c = image.shape + img_size = (image_w, image_h) + out_w, out_h = output_size + + bbox = fix_bbox_aspect_ratio(bbox, margin, out_w, out_h) + x1, y1, x2, y2, pad_left, pad_top, pad_x, pad_y = crop_corners( + bbox, img_size, center_padding + ) + w, h = x2 - x1, y2 - y1 + crop_w, crop_h = w + pad_x, h + pad_y + + # crop the pixels we care about + image_crop = np.zeros((crop_h, crop_w, c), dtype=image.dtype) + image_crop[pad_top : pad_top + h, pad_left : pad_left + w] = image[y1:y2, x1:x2] + + # resize the cropped image + image = cv2.resize(image_crop, (out_w, out_h), interpolation=cv2.INTER_LINEAR) + + # compute scale and offset + offset = x1 - pad_left, y1 - pad_top + scale = crop_w / out_w, crop_h / out_h + return image, offset, scale + + +def top_down_crop_torch( + image: torch.Tensor, + bbox: tuple[float, float, float, float] | torch.Tensor, + output_size: tuple[int, int], + margin: int = 0, +) -> tuple[torch.Tensor, tuple[int, int], tuple[float, float]]: + """""" + out_w, out_h = output_size + + x1, y1, x2, y2 = fix_bbox_aspect_ratio(bbox, margin, out_w, out_h) + h, w = x2 - x1, y2 - y1 + + F.resized_crop(image, y1, x1, h, w, [out_h, out_w]) + + scale = w / out_w, h / out_h + offset = x1, y1 + crop = F.resized_crop(image, y1, x1, h, w, [out_h, out_w]) + return crop, offset, scale diff --git a/dlclive/pose_estimation_pytorch/dynamic_cropping.py b/dlclive/pose_estimation_pytorch/dynamic_cropping.py new file mode 100644 index 0000000..ae5991f --- /dev/null +++ b/dlclive/pose_estimation_pytorch/dynamic_cropping.py @@ -0,0 +1,546 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Modules to dynamically crop individuals out of videos to improve video analysis""" +from __future__ import annotations + +import math +from dataclasses import dataclass, field +from typing import Optional + +import torch +import torchvision.transforms.functional as F + + +@dataclass +class DynamicCropper: + """ + If the state is true, then dynamic cropping will be performed. That means that + if an object is detected (i.e. any body part > detection threshold), then object + boundaries are computed according to the smallest/largest x position and + smallest/largest y position of all body parts. This window is expanded by the + margin and from then on only the posture within this crop is analyzed (until the + object is lost, i.e. < detection threshold). The current position is utilized for + updating the crop window for the next frame (this is why the margin is important + and should be set large enough given the movement of the animal). + + Attributes: + threshold: float + The threshold score for bodyparts above which an individual is deemed to + have been detected. + margin: int + The margin used to expand an individuals bounding box before cropping it. + + Examples: + >>> import torch.nn as nn + >>> + >>> model: nn.Module # pose estimation model + >>> frames: torch.Tensor # shape (num_frames, 3, H, W) + >>> + >>> dynamic = DynamicCropper(threshold=0.6, margin=25) + >>> predictions = [] + >>> for image in frames: + >>> image = dynamic.crop(image) + >>> + >>> outputs = model(image) + >>> preds = model.get_predictions(outputs) + >>> pose = preds["bodypart"]["poses"] + >>> + >>> dynamic.update(pose) + >>> predictions.append(pose) + >>> + """ + + threshold: float + margin: int + _crop: tuple[int, int, int, int] | None = field(default=None, repr=False) + _shape: tuple[int, int] | None = field(default=None, repr=False) + + def crop(self, image: torch.Tensor) -> torch.Tensor: + """Crops an input image according to the dynamic cropping parameters. + + Args: + image: The image to crop, of shape (1, C, H, W). + + Returns: + The cropped image of shape (1, C, H', W'), where [H', W'] is the size of + the crop. + + Raises: + RuntimeError: if there is not exactly one image in the batch to crop, or if + `crop` was previously called with an image of a different width or + height. + """ + if len(image) != 1: + raise RuntimeError( + "DynamicCropper can only be used with batch size 1 (found image " + f"shape: {image.shape})" + ) + + if self._shape is None: + self._shape = image.shape[3], image.shape[2] + + if image.shape[3] != self._shape[0] or image.shape[2] != self._shape[1]: + raise RuntimeError( + "All frames must have the same shape; The first frame had (W, H) " + f"{self._shape} but the current frame has shape {image.shape}." + ) + + if self._crop is None: + return image + + x0, y0, x1, y1 = self._crop + return image[:, :, y0:y1, x0:x1] + + def update(self, pose: torch.Tensor) -> torch.Tensor: + """Updates the dynamic crop according to the pose model output. + + Uses the pose predicted by the model to update the dynamic crop parameters for + the next frame. Scales the pose predicted in the cropped image back to the + original image space and returns it. + + Args: + pose: The pose that was predicted by the pose estimation model in the + cropped image coordinate space. + + Returns: + The pose, with coordinates updated to the full image space. + """ + if self._shape is None: + raise RuntimeError(f"You must call `crop` before calling `update`.") + + # offset the pose to the original image space + offset_x, offset_y = 0, 0 + if self._crop is not None: + offset_x, offset_y = self._crop[:2] + pose[..., 0] = pose[..., 0] + offset_x + pose[..., 1] = pose[..., 1] + offset_y + + # check whether keypoints can be used for dynamic cropping + keypoints = pose[..., :3].reshape(-1, 3) + keypoints = keypoints[~torch.any(torch.isnan(keypoints), dim=1)] + if len(keypoints) == 0: + self.reset() + return pose + + mask = keypoints[:, 2] >= self.threshold + if torch.all(~mask): + self.reset() + return pose + + # set the crop coordinates + x0 = self._min_value(keypoints[:, 0], self._shape[0]) + x1 = self._max_value(keypoints[:, 0], self._shape[0]) + y0 = self._min_value(keypoints[:, 1], self._shape[1]) + y1 = self._max_value(keypoints[:, 1], self._shape[1]) + crop_w, crop_h = x1 - x0, y1 - y0 + if crop_w == 0 or crop_h == 0: + self.reset() + else: + self._crop = x0, y0, x1, y1 + + return pose + + def reset(self) -> None: + """Resets the DynamicCropper to not crop the next frame""" + self._crop = None + + @staticmethod + def build( + dynamic: bool, threshold: float, margin: int + ) -> Optional["DynamicCropper"]: + """Builds the DynamicCropper based on the given parameters + + Args: + dynamic: Whether dynamic cropping should be used + threshold: The threshold score for bodyparts above which an individual is + deemed to have been detected. + margin: The margin used to expand an individuals bounding box before + cropping it. + + Returns: + None if dynamic is False + DynamicCropper to use if dynamic is True + """ + if not dynamic: + return None + + return DynamicCropper(threshold, margin) + + def _min_value(self, coordinates: torch.Tensor, maximum: int) -> int: + """Returns: min(coordinates - margin), clipped to [0, maximum]""" + return self._clip( + int(math.floor(torch.min(coordinates).item() - self.margin)), + maximum, + ) + + def _max_value(self, coordinates: torch.Tensor, maximum: int) -> int: + """Returns: max(coordinates + margin), clipped to [0, maximum]""" + return self._clip( + int(math.ceil(torch.max(coordinates).item() + self.margin)), + maximum, + ) + + def _clip(self, value: int, maximum: int) -> int: + """Returns: The value clipped to [0, maximum]""" + return min(max(value, 0), maximum) + + +class TopDownDynamicCropper(DynamicCropper): + """Dynamic cropping for top-down models used on single animal videos. + + The `TopDownDynamicCropper` can be used instead of an object detector to analyze + videos **containing a single animal** with top-down models. + + At frame 0, the full frame is split into (n, m) image patches, with a given overlap + between the patches. Patches are then + - Resized to the input size required by the model with a top-down crop. + - Stacked into a batch and given to the pose estimation model + - The output poses for each patch are post-processed: the patch containing the + highest average score prediction is selected as the patch containing the + individual, and the pose from that patch is selected as the predicted pose. + + At frame n, one of two things can happen: + - If the individual was successfully detected at frame n - 1, a bounding box + is generated from the predicted pose and used as the bounding box for the + next frame. + - If the individual was not detected at frame n - 1, patches are cropped as in + frame 0 and the pose selected as in frame 0 + + An individual is considered to be successfully detected if: + - at least `min_hq_keypoints` keypoint have scores above the `threshold` + + The bounding box is generated from the keypoints (either from all keypoints or only + the ones above the threshold) with a margin around the keypoints. If the bounding + box is smaller than a set minimum size, it is expanded to that size. + + Args: + top_down_crop_size: The (width, height) of to resize crops to. + patch_counts: The number of patches along the (width, height) of the images when + no crop is found. + patch_overlap: The amount of overlapping pixels between adjacent patches. + min_bbox_size: The minimum (width, height) for a detected bounding box. If the + bounding box computed from the keypoints is smaller than this value, it + will be expanded to these values. + threshold: The threshold score for bodyparts above which an individual is + considered to be detected. + margin: The margin to add around keypoints when generating bounding boxes. + min_hq_keypoints: The minimum number of keypoints above the threshold required + for the individual to be considered detected and a bounding box to be + computed from the pose. + bbox_from_hq: If True, only keypoints above the score threshold will be used + to compute the bounding boxes. + store_crops: Useful for debugging. When True, all crops are stored in the + `crop_history` attribute. + **kwargs: Key-word arguments passed to the DynamicCropper base class. + + Attributes: + min_bbox_size: tuple[int, int]. The minimum (width, height) for a detected + bounding box. If the bounding box computed from the keypoints is smaller + than this value, it will be expanded to these values. + min_hq_keypoints: int. The minimum number of keypoints above the threshold + required for the individual to be considered detected and a bounding box to + be computed from the pose. + bbox_from_hq: bool. If True, only keypoints above the score threshold will be + used to compute the bounding boxes. + store_crops: bool. Useful for debugging. When True, all crops are stored in the + `crop_history` attribute. + crop_history: list[list[tuple[int, int, int, int]]. Empty list if `store_crops` + is False. Every time `crop` is called, a list is appended to the + `crop_history` attribute. This list is empty if no crop was used for the + frame, otherwise a list containing a single (x, y, w, h) tuple is appended. + """ + + def __init__( + self, + top_down_crop_size: tuple[int, int] = (256, 256), + patch_counts: tuple[int, int] = (3, 2), + patch_overlap: int = 50, + min_bbox_size: tuple[int, int] = (50, 50), + threshold: float = 0.25, + margin: int = 10, + min_hq_keypoints: int = 2, + bbox_from_hq: bool = True, + store_crops: bool = False, + **kwargs, + ) -> None: + super().__init__(threshold=threshold, margin=margin, **kwargs) + self.top_down_crop_size = top_down_crop_size + self.min_bbox_size = min_bbox_size + self.min_hq_keypoints = min_hq_keypoints + self.bbox_from_hq = bbox_from_hq + + self._patch_counts = patch_counts + self._patch_overlap = patch_overlap + self._patches = [] + self._patch_offsets = [] + self._td_ratio = self.top_down_crop_size[0] / self.top_down_crop_size[1] + + self.crop_history = [] + self.store_crops = store_crops + + def patch_counts(self) -> tuple[int, int]: + """Returns: the number of patches created for an image.""" + return self._patch_counts + + def num_patches(self) -> int: + """Returns: the total number of patches created for an image.""" + return self._patch_counts[0] * self._patch_counts[1] + + def crop(self, image: torch.Tensor) -> torch.Tensor: + """Crops an input image according to the dynamic cropping parameters. + + Args: + image: The image to crop, of shape (1, C, H, W). + + Returns: + The cropped image of shape (B, C, H', W'), where [H', W'] is the size of + the crop. + + Raises: + RuntimeError: if there is not exactly one image in the batch to crop, or if + `crop` was previously called with an image of a different W or H. + """ + if len(image) != 1: + raise RuntimeError( + "DynamicCropper can only be used with batch size 1 (found image " + f"shape: {image.shape})" + ) + + if self._shape is None: + self._shape = image.shape[3], image.shape[2] + self._patches = self.generate_patches() + + if image.shape[3] != self._shape[0] or image.shape[2] != self._shape[1]: + raise RuntimeError( + "All frames must have the same shape; The first frame had (W, H) " + f"{self._shape} but the current frame has shape {image.shape}." + ) + + if self._crop is None: + if self.store_crops: + self.crop_history.append([]) + return self._crop_patches(image) + + if self.store_crops: + self.crop_history.append([self._crop]) + + return self._crop_bounding_box(image, self._crop) + + def update(self, pose: torch.Tensor) -> torch.Tensor: + """Updates the dynamic crop according to the pose model output. + + Uses the pose predicted by the model to update the dynamic crop parameters for + the next frame. Scales the pose predicted in the cropped image back to the + original image space and returns it. + + Args: + pose: The pose that was predicted by the pose estimation model in the + cropped image coordinate space. + + Returns: + The pose, with coordinates updated to the full image space. + """ + if self._shape is None: + raise RuntimeError(f"You must call `crop` before calling `update`.") + + # check whether this was a patched crop + batch_size = pose.shape[0] + if batch_size > 1: + pose = self._extract_best_patch(pose) + + if self._crop is None: + raise RuntimeError( + "The _crop should never be `None` when `update` is called. Ensure you " + "always alternate between `crop` and `update`." + ) + + # offset and rescale the pose to the original image space + out_w, out_h = self.top_down_crop_size + offset_x, offset_y, w, h = self._crop + scale_x, scale_y = w / out_w, h / out_h + pose[..., 0] = (pose[..., 0] * scale_x) + offset_x + pose[..., 1] = (pose[..., 1] * scale_y) + offset_y + pose[..., 0] = torch.clip(pose[..., 0], 0, self._shape[0]) + pose[..., 1] = torch.clip(pose[..., 1], 0, self._shape[1]) + + # check whether keypoints can be used for dynamic cropping + keypoints = pose[..., :3].reshape(-1, 3) + keypoints = keypoints[~torch.any(torch.isnan(keypoints), dim=1)] + if len(keypoints) == 0: + self.reset() + return pose + + mask = keypoints[:, 2] >= self.threshold + if torch.sum(mask) < self.min_hq_keypoints: + self.reset() + return pose + + if self.bbox_from_hq: + keypoints = keypoints[mask] + + # set the crop coordinates + x0 = self._min_value(keypoints[:, 0], self._shape[0]) + x1 = self._max_value(keypoints[:, 0], self._shape[0]) + y0 = self._min_value(keypoints[:, 1], self._shape[1]) + y1 = self._max_value(keypoints[:, 1], self._shape[1]) + crop_w, crop_h = x1 - x0, y1 - y0 + if crop_w == 0 or crop_h == 0: + self.reset() + else: + self._crop = self._prepare_bounding_box(x0, y0, x1, y1) + + return pose + + def _prepare_bounding_box( + self, x1: int, y1: int, x2: int, y2: int + ) -> tuple[int, int, int, int]: + """Prepares the bounding box for cropping. + + Adds a margin around the bounding box, then transforms it into the target aspect + ratio required for crops given as inputs to the model. + + Args: + x1: The x coordinate for the top-left corner of the bounding box. + y1: The y coordinate for the top-left corner of the bounding box. + x2: The x coordinate for the bottom-right corner of the bounding box. + y2: The y coordinate for the bottom-right corner of the bounding box. + + Returns: + The (x, y, w, h) coordinates for the prepared bounding box. + """ + x1 -= self.margin + x2 += self.margin + y1 -= self.margin + y2 += self.margin + w, h = x2 - x1, y2 - y1 + cx, cy = x1 + w / 2, y1 + h / 2 + + input_ratio = w / h + if input_ratio > self._td_ratio: # h/w < h0/w0 => h' = w * h0/w0 + h = w / self._td_ratio + elif input_ratio < self._td_ratio: # w/h < w0/h0 => w' = h * w0/h0 + w = h * self._td_ratio + + x1, y1 = int(round(cx - (w / 2))), int(round(cy - (h / 2))) + w, h = max(int(w), self.min_bbox_size[0]), max(int(h), self.min_bbox_size[1]) + return x1, y1, w, h + + def _crop_bounding_box( + self, + image: torch.Tensor, + bbox: tuple[int, int, int, int], + ) -> torch.Tensor: + """Applies a top-down crop to an image given a bounding box. + + Args: + image: The image to crop, of shape (1, C, H, W). + bbox: The bounding box to crop out of the image. + + Returns: + The cropped and resized image. + """ + x1, y1, w, h = bbox + out_w, out_h = self.top_down_crop_size + return F.resized_crop(image, y1, x1, h, w, [out_h, out_w]) + + def _crop_patches(self, image: torch.Tensor) -> torch.Tensor: + """Crops patches from the image. + + Args: + image: The image to crop patches from, of shape (1, C, H, W). + + Returns: + The patches, of shape (B, C, H', W'), where [H', W'] is the crop size. + """ + patches = [self._crop_bounding_box(image, patch) for patch in self._patches] + return torch.cat(patches, dim=0) + + def _extract_best_patch(self, pose: torch.Tensor) -> torch.Tensor: + """Extracts the best pose prediction from patches. + + Args: + pose: The predicted pose, of shape (b, num_idv, num_kpt, 3). The number of + individuals must be 1. + + Returns: + The selected pose, of shape [1, N, K, 3] + """ + # check that only 1 prediction was made in each image + if pose.shape[1] != 1: + raise ValueError( + "The TopDownDynamicCropper can only be used with models predicting " + f"a single individual per image. Found {pose.shape[0]} " + f"predictions." + ) + + # compute the score for each individual + idv_scores = torch.mean(pose[:, 0, :, 2], dim=1) + + # get the index of the best patch + best_patch = torch.argmax(idv_scores) + + # set the crop to the one used for the best patch + self._crop = self._patches[best_patch] + + return pose[best_patch : best_patch + 1] + + def generate_patches(self) -> list[tuple[int, int, int, int]]: + """Generates patch coordinates for splitting an image. + + Returns: + A list of patch coordinates as tuples (x0, y0, x1, y1). + """ + patch_xs = self.split_array( + self._shape[0], self._patch_counts[0], self._patch_overlap + ) + patch_ys = self.split_array( + self._shape[1], self._patch_counts[1], self._patch_overlap + ) + + patches = [] + for y0, y1 in patch_ys: + for x0, x1 in patch_xs: + patches.append(self._prepare_bounding_box(x0, y0, x1, y1)) + + return patches + + @staticmethod + def split_array(size: int, n: int, overlap: int) -> list[tuple[int, int]]: + """ + Splits an array into n segments of equal size, where the overlap between each + segment is at least a given value. + + Args: + size: The size of the array. + n: The number of segments to split the array into. + overlap: The minimum overlap between each segment. + + Returns: + (start_index, end_index) pairs for each segment. The end index is exclusive. + """ + if n < 1: + raise ValueError(f"Array must be split into at least 1 segment. Found {n}.") + + # FIXME - auto-correct the overlap to spread it out more evenly + padded_size = size + (n - 1) * overlap + segment_size = (padded_size // n) + (padded_size % n > 0) + segments = [] + end = overlap + for i in range(n): + start = end - overlap + end = start + segment_size + if end > size: + end = size + start = end - segment_size + + segments.append((start, end)) + + return segments diff --git a/dlclive/pose_estimation_pytorch/models/__init__.py b/dlclive/pose_estimation_pytorch/models/__init__.py new file mode 100644 index 0000000..edd4e27 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/__init__.py @@ -0,0 +1,9 @@ +""" +DeepLabCut Toolbox (deeplabcut.org) +© A. & M. Mathis Labs + +Licensed under GNU Lesser General Public License v3.0 +""" + +from dlclive.pose_estimation_pytorch.models.model import PoseModel +from dlclive.pose_estimation_pytorch.models.detectors import DETECTORS, BaseDetector diff --git a/dlclive/pose_estimation_pytorch/models/backbones/__init__.py b/dlclive/pose_estimation_pytorch/models/backbones/__init__.py new file mode 100644 index 0000000..0d32951 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/backbones/__init__.py @@ -0,0 +1,17 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from dlclive.pose_estimation_pytorch.models.backbones.base import ( + BACKBONES, + BaseBackbone, +) +from dlclive.pose_estimation_pytorch.models.backbones.cspnext import CSPNeXt +from dlclive.pose_estimation_pytorch.models.backbones.hrnet import HRNet +from dlclive.pose_estimation_pytorch.models.backbones.resnet import DLCRNet, ResNet diff --git a/dlclive/pose_estimation_pytorch/models/backbones/base.py b/dlclive/pose_estimation_pytorch/models/backbones/base.py new file mode 100644 index 0000000..4bc6b4f --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/backbones/base.py @@ -0,0 +1,141 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import logging +import shutil +from abc import ABC, abstractmethod +from pathlib import Path + +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download + +from dlclive.pose_estimation_pytorch.models.registry import build_from_cfg, Registry + +BACKBONES = Registry("backbones", build_func=build_from_cfg) + + +class BaseBackbone(ABC, nn.Module): + """Base Backbone class for pose estimation. + + Attributes: + stride: the stride for the backbone + freeze_bn_weights: freeze weights of batch norm layers during training + freeze_bn_stats: freeze stats of batch norm layers during training + """ + + def __init__( + self, + stride: int | float, + freeze_bn_weights: bool = True, + freeze_bn_stats: bool = True, + ): + super().__init__() + self.stride = stride + self.freeze_bn_weights = freeze_bn_weights + self.freeze_bn_stats = freeze_bn_stats + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Abstract method for the forward pass through the backbone. + + Args: + x: Input tensor of shape (batch_size, channels, height, width). + + Returns: + a feature map for the input, of shape (batch_size, c', h', w') + """ + pass + + def freeze_batch_norm_layers(self) -> None: + """Freezes batch norm layers + + Running mean + var are always given to F.batch_norm, except when the layer is + in `train` mode and track_running_stats is False, see + https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html + So to 'freeze' the running stats, the only way is to set the layer to "eval" + mode. + """ + for module in self.modules(): + if isinstance(module, nn.BatchNorm2d): + if self.freeze_bn_weights: + module.weight.requires_grad = False + module.bias.requires_grad = False + if self.freeze_bn_stats: + module.eval() + + def train(self, mode: bool = True) -> None: + """Sets the module in training or evaluation mode. + + Args: + mode: whether to set training mode (True) or evaluation mode (False) + """ + super().train(mode) + if self.freeze_bn_weights or self.freeze_bn_stats: + self.freeze_batch_norm_layers() + + +class HuggingFaceWeightsMixin: + """Mixin for backbones where the pretrained weights are stored on HuggingFace""" + + def __init__( + self, + backbone_weight_folder: str | Path | None = None, + repo_id: str = "DeepLabCut/DeepLabCut-Backbones", + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + if backbone_weight_folder is None: + backbone_weight_folder = Path(__file__).parent / "pretrained_weights" + else: + backbone_weight_folder = Path(backbone_weight_folder).resolve() + + self.backbone_weight_folder = backbone_weight_folder + self.repo_id = repo_id + + def download_weights(self, filename: str, force: bool = False) -> Path: + """Downloads the backbone weights from the HuggingFace repo + + Args: + filename: The name of the model file to download in the repo. + force: Whether to re-download the file if it already exists locally. + + Returns: + The path to the model snapshot. + """ + model_path = self.backbone_weight_folder / filename + if model_path.exists(): + if not force: + return model_path + model_path.unlink() + + logging.info(f"Downloading the pre-trained backbone to {model_path}") + self.backbone_weight_folder.mkdir(exist_ok=True, parents=False) + output_path = Path( + hf_hub_download( + self.repo_id, filename, cache_dir=self.backbone_weight_folder + ) + ) + + # resolve gets the actual path if the output path is a symlink + output_path = output_path.resolve() + # move to the target path + output_path.rename(model_path) + + # delete downloaded artifacts + uid, rid = self.repo_id.split("/") + artifact_dir = self.backbone_weight_folder / f"models--{uid}--{rid}" + if artifact_dir.exists(): + shutil.rmtree(artifact_dir) + + return model_path diff --git a/dlclive/pose_estimation_pytorch/models/backbones/cspnext.py b/dlclive/pose_estimation_pytorch/models/backbones/cspnext.py new file mode 100644 index 0000000..681f2ba --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/backbones/cspnext.py @@ -0,0 +1,208 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Implementation of the CSPNeXt Backbone + +Based on the ``mmdetection`` CSPNeXt implementation. For more information, see: + + +For more details about this architecture, see `RTMDet: An Empirical Study of Designing +Real-Time Object Detectors`: https://arxiv.org/abs/1711.05101. +""" +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from dlclive.pose_estimation_pytorch.models.backbones.base import ( + BACKBONES, + BaseBackbone, + HuggingFaceWeightsMixin, +) +from dlclive.pose_estimation_pytorch.models.modules.csp import ( + CSPConvModule, + CSPLayer, + SPPBottleneck, +) + + +@dataclass(frozen=True) +class CSPNeXtLayerConfig: + """Configuration for a CSPNeXt layer""" + + in_channels: int + out_channels: int + num_blocks: int + add_identity: bool + use_spp: bool + + +@BACKBONES.register_module +class CSPNeXt(HuggingFaceWeightsMixin, BaseBackbone): + """CSPNeXt Backbone + + Args: + model_name: The model variant to build. If ``pretrained==True``, must be one of + the variants for which weights are available on HuggingFace (in the + `DeepLabCut/DeepLabCut-Backbones` hub, e.g. `cspnext_m`). + pretrained: Whether to load pretrained weights for the model. + arch: The model architecture to build. Must be one of the keys of the + ``CSPNeXt.ARCH`` attribute (e.g. `P5`, `P6`, ...). + expand_ratio: Ratio used to adjust the number of channels of the hidden layer. + deepen_factor: Number of blocks in each CSP layer is multiplied by this value. + widen_factor: Number of channels in each layer is multiplied by this value. + out_indices: The branch indices to output. If a tuple of integers, the outputs + are returned as a list of tensors. If a single integer, a tensor is returned + containing the configured index. + channel_attention: Add chanel attention to all stages + norm_layer: The type of normalization layer to use. + activation_fn: The type of activation function to use. + **kwargs: BaseBackbone kwargs. + """ + + ARCH: dict[str, list[CSPNeXtLayerConfig]] = { + "P5": [ + CSPNeXtLayerConfig(64, 128, 3, True, False), + CSPNeXtLayerConfig(128, 256, 6, True, False), + CSPNeXtLayerConfig(256, 512, 6, True, False), + CSPNeXtLayerConfig(512, 1024, 3, False, True), + ], + "P6": [ + CSPNeXtLayerConfig(64, 128, 3, True, False), + CSPNeXtLayerConfig(128, 256, 6, True, False), + CSPNeXtLayerConfig(256, 512, 6, True, False), + CSPNeXtLayerConfig(512, 768, 3, True, False), + CSPNeXtLayerConfig(768, 1024, 3, False, True), + ], + } + + def __init__( + self, + model_name: str = "cspnext_m", + pretrained: bool = False, + arch: str = "P5", + expand_ratio: float = 0.5, + deepen_factor: float = 0.67, + widen_factor: float = 0.75, + out_indices: int | tuple[int, ...] = -1, + channel_attention: bool = True, + norm_layer: str = "SyncBN", + activation_fn: str = "SiLU", + **kwargs, + ) -> None: + super().__init__(stride=32, **kwargs) + if arch not in self.ARCH: + raise ValueError( + f"Unknown `CSPNeXT` architecture: {arch}. Must be one of " + f"{self.ARCH.keys()}" + ) + + self.model_name = model_name + self.layer_configs = self.ARCH[arch] + self.stem_out_channels = self.layer_configs[0].in_channels + self.spp_kernel_sizes = (5, 9, 13) + + # stem has stride 2 + self.stem = nn.Sequential( + CSPConvModule( + in_channels=3, + out_channels=int(self.stem_out_channels * widen_factor // 2), + kernel_size=3, + padding=1, + stride=2, + norm_layer=norm_layer, + activation_fn=activation_fn, + ), + CSPConvModule( + in_channels=int(self.stem_out_channels * widen_factor // 2), + out_channels=int(self.stem_out_channels * widen_factor // 2), + kernel_size=3, + padding=1, + stride=1, + norm_layer=norm_layer, + activation_fn=activation_fn, + ), + CSPConvModule( + in_channels=int(self.stem_out_channels * widen_factor // 2), + out_channels=int(self.stem_out_channels * widen_factor), + kernel_size=3, + padding=1, + stride=1, + norm_layer=norm_layer, + activation_fn=activation_fn, + ), + ) + self.layers = ["stem"] + + for i, layer_cfg in enumerate(self.layer_configs): + layer_cfg: CSPNeXtLayerConfig + in_channels = int(layer_cfg.in_channels * widen_factor) + out_channels = int(layer_cfg.out_channels * widen_factor) + num_blocks = max(round(layer_cfg.num_blocks * deepen_factor), 1) + stage = [] + conv_layer = CSPConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + stage.append(conv_layer) + if layer_cfg.use_spp: + spp = SPPBottleneck( + out_channels, + out_channels, + kernel_sizes=self.spp_kernel_sizes, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + stage.append(spp) + + csp_layer = CSPLayer( + out_channels, + out_channels, + num_blocks=num_blocks, + add_identity=layer_cfg.add_identity, + expand_ratio=expand_ratio, + channel_attention=channel_attention, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + stage.append(csp_layer) + self.add_module(f"stage{i + 1}", nn.Sequential(*stage)) + self.layers.append(f"stage{i + 1}") + + self.single_output = isinstance(out_indices, int) + if self.single_output: + if out_indices == -1: + out_indices = len(self.layers) - 1 + out_indices = (out_indices,) + self.out_indices = out_indices + + if pretrained: + weights_filename = f"{model_name}.pt" + weights_path = self.download_weights(weights_filename, force=False) + snapshot = torch.load(weights_path, map_location="cpu", weights_only=True) + self.load_state_dict(snapshot["state_dict"]) + + def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]: + outs = [] + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if self.single_output: + return outs[-1] + + return tuple(outs) diff --git a/dlclive/pose_estimation_pytorch/models/backbones/hrnet.py b/dlclive/pose_estimation_pytorch/models/backbones/hrnet.py new file mode 100644 index 0000000..942399d --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/backbones/hrnet.py @@ -0,0 +1,122 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F + +from dlclive.pose_estimation_pytorch.models.backbones.base import ( + BACKBONES, + BaseBackbone, +) + + +@BACKBONES.register_module +class HRNet(BaseBackbone): + """HRNet backbone. + + This version returns high-resolution feature maps of size 1/4 * original_image_size. + This is obtained using bilinear interpolation and concatenation of all the outputs + of the HRNet stages. + + The model outputs 4 branches, with strides 4, 8, 16 and 32. + + Args: + stride: The stride of the HRNet. Should always be 4, except for custom models. + model_name: Any HRNet variant available through timm (e.g., 'hrnet_w32', + 'hrnet_w48'). See timm for more options. + pretrained: If True, loads the backbone with ImageNet pretrained weights from + timm. + interpolate_branches: Needed for DEKR. Instead of returning features from the + high-resolution branch, interpolates all other branches to the same shape + and concatenates them. + increased_channel_count: As described by timm, it "allows grabbing increased + channel count features using part of the classification head" (otherwise, + the default features are returned). + kwargs: BaseBackbone kwargs + + Attributes: + model: the HRNet model + """ + + def __init__( + self, + stride: int = 4, + model_name: str = "hrnet_w32", + pretrained: bool = False, + interpolate_branches: bool = False, + increased_channel_count: bool = False, + **kwargs, + ) -> None: + super().__init__(stride=stride, **kwargs) + self.model = _load_hrnet(model_name, pretrained, increased_channel_count) + self.interpolate_branches = interpolate_branches + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the HRNet backbone. + + Args: + x: Input tensor of shape (batch_size, channels, height, width). + + Returns: + the feature map + + Example: + >>> import torch + >>> from dlclive.models.backbones import HRNet + >>> backbone = HRNet(model_name='hrnet_w32', pretrained=False) + >>> x = torch.randn(1, 3, 256, 256) + >>> y = backbone(x) + """ + y_list = self.model(x) + if not self.interpolate_branches: + return y_list[0] + + x0_h, x0_w = y_list[0].size(2), y_list[0].size(3) + x = torch.cat( + [ + y_list[0], + F.interpolate(y_list[1], size=(x0_h, x0_w), mode="bilinear"), + F.interpolate(y_list[2], size=(x0_h, x0_w), mode="bilinear"), + F.interpolate(y_list[3], size=(x0_h, x0_w), mode="bilinear"), + ], + 1, + ) + return x + + +def _load_hrnet( + model_name: str, + pretrained: bool, + increased_channel_count: bool, +) -> nn.Module: + """Loads a TIMM HRNet model. + + Args: + model_name: Any HRNet variant available through timm (e.g., 'hrnet_w32', + 'hrnet_w48'). See timm for more options. + pretrained: If True, loads the backbone with ImageNet pretrained weights from + timm. + increased_channel_count: As described by timm, it "allows grabbing increased + channel count features using part of the classification head" (otherwise, + the default features are returned). + + Returns: + the HRNet model + """ + # First stem conv is used for stride 2 features, so only return branches 1-4 + return timm.create_model( + model_name, + pretrained=pretrained, + features_only=True, + feature_location="incre" if increased_channel_count else "", + out_indices=(1, 2, 3, 4), + ) diff --git a/dlclive/pose_estimation_pytorch/models/backbones/resnet.py b/dlclive/pose_estimation_pytorch/models/backbones/resnet.py new file mode 100644 index 0000000..f661159 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/backbones/resnet.py @@ -0,0 +1,151 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +import timm +import torch +import torch.nn as nn +from torchvision.transforms.functional import resize + +from dlclive.pose_estimation_pytorch.models.backbones.base import ( + BACKBONES, + BaseBackbone, +) + + +@BACKBONES.register_module +class ResNet(BaseBackbone): + """ResNet backbone. + + This class represents a typical ResNet backbone for pose estimation. + + Attributes: + model: the ResNet model + """ + + def __init__( + self, + model_name: str = "resnet50", + output_stride: int = 32, + pretrained: bool = False, + drop_path_rate: float = 0.0, + drop_block_rate: float = 0.0, + **kwargs, + ) -> None: + """Initialize the ResNet backbone. + + Args: + model_name: Name of the ResNet model to use, e.g., 'resnet50', 'resnet101' + output_stride: Output stride of the network, 32, 16, or 8. + pretrained: If True, initializes with ImageNet pretrained weights. + drop_path_rate: Stochastic depth drop-path rate + drop_block_rate: Drop block rate + kwargs: BaseBackbone kwargs + """ + super().__init__(stride=output_stride, **kwargs) + self.model = timm.create_model( + model_name, + output_stride=output_stride, + pretrained=pretrained, + drop_path_rate=drop_path_rate, + drop_block_rate=drop_block_rate, + ) + self.model.fc = nn.Identity() # remove the FC layer + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the ResNet backbone. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: Output tensor. + Example: + >>> import torch + >>> from dlclive.models.backbones import ResNet + >>> backbone = ResNet(model_name='resnet50', pretrained=False) + >>> x = torch.randn(1, 3, 256, 256) + >>> y = backbone(x) + + Expected Output Shape: + If input size is (batch_size, 3, shape_x, shape_y), the output shape + will be (batch_size, 3, shape_x//16, shape_y//16) + """ + return self.model.forward_features(x) + + +@BACKBONES.register_module +class DLCRNet(ResNet): + def __init__( + self, + model_name: str = "resnet50", + output_stride: int = 32, + pretrained: bool = True, + **kwargs, + ) -> None: + super().__init__(model_name, output_stride, pretrained, **kwargs) + self.interm_features = {} + self.model.layer1[2].register_forward_hook(self._get_features("bank1")) + self.model.layer2[2].register_forward_hook(self._get_features("bank2")) + self.conv_block1 = self._make_conv_block( + in_channels=512, out_channels=512, kernel_size=3, stride=2 + ) + self.conv_block2 = self._make_conv_block( + in_channels=512, out_channels=128, kernel_size=1, stride=1 + ) + self.conv_block3 = self._make_conv_block( + in_channels=256, out_channels=256, kernel_size=3, stride=2 + ) + self.conv_block4 = self._make_conv_block( + in_channels=256, out_channels=256, kernel_size=3, stride=2 + ) + self.conv_block5 = self._make_conv_block( + in_channels=256, out_channels=128, kernel_size=1, stride=1 + ) + + def _make_conv_block( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + momentum: float = 0.001, # (1 - decay) + ) -> torch.nn.Sequential: + return nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + + def _get_features(self, name): + def hook(model, input, output): + self.interm_features[name] = output.detach() + + return hook + + def forward(self, x): + out = super().forward(x) + + # Fuse intermediate features + bank_2_s8 = self.interm_features["bank2"] + bank_1_s4 = self.interm_features["bank1"] + bank_2_s16 = self.conv_block1(bank_2_s8) + bank_2_s16 = self.conv_block2(bank_2_s16) + bank_1_s8 = self.conv_block3(bank_1_s4) + bank_1_s16 = self.conv_block4(bank_1_s8) + bank_1_s16 = self.conv_block5(bank_1_s16) + # Resizing here is required to guarantee all shapes match, as + # Conv2D(..., padding='same') is invalid for strided convolutions. + h, w = out.shape[-2:] + bank_1_s16 = resize(bank_1_s16, [h, w], antialias=True) + bank_2_s16 = resize(bank_2_s16, [h, w], antialias=True) + + return torch.cat((bank_1_s16, bank_2_s16, out), dim=1) diff --git a/dlclive/pose_estimation_pytorch/models/detectors/__init__.py b/dlclive/pose_estimation_pytorch/models/detectors/__init__.py new file mode 100644 index 0000000..e9a99a6 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/detectors/__init__.py @@ -0,0 +1,16 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from dlclive.pose_estimation_pytorch.models.detectors.base import ( + DETECTORS, + BaseDetector, +) +from dlclive.pose_estimation_pytorch.models.detectors.fasterRCNN import FasterRCNN +from dlclive.pose_estimation_pytorch.models.detectors.ssd import SSDLite diff --git a/dlclive/pose_estimation_pytorch/models/detectors/base.py b/dlclive/pose_estimation_pytorch/models/detectors/base.py new file mode 100644 index 0000000..bcd9fb0 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/detectors/base.py @@ -0,0 +1,56 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from dlclive.pose_estimation_pytorch.models.registry import Registry, build_from_cfg + +DETECTORS = Registry("detectors", build_func=build_from_cfg) + + +class BaseDetector(ABC, nn.Module): + """ + Definition of the class BaseDetector object. + This is an abstract class defining the common structure and inference for detectors. + """ + + def __init__( + self, + freeze_bn_stats: bool = False, + freeze_bn_weights: bool = False, + pretrained: bool = False, + ) -> None: + super().__init__() + self.freeze_bn_stats = freeze_bn_stats + self.freeze_bn_weights = freeze_bn_weights + self._pretrained = pretrained + + @abstractmethod + def forward( + self, x: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None + ) -> list[dict[str, torch.Tensor]]: + """ + Forward pass of the detector + + Args: + x: images to be processed + targets: ground-truth boxes present in each images + + Returns: + losses: {'loss_name': loss_value} + detections: for each of the b images, {"boxes": bounding_boxes} + """ + pass diff --git a/dlclive/pose_estimation_pytorch/models/detectors/fasterRCNN.py b/dlclive/pose_estimation_pytorch/models/detectors/fasterRCNN.py new file mode 100644 index 0000000..f250b9a --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/detectors/fasterRCNN.py @@ -0,0 +1,74 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import torchvision.models.detection as detection + +from dlclive.pose_estimation_pytorch.models.detectors.base import DETECTORS +from dlclive.pose_estimation_pytorch.models.detectors.torchvision import ( + TorchvisionDetectorAdaptor, +) + + +@DETECTORS.register_module +class FasterRCNN(TorchvisionDetectorAdaptor): + """A FasterRCNN detector + + Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks + Ren, Shaoqing, Kaiming He, Ross Girshick, and Jian Sun. "Faster r-cnn: Towards + real-time object detection with region proposal networks." Advances in neural + information processing systems 28 (2015). + + This class is a wrapper of the torchvision implementation of a FasterRCNN (source: + https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py). + + Some of the available FasterRCNN variants (from fastest to most powerful): + - fasterrcnn_mobilenet_v3_large_fpn + - fasterrcnn_resnet50_fpn + - fasterrcnn_resnet50_fpn_v2 + + Args: + variant: The FasterRCNN variant to use (see all options at + https://pytorch.org/vision/stable/models.html#object-detection). + pretrained: Whether to load model weights pretrained on COCO + box_score_thresh: during inference, only return proposals with a classification + score greater than box_score_thresh + """ + + def __init__( + self, + freeze_bn_stats: bool = False, + freeze_bn_weights: bool = False, + variant: str = "fasterrcnn_mobilenet_v3_large_fpn", + pretrained: bool = False, + box_score_thresh: float = 0.01, + ) -> None: + if not variant.lower().startswith("fasterrcnn"): + raise ValueError( + "The version must start with `fasterrcnn`. See available models at " + "https://pytorch.org/vision/stable/models.html#object-detection" + ) + + super().__init__( + model=variant, + weights=("COCO_V1" if pretrained else None), + num_classes=None, + freeze_bn_stats=freeze_bn_stats, + freeze_bn_weights=freeze_bn_weights, + box_score_thresh=box_score_thresh, + ) + + # Modify the base predictor to output the correct number of classes + num_classes = 2 + in_features = self.model.roi_heads.box_predictor.cls_score.in_features + self.model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor( + in_features, num_classes + ) diff --git a/dlclive/pose_estimation_pytorch/models/detectors/ssd.py b/dlclive/pose_estimation_pytorch/models/detectors/ssd.py new file mode 100644 index 0000000..7140cd5 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/detectors/ssd.py @@ -0,0 +1,70 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import torchvision.models.detection as detection + +from dlclive.pose_estimation_pytorch.models.detectors.base import DETECTORS +from dlclive.pose_estimation_pytorch.models.detectors.torchvision import ( + TorchvisionDetectorAdaptor, +) + + +@DETECTORS.register_module +class SSDLite(TorchvisionDetectorAdaptor): + """An SSD object detection model""" + + def __init__( + self, + freeze_bn_stats: bool = False, + freeze_bn_weights: bool = False, + pretrained: bool = False, + pretrained_from_imagenet: bool = False, + box_score_thresh: float = 0.01, + ) -> None: + model_kwargs = dict(weights_backbone=None) + if pretrained_from_imagenet: + model_kwargs["weights_backbone"] = "IMAGENET1K_V2" + + super().__init__( + model="ssdlite320_mobilenet_v3_large", + weights=None, + num_classes=2, + freeze_bn_stats=freeze_bn_stats, + freeze_bn_weights=freeze_bn_weights, + box_score_thresh=box_score_thresh, + model_kwargs=model_kwargs, + ) + + if pretrained and not pretrained_from_imagenet: + weights = detection.SSDLite320_MobileNet_V3_Large_Weights.verify("COCO_V1") + state_dict = weights.get_state_dict(progress=False, check_hash=True) + for k, v in state_dict.items(): + key_parts = k.split(".") + if ( + len(key_parts) == 6 + and key_parts[0] == "head" + and key_parts[1] == "classification_head" + and key_parts[2] == "module_list" + and key_parts[4] == "1" + and key_parts[5] in ("weight", "bias") + ): + # number of COCO classes: 90 + background (91) + # number of DLC classes: 1 + background (2) + # -> only keep weights for the background + first class + + # future improvement: find best-suited class for the project + # and use those weights, instead of naively taking the first + all_classes_size = v.shape[0] + two_classes_size = 2 * (all_classes_size // 91) + state_dict[k] = v[:two_classes_size] + + self.model.load_state_dict(state_dict) diff --git a/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py new file mode 100644 index 0000000..72dd54b --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/detectors/torchvision.py @@ -0,0 +1,96 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Module to adapt torchvision detectors for DeepLabCut""" +from __future__ import annotations + +import torch +import torchvision.models.detection as detection + +from dlclive.pose_estimation_pytorch.models.detectors.base import BaseDetector + + +class TorchvisionDetectorAdaptor(BaseDetector): + """An adaptor for torchvision detectors + + This class is an adaptor for torchvision detectors to DeepLabCut detectors. Some of + the models (from fastest to most powerful) available are: + - ssdlite320_mobilenet_v3_large + - fasterrcnn_mobilenet_v3_large_fpn + - fasterrcnn_resnet50_fpn_v2 + + This class should not be used out-of-the-box. Subclasses (such as FasterRCNN or + SSDLite) should be used instead. + + The torchvision implementation does not allow to get both predictions and losses + with a single forward pass. Therefore, during evaluation only bounding box metrics + (mAP, mAR) are available for the test set. See validation loss issue: + - https://discuss.pytorch.org/t/compute-validation-loss-for-faster-rcnn/62333/12 + - https://stackoverflow.com/a/65347721 + + Args: + model: The torchvision model to use (see all options at + https://pytorch.org/vision/stable/models.html#object-detection). + weights: The weights to load for the model. If None, no pre-trained weights are + loaded. + num_classes: Number of classes that the model should output. If None, the number + of classes the model is pre-trained on is used. + freeze_bn_stats: Whether to freeze stats for BatchNorm layers. + freeze_bn_weights: Whether to freeze weights for BatchNorm layers. + box_score_thresh: during inference, only return proposals with a classification + score greater than box_score_thresh + """ + + def __init__( + self, + model: str, + weights: str | None = None, + num_classes: int | None = 2, + freeze_bn_stats: bool = False, + freeze_bn_weights: bool = False, + box_score_thresh: float = 0.01, + model_kwargs: dict | None = None, + ) -> None: + super().__init__( + freeze_bn_stats=freeze_bn_stats, + freeze_bn_weights=freeze_bn_weights, + pretrained=weights is not None, + ) + + # Load the model + model_fn = getattr(detection, model) + if model_kwargs is None: + model_kwargs = {} + + self.model = model_fn( + weights=weights, + box_score_thresh=box_score_thresh, + num_classes=num_classes, + **model_kwargs, + ) + + # See source: https://stackoverflow.com/a/65347721 + self.model.eager_outputs = lambda losses, detections: (losses, detections) + + def forward( + self, x: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None + ) -> list[dict[str, torch.Tensor]]: + """ + Forward pass of the torchvision detector + + Args: + x: images to be processed, of shape (b, c, h, w) + targets: ground-truth boxes present in the images + + Returns: + losses: {'loss_name': loss_value} + detections: for each of the b images, {"boxes": bounding_boxes} + """ + return self.model(x, targets)[1] diff --git a/dlclive/pose_estimation_pytorch/models/heads/__init__.py b/dlclive/pose_estimation_pytorch/models/heads/__init__.py new file mode 100644 index 0000000..5bf207e --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/heads/__init__.py @@ -0,0 +1,16 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from dlclive.pose_estimation_pytorch.models.heads.base import HEADS, BaseHead +from dlclive.pose_estimation_pytorch.models.heads.dekr import DEKRHead +from dlclive.pose_estimation_pytorch.models.heads.dlcrnet import DLCRNetHead +from dlclive.pose_estimation_pytorch.models.heads.rtmcc_head import RTMCCHead +from dlclive.pose_estimation_pytorch.models.heads.simple_head import HeatmapHead +from dlclive.pose_estimation_pytorch.models.heads.transformer import TransformerHead diff --git a/dlclive/pose_estimation_pytorch/models/heads/base.py b/dlclive/pose_estimation_pytorch/models/heads/base.py new file mode 100644 index 0000000..56e1f01 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/heads/base.py @@ -0,0 +1,57 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from dlclive.pose_estimation_pytorch.models.predictors import BasePredictor +from dlclive.pose_estimation_pytorch.models.registry import Registry, build_from_cfg + +HEADS = Registry("heads", build_func=build_from_cfg) + + +class BaseHead(ABC, nn.Module): + """A head for pose estimation models + + Attributes: + stride: The stride for the head (or neck + head pair), where positive values + indicate an increase in resolution while negative values a decrease. + Assuming that H and W are divisible by `stride`, this is the value such + that if a backbone outputs an encoding of shape (C, H, W), the head will + output heatmaps of shape: + (C, H * stride, W * stride) if stride > 0 + (C, -H/stride, -W/stride) if stride < 0 + predictor: an object to generate predictions from the head outputs + """ + + def __init__(self, stride: int | float, predictor: BasePredictor) -> None: + super().__init__() + if stride == 0: + raise ValueError(f"Stride must not be 0. Found {stride}.") + + self.stride = stride + self.predictor = predictor + + @abstractmethod + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + """ + Given the feature maps for an image () + + Args: + x: the feature maps, of shape (b, c, h, w) + + Returns: + the head outputs (e.g. "heatmap", "locref") + """ + pass diff --git a/dlclive/pose_estimation_pytorch/models/heads/dekr.py b/dlclive/pose_estimation_pytorch/models/heads/dekr.py new file mode 100644 index 0000000..1ef1ec1 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/heads/dekr.py @@ -0,0 +1,416 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import torch +import torch.nn as nn + +from dlclive.pose_estimation_pytorch.models.heads.base import HEADS, BaseHead +from dlclive.pose_estimation_pytorch.models.modules.conv_block import ( + AdaptBlock, + BaseBlock, + BasicBlock, +) +from dlclive.pose_estimation_pytorch.models.predictors import BasePredictor + + +@HEADS.register_module +class DEKRHead(BaseHead): + """ + DEKR head based on: + Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression + Zigang Geng, Ke Sun, Bin Xiao, Zhaoxiang Zhang, Jingdong Wang, CVPR 2021 + Code based on: + https://github.com/HRNet/DEKR + """ + + def __init__( + self, + predictor: BasePredictor, + heatmap_config: dict, + offset_config: dict, + stride: int | float = 1, # head stride - should always be 1 for DEKR + ) -> None: + super().__init__(stride, predictor) + self.heatmap_head = DEKRHeatmap(**heatmap_config) + self.offset_head = DEKROffset(**offset_config) + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + return {"heatmap": self.heatmap_head(x), "offset": self.offset_head(x)} + + +class DEKRHeatmap(nn.Module): + """ + DEKR head to compute the heatmaps corresponding to keypoints based on: + Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression + Zigang Geng, Ke Sun, Bin Xiao, Zhaoxiang Zhang, Jingdong Wang, CVPR 2021 + Code based on: + https://github.com/HRNet/DEKR + """ + + def __init__( + self, + channels: tuple[int], + num_blocks: int, + dilation_rate: int, + final_conv_kernel: int, + block: type(BaseBlock) = BasicBlock, + ) -> None: + """Summary: + Constructor of the HeatmapDEKRHead. + Loads the data. + + Args: + channels: tuple containing the number of channels for the head. + num_blocks: number of blocks in the head + dilation_rate: dilation rate for the head + final_conv_kernel: kernel size for the final convolution + block: type of block to use in the head. Defaults to BasicBlock. + + Returns: + None + + Examples: + channels = (64,128,17) + num_blocks = 3 + dilation_rate = 2 + final_conv_kernel = 3 + block = BasicBlock + """ + super().__init__() + self.bn_momentum = 0.1 + self.inp_channels = channels[0] + self.num_joints_with_center = channels[ + 2 + ] # Should account for the center being a joint + self.final_conv_kernel = final_conv_kernel + + self.transition_heatmap = self._make_transition_for_head( + self.inp_channels, channels[1] + ) + self.head_heatmap = self._make_heatmap_head( + block, num_blocks, channels[1], dilation_rate + ) + + def _make_transition_for_head( + self, in_channels: int, out_channels: int + ) -> nn.Sequential: + """Summary: + Construct the transition layer for the head. + + Args: + in_channels: number of input channels + out_channels: number of output channels + + Returns: + Transition layer consisting of Conv2d, BatchNorm2d, and ReLU + """ + transition_layer = [ + nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(True), + ] + return nn.Sequential(*transition_layer) + + def _make_heatmap_head( + self, + block: type(BaseBlock), + num_blocks: int, + num_channels: int, + dilation_rate: int, + ) -> nn.ModuleList: + """Summary: + Construct the heatmap head + + Args: + block: type of block to use in the head. + num_blocks: number of blocks in the head. + num_channels: number of input channels for the head. + dilation_rate: dilation rate for the head. + + Returns: + List of modules representing the heatmap head layers. + """ + heatmap_head_layers = [] + + feature_conv = self._make_layer( + block, num_channels, num_channels, num_blocks, dilation=dilation_rate + ) + heatmap_head_layers.append(feature_conv) + + heatmap_conv = nn.Conv2d( + in_channels=num_channels, + out_channels=self.num_joints_with_center, + kernel_size=self.final_conv_kernel, + stride=1, + padding=1 if self.final_conv_kernel == 3 else 0, + ) + heatmap_head_layers.append(heatmap_conv) + + return nn.ModuleList(heatmap_head_layers) + + def _make_layer( + self, + block: type(BaseBlock), + in_channels: int, + out_channels: int, + num_blocks: int, + stride: int = 1, + dilation: int = 1, + ) -> nn.Sequential: + """Summary: + Construct a layer in the head. + + Args: + block: type of block to use in the head. + in_channels: number of input channels for the layer. + out_channels: number of output channels for the layer. + num_blocks: number of blocks in the layer. + stride: stride for the convolutional layer. Defaults to 1. + dilation: dilation rate for the convolutional layer. Defaults to 1. + + Returns: + Sequential layer containing the specified num_blocks. + """ + downsample = None + if stride != 1 or in_channels != out_channels * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d( + out_channels * block.expansion, momentum=self.bn_momentum + ), + ) + + layers = [ + block(in_channels, out_channels, stride, downsample, dilation=dilation) + ] + in_channels = out_channels * block.expansion + for _ in range(1, num_blocks): + layers.append(block(in_channels, out_channels, dilation=dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + heatmap = self.head_heatmap[1](self.head_heatmap[0](self.transition_heatmap(x))) + + return heatmap + + +class DEKROffset(nn.Module): + """ + DEKR module to compute the offset from the center corresponding to each keypoints: + Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression + Zigang Geng, Ke Sun, Bin Xiao, Zhaoxiang Zhang, Jingdong Wang, CVPR 2021 + Code based on: + https://github.com/HRNet/DEKR + """ + + def __init__( + self, + channels: tuple[int, ...], + num_offset_per_kpt: int, + num_blocks: int, + dilation_rate: int, + final_conv_kernel: int, + block: type(BaseBlock) = AdaptBlock, + ) -> None: + """Args: + channels: tuple containing the number of input, offset, and output channels. + num_offset_per_kpt: number of offset values per keypoint. + num_blocks: number of blocks in the head. + dilation_rate: dilation rate for convolutional layers. + final_conv_kernel: kernel size for the final convolution. + block: type of block to use in the head. Defaults to AdaptBlock. + """ + super().__init__() + self.inp_channels = channels[0] + self.num_joints = channels[2] + self.num_joints_with_center = self.num_joints + 1 + + self.bn_momentum = 0.1 + self.offset_perkpt = num_offset_per_kpt + self.num_joints_without_center = self.num_joints + self.offset_channels = self.offset_perkpt * self.num_joints_without_center + assert self.offset_channels == channels[1] + + self.num_blocks = num_blocks + self.dilation_rate = dilation_rate + self.final_conv_kernel = final_conv_kernel + + self.transition_offset = self._make_transition_for_head( + self.inp_channels, self.offset_channels + ) + ( + self.offset_feature_layers, + self.offset_final_layer, + ) = self._make_separete_regression_head( + block, + num_blocks=num_blocks, + num_channels_per_kpt=self.offset_perkpt, + dilation_rate=self.dilation_rate, + ) + + def _make_layer( + self, + block: type(BaseBlock), + in_channels: int, + out_channels: int, + num_blocks: int, + stride: int = 1, + dilation: int = 1, + ) -> nn.Sequential: + """Summary: + Create a sequential layer with the specified block and number of num_blocks. + + Args: + block: block type to use in the layer. + in_channels: number of input channels. + out_channels: number of output channels. + num_blocks: number of blocks to be stacked in the layer. + stride: stride for the first block. Defaults to 1. + dilation: dilation rate for the blocks. Defaults to 1. + + Returns: + A sequential layer containing stacked num_blocks. + + Examples: + input: + block=BasicBlock + in_channels=64 + out_channels=128 + num_blocks=3 + stride=1 + dilation=1 + """ + downsample = None + if stride != 1 or in_channels != out_channels * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d( + out_channels * block.expansion, momentum=self.bn_momentum + ), + ) + + layers = [] + layers.append( + block(in_channels, out_channels, stride, downsample, dilation=dilation) + ) + in_channels = out_channels * block.expansion + for _ in range(1, num_blocks): + layers.append(block(in_channels, out_channels, dilation=dilation)) + + return nn.Sequential(*layers) + + def _make_transition_for_head( + self, in_channels: int, out_channels: int + ) -> nn.Sequential: + """Summary: + Create a transition layer for the head. + + Args: + in_channels: number of input channels + out_channels: number of output channels + + Returns: + Sequential layer containing the transition operations. + """ + transition_layer = [ + nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(True), + ] + return nn.Sequential(*transition_layer) + + def _make_separete_regression_head( + self, + block: type(BaseBlock), + num_blocks: int, + num_channels_per_kpt: int, + dilation_rate: int, + ) -> tuple: + """Summary: + + Args: + block: type of block to use in the head + num_blocks: number of blocks in the regression head + num_channels_per_kpt: number of channels per keypoint + dilation_rate: dilation rate for the regression head + + Returns: + A tuple containing two ModuleList objects. + The first ModuleList contains the feature convolution layers for each keypoint, + and the second ModuleList contains the final offset convolution layers. + """ + offset_feature_layers = [] + offset_final_layer = [] + + for _ in range(self.num_joints): + feature_conv = self._make_layer( + block, + num_channels_per_kpt, + num_channels_per_kpt, + num_blocks, + dilation=dilation_rate, + ) + offset_feature_layers.append(feature_conv) + + offset_conv = nn.Conv2d( + in_channels=num_channels_per_kpt, + out_channels=2, + kernel_size=self.final_conv_kernel, + stride=1, + padding=1 if self.final_conv_kernel == 3 else 0, + ) + offset_final_layer.append(offset_conv) + + return nn.ModuleList(offset_feature_layers), nn.ModuleList(offset_final_layer) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Summary: + Perform forward pass through the OffsetDEKRHead. + + Args: + x: input tensor to the head. + + Returns: + offset: Computed offsets from the center corresponding to each keypoint. + The tensor will have the shape (N, num_joints * 2, H, W), where N is the batch size, + num_joints is the number of keypoints, and H and W are the height and width of the output tensor. + """ + final_offset = [] + offset_feature = self.transition_offset(x) + + for j in range(self.num_joints): + final_offset.append( + self.offset_final_layer[j]( + self.offset_feature_layers[j]( + offset_feature[ + :, j * self.offset_perkpt : (j + 1) * self.offset_perkpt + ] + ) + ) + ) + + offset = torch.cat(final_offset, dim=1) + + return offset diff --git a/dlclive/pose_estimation_pytorch/models/heads/dlcrnet.py b/dlclive/pose_estimation_pytorch/models/heads/dlcrnet.py new file mode 100644 index 0000000..79cc315 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/heads/dlcrnet.py @@ -0,0 +1,137 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import torch +import torch.nn as nn + +from dlclive.pose_estimation_pytorch.models.heads.base import HEADS +from dlclive.pose_estimation_pytorch.models.heads.simple_head import ( + DeconvModule, + HeatmapHead, +) +from dlclive.pose_estimation_pytorch.models.predictors import BasePredictor + + +@HEADS.register_module +class DLCRNetHead(HeatmapHead): + """A head for DLCRNet models using Part-Affinity Fields to predict individuals""" + + def __init__( + self, + predictor: BasePredictor, + heatmap_config: dict, + locref_config: dict, + paf_config: dict, + num_stages: int = 5, + features_dim: int = 128, + ) -> None: + self.num_stages = num_stages + # FIXME Cleaner __init__ to avoid initializing unused layers + in_channels = heatmap_config["channels"][0] + num_keypoints = heatmap_config["channels"][-1] + num_limbs = paf_config["channels"][-1] # Already has the 2x multiplier + in_refined_channels = features_dim + num_keypoints + num_limbs + if num_stages > 0: + heatmap_config["channels"][0] = paf_config["channels"][ + 0 + ] = in_refined_channels + locref_config["channels"][0] = locref_config["channels"][-1] + + super().__init__(predictor, heatmap_config, locref_config) + if num_stages > 0: + self.stride *= 2 # extra deconv layer where it's multi-stage + + self.paf_head = DeconvModule(**paf_config) + + self.convt1 = self._make_layer_same_padding( + in_channels=in_channels, out_channels=num_keypoints + ) + self.convt2 = self._make_layer_same_padding( + in_channels=in_channels, out_channels=locref_config["channels"][-1] + ) + self.convt3 = self._make_layer_same_padding( + in_channels=in_channels, out_channels=num_limbs + ) + self.convt4 = self._make_layer_same_padding( + in_channels=in_channels, out_channels=features_dim + ) + self.hm_ref_layers = nn.ModuleList() + self.paf_ref_layers = nn.ModuleList() + for _ in range(num_stages): + self.hm_ref_layers.append( + self._make_refinement_layer( + in_channels=in_refined_channels, out_channels=num_keypoints + ) + ) + self.paf_ref_layers.append( + self._make_refinement_layer( + in_channels=in_refined_channels, out_channels=num_limbs + ) + ) + + def _make_layer_same_padding( + self, in_channels: int, out_channels: int + ) -> nn.ConvTranspose2d: + # FIXME There is no consensual solution to emulate TF behavior in pytorch + # see https://github.com/pytorch/pytorch/issues/3867 + return nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ) + + def _make_refinement_layer(self, in_channels: int, out_channels: int) -> nn.Conv2d: + """Summary: + Helper function to create a refinement layer. + + Args: + in_channels: number of input channels + out_channels: number of output channels + + Returns: + refinement_layer: the refinement layer. + """ + return nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding="same" + ) + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + if self.num_stages > 0: + stage1_hm_out = self.convt1(x) + stage1_paf_out = self.convt3(x) + features = self.convt4(x) + stage2_in = torch.cat((stage1_hm_out, stage1_paf_out, features), dim=1) + stage_in = stage2_in + stage_paf_out = stage1_paf_out + stage_hm_out = stage1_hm_out + for i, (hm_ref_layer, paf_ref_layer) in enumerate( + zip(self.hm_ref_layers, self.paf_ref_layers) + ): + pre_stage_hm_out = stage_hm_out + stage_hm_out = hm_ref_layer(stage_in) + stage_paf_out = paf_ref_layer(stage_in) + if i > 0: + stage_hm_out += pre_stage_hm_out + stage_in = torch.cat((stage_hm_out, stage_paf_out, features), dim=1) + return { + "heatmap": self.heatmap_head(stage_in), + "locref": self.locref_head(self.convt2(x)), + "paf": self.paf_head(stage_in), + } + return { + "heatmap": self.heatmap_head(x), + "locref": self.locref_head(x), + "paf": self.paf_head(x), + } diff --git a/dlclive/pose_estimation_pytorch/models/heads/rtmcc_head.py b/dlclive/pose_estimation_pytorch/models/heads/rtmcc_head.py new file mode 100644 index 0000000..53c112d --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/heads/rtmcc_head.py @@ -0,0 +1,139 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Modified SimCC head for the RTMPose model + +Based on the official ``mmpose`` RTMCC head implementation. For more information, see +. +""" +from __future__ import annotations + +import torch +import torch.nn as nn + +from dlclive.pose_estimation_pytorch.models.heads.base import ( + BaseHead, + HEADS, +) +from dlclive.pose_estimation_pytorch.models.modules import ( + GatedAttentionUnit, + ScaleNorm, +) +from dlclive.pose_estimation_pytorch.models.predictors import BasePredictor + + +@HEADS.register_module +class RTMCCHead(BaseHead): + """RTMPose Coordinate Classification head + + The RTMCC head is itself adapted from the SimCC head. For more information, see + "SimCC: a Simple Coordinate Classification Perspective for Human Pose Estimation" + () and "RTMPose: Real-Time Multi-Person Pose + Estimation based on MMPose" (). + + Args: + input_size: The size of images given to the pose estimation model. + in_channels: The number of input channels for the head. + out_channels: Number of channels output by the head (number of bodyparts). + in_featuremap_size: The size of the input feature map for the head. This is + equal to the input_size divided by the backbone stride. + simcc_split_ratio: The split ratio of pixels, as described in SimCC. + final_layer_kernel_size: Kernel size of the final convolutional layer. + gau_cfg: Configuration for the GatedAttentionUnit. + predictor: The predictor for the head. Should usually be a `SimCCPredictor`. + """ + + def __init__( + self, + input_size: tuple[int, int], + in_channels: int, + out_channels: int, + in_featuremap_size: tuple[int, int], + simcc_split_ratio: float, + final_layer_kernel_size: int, + gau_cfg: dict, + predictor: BasePredictor, + ) -> None: + super().__init__(1, predictor) + + self.input_size = input_size + self.in_channels = in_channels + self.out_channels = out_channels + + self.in_featuremap_size = in_featuremap_size + self.simcc_split_ratio = simcc_split_ratio + + flatten_dims = self.in_featuremap_size[0] * self.in_featuremap_size[1] + out_w = int(self.input_size[0] * self.simcc_split_ratio) + out_h = int(self.input_size[1] * self.simcc_split_ratio) + + self.gau = GatedAttentionUnit( + num_token=self.out_channels, + in_token_dims=gau_cfg["hidden_dims"], + out_token_dims=gau_cfg["hidden_dims"], + expansion_factor=gau_cfg["expansion_factor"], + s=gau_cfg["s"], + eps=1e-5, + dropout_rate=gau_cfg["dropout_rate"], + drop_path=gau_cfg["drop_path"], + attn_type="self-attn", + act_fn=gau_cfg["act_fn"], + use_rel_bias=gau_cfg["use_rel_bias"], + pos_enc=gau_cfg["pos_enc"], + ) + + self.final_layer = nn.Conv2d( + in_channels, + out_channels, + kernel_size=final_layer_kernel_size, + stride=1, + padding=final_layer_kernel_size // 2, + ) + self.mlp = nn.Sequential( + ScaleNorm(flatten_dims), + nn.Linear(flatten_dims, gau_cfg["hidden_dims"], bias=False), + ) + + self.cls_x = nn.Linear(gau_cfg["hidden_dims"], out_w, bias=False) + self.cls_y = nn.Linear(gau_cfg["hidden_dims"], out_h, bias=False) + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + feats = self.final_layer(x) # -> B, K, H, W + feats = torch.flatten(feats, start_dim=2) # -> B, K, hidden=HxW + feats = self.mlp(feats) # -> B, K, hidden + feats = self.gau(feats) + x, y = self.cls_x(feats), self.cls_y(feats) + return dict(x=x, y=y) + + @staticmethod + def update_input_size(model_cfg: dict, input_size: tuple[int, int]) -> None: + """Updates an RTMPose model configuration file for a new image input size + + Args: + model_cfg: The model configuration to update in-place. + input_size: The updated input (width, height). + """ + _sigmas = {192: 4.9, 256: 5.66, 288: 6, 384: 6.93} + + def _sigma(size: int) -> float: + sigma = _sigmas.get(size) + if sigma is None: + return 2.87 + 0.01 * size + + return sigma + + w, h = input_size + model_cfg["data"]["inference"]["top_down_crop"] = dict(width=w, height=h) + model_cfg["data"]["train"]["top_down_crop"] = dict(width=w, height=h) + head_cfg = model_cfg["model"]["heads"]["bodypart"] + head_cfg["input_size"] = input_size + head_cfg["in_featuremap_size"] = h // 32, w // 32 + head_cfg["target_generator"]["input_size"] = input_size + head_cfg["target_generator"]["sigma"] = (_sigma(w), _sigma(h)) diff --git a/dlclive/pose_estimation_pytorch/models/heads/simple_head.py b/dlclive/pose_estimation_pytorch/models/heads/simple_head.py new file mode 100644 index 0000000..545d854 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/heads/simple_head.py @@ -0,0 +1,224 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import torch +import torch.nn as nn + +from dlclive.pose_estimation_pytorch.models.heads.base import HEADS, BaseHead +from dlclive.pose_estimation_pytorch.models.predictors import BasePredictor + + +@HEADS.register_module +class HeatmapHead(BaseHead): + """Deconvolutional head to predict maps from the extracted features. + + This class implements a simple deconvolutional head to predict maps from the + extracted features. + + Args: + predictor: The predictor used to transform heatmaps into keypoints. + heatmap_config: The configuration for the heatmap outputs of the head. + locref_config: The configuration for the location refinement outputs (None if + no location refinement should be used). + """ + + def __init__( + self, + predictor: BasePredictor, + heatmap_config: dict, + locref_config: dict | None = None, + ) -> None: + heatmap_head = DeconvModule(**heatmap_config) + locref_head = None + if locref_config is not None: + locref_head = DeconvModule(**locref_config) + + # check that the heatmap and locref modules have the same stride + if heatmap_head.stride != locref_head.stride: + raise ValueError( + f"Invalid model config: Your heatmap and locref need to have the " + f"same stride (found {heatmap_head.stride}, " + f"{locref_head.stride}). Please check your config (found " + f"heatmap_config={heatmap_config}, locref_config={locref_config}" + ) + + super().__init__(heatmap_head.stride, predictor) + self.heatmap_head = heatmap_head + self.locref_head = locref_head + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + outputs = {"heatmap": self.heatmap_head(x)} + if self.locref_head is not None: + outputs["locref"] = self.locref_head(x) + return outputs + + @staticmethod + def convert_weights( + state_dict: dict[str, torch.Tensor], + module_prefix: str, + conversion: torch.Tensor, + ) -> dict[str, torch.Tensor]: + """Converts pre-trained weights to be fine-tuned on another dataset + + Args: + state_dict: the state dict for the pre-trained model + module_prefix: the prefix for weights in this head (e.g., 'heads.bodypart.') + conversion: the mapping of old indices to new indices + """ + state_dict = DeconvModule.convert_weights( + state_dict, + f"{module_prefix}heatmap_head.", + conversion, + ) + + locref_conversion = torch.stack( + [2 * conversion, 2 * conversion + 1], + dim=1, + ).reshape(-1) + state_dict = DeconvModule.convert_weights( + state_dict, + f"{module_prefix}locref_head.", + locref_conversion, + ) + return state_dict + + +class DeconvModule(nn.Module): + """ + Deconvolutional module to predict maps from the extracted features. + """ + + def __init__( + self, + channels: list[int], + kernel_size: list[int], + strides: list[int], + final_conv: dict | None = None, + ) -> None: + """ + Args: + channels: List containing the number of input and output channels for each + deconvolutional layer. + kernel_size: List containing the kernel size for each deconvolutional layer. + strides: List containing the stride for each deconvolutional layer. + final_conv: Configuration for a conv layer after the deconvolutional layers, + if one should be added. Must have keys "out_channels" and "kernel_size". + """ + super().__init__() + if not (len(channels) == len(kernel_size) + 1 == len(strides) + 1): + raise ValueError( + "Incorrect DeconvModule configuration: there should be one more number" + f" of channels than kernel_sizes and strides, found {len(channels)} " + f"channels, {len(kernel_size)} kernels and {len(strides)} strides." + ) + + in_channels = channels[0] + head_stride = 1 + self.deconv_layers = nn.Identity() + if len(kernel_size) > 0: + self.deconv_layers = nn.Sequential( + *self._make_layers(in_channels, channels[1:], kernel_size, strides) + ) + for s in strides: + head_stride *= s + + self.stride = head_stride + self.final_conv = nn.Identity() + if final_conv: + self.final_conv = nn.Conv2d( + in_channels=channels[-1], + out_channels=final_conv["out_channels"], + kernel_size=final_conv["kernel_size"], + stride=1, + ) + + @staticmethod + def _make_layers( + in_channels: int, + out_channels: list[int], + kernel_sizes: list[int], + strides: list[int], + ) -> list[nn.Module]: + """ + Helper function to create the deconvolutional layers. + + Args: + in_channels: number of input channels to the module + out_channels: number of output channels of each layer + kernel_sizes: size of the deconvolutional kernel + strides: stride for the convolution operation + + Returns: + the deconvolutional layers + """ + layers = [] + for out_channels, k, s in zip(out_channels, kernel_sizes, strides): + layers.append( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=s) + ) + layers.append(nn.ReLU()) + in_channels = out_channels + return layers[:-1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the HeatmapHead + + Args: + x: input tensor + + Returns: + out: output tensor + """ + x = self.deconv_layers(x) + x = self.final_conv(x) + return x + + @staticmethod + def convert_weights( + state_dict: dict[str, torch.Tensor], + module_prefix: str, + conversion: torch.Tensor, + ) -> dict[str, torch.Tensor]: + """Converts pre-trained weights to be fine-tuned on another dataset + + Args: + state_dict: the state dict for the pre-trained model + module_prefix: the prefix for weights in this head (e.g., 'heads.bodypart') + conversion: the mapping of old indices to new indices + """ + if f"{module_prefix}final_conv.weight" in state_dict: + # has final convolution + weight_key = f"{module_prefix}final_conv.weight" + bias_key = f"{module_prefix}final_conv.bias" + state_dict[weight_key] = state_dict[weight_key][conversion] + state_dict[bias_key] = state_dict[bias_key][conversion] + return state_dict + + # get the last deconv layer of the net + next_index = 0 + while f"{module_prefix}deconv_layers.{next_index}.weight" in state_dict: + next_index += 1 + last_index = next_index - 1 + + # if there are deconv layers for this module prefix (there might not be, + # e.g., when there are no location refinement layers in a heatmap head) + if last_index >= 0: + weight_key = f"{module_prefix}deconv_layers.{last_index}.weight" + bias_key = f"{module_prefix}deconv_layers.{last_index}.bias" + + # for ConvTranspose2d, the weight shape is (in_channels, out_channels, ...) + # while it's (out_channels, in_channels, ...) for Conv2d + state_dict[weight_key] = state_dict[weight_key][:, conversion] + state_dict[bias_key] = state_dict[bias_key][conversion] + + return state_dict diff --git a/dlclive/pose_estimation_pytorch/models/heads/transformer.py b/dlclive/pose_estimation_pytorch/models/heads/transformer.py new file mode 100644 index 0000000..cd64677 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/heads/transformer.py @@ -0,0 +1,94 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import torch +from einops import rearrange +from timm.layers import trunc_normal_ +from torch import nn as nn + +from dlclive.pose_estimation_pytorch.models.heads import HEADS, BaseHead +from dlclive.pose_estimation_pytorch.models.predictors import BasePredictor + + +@HEADS.register_module +class TransformerHead(BaseHead): + """ + Transformer Head module to predict heatmaps using a transformer-based approach + """ + + def __init__( + self, + predictor: BasePredictor, + dim: int, + hidden_heatmap_dim: int, + heatmap_dim: int, + apply_multi: bool, + heatmap_size: tuple[int, int], + apply_init: bool, + head_stride: int, + ): + """ + Args: + dim: Dimension of the input features. + hidden_heatmap_dim: Dimension of the hidden features in the MLP head. + heatmap_dim: Dimension of the output heatmaps. + apply_multi: If True, apply a multi-layer perceptron (MLP) with LayerNorm + to generate heatmaps. If False, directly apply a single linear + layer for heatmap prediction. + heatmap_size: Tuple (height, width) representing the size of the output + heatmaps. + apply_init: If True, apply weight initialization to the module's layers. + head_stride: The stride for the head (or neck + head pair), where positive + values indicate an increase in resolution while negative values a + decrease. Assuming that H and W are divisible by head_stride, this is + the value such that if a backbone outputs an encoding of shape + (C, H, W), the head will output heatmaps of shape: + (C, H * head_stride, W * head_stride) if head_stride > 0 + (C, -H/head_stride, -W/head_stride) if head_stride < 0 + """ + super().__init__(head_stride, predictor) + self.mlp_head = ( + nn.Sequential( + nn.LayerNorm(dim * 3), + nn.Linear(dim * 3, hidden_heatmap_dim), + nn.LayerNorm(hidden_heatmap_dim), + nn.Linear(hidden_heatmap_dim, heatmap_dim), + ) + if (dim * 3 <= hidden_heatmap_dim * 0.5 and apply_multi) + else nn.Sequential(nn.LayerNorm(dim * 3), nn.Linear(dim * 3, heatmap_dim)) + ) + self.heatmap_size = heatmap_size + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + x = self.mlp_head(x) + x = rearrange( + x, + "b c (p1 p2) -> b c p1 p2", + p1=self.heatmap_size[0], + p2=self.heatmap_size[1], + ) + return {"heatmap": x} + + def _init_weights(self, m: nn.Module) -> None: + """ + Custom weight initialization for linear and layer normalization layers. + + Args: + m: module to initialize + """ + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) diff --git a/dlclive/pose_estimation_pytorch/models/model.py b/dlclive/pose_estimation_pytorch/models/model.py new file mode 100644 index 0000000..1d8fce8 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/model.py @@ -0,0 +1,127 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import copy + +import torch +import torch.nn as nn + +from dlclive.pose_estimation_pytorch.models.backbones import BACKBONES, BaseBackbone +from dlclive.pose_estimation_pytorch.models.heads import HEADS, BaseHead +from dlclive.pose_estimation_pytorch.models.necks import NECKS, BaseNeck +from dlclive.pose_estimation_pytorch.models.predictors import PREDICTORS + + +class PoseModel(nn.Module): + """A pose estimation model + + A pose estimation model is composed of a backbone, optionally a neck, and an + arbitrary number of heads. Outputs are computed as follows: + """ + + def __init__( + self, + cfg: dict, + backbone: BaseBackbone, + heads: dict[str, BaseHead], + neck: BaseNeck | None = None, + ) -> None: + """ + Args: + cfg: configuration dictionary for the model. + backbone: backbone network architecture. + heads: the heads for the model + neck: neck network architecture (default is None). Defaults to None. + """ + super().__init__() + self.cfg = cfg + self.backbone = backbone + self.heads = nn.ModuleDict(heads) + self.neck = neck + + self._strides = { + name: _model_stride(self.backbone.stride, head.stride) + for name, head in heads.items() + } + + def forward(self, x: torch.Tensor) -> dict[str, dict[str, torch.Tensor]]: + """ + Forward pass of the PoseModel. + + Args: + x: input images + + Returns: + Outputs of head groups + """ + if x.dim() == 3: + x = x[None, :] + features = self.backbone(x) + if self.neck: + features = self.neck(features) + + outputs = {} + for head_name, head in self.heads.items(): + outputs[head_name] = head(features) + return outputs + + def get_predictions(self, outputs: dict[str, dict[str, torch.Tensor]]) -> dict: + """Abstract method for the forward pass of the Predictor. + + Args: + outputs: outputs of the model heads + + Returns: + A dictionary containing the predictions of each head group + """ + return { + name: head.predictor(self._strides[name], outputs[name]) + for name, head in self.heads.items() + } + + @staticmethod + def build(cfg: dict) -> "PoseModel": + """ + Args: + cfg: The configuration of the model to build. + + Returns: + the built pose model + """ + cfg["backbone"]["pretrained"] = False + backbone = BACKBONES.build(dict(cfg["backbone"])) + + neck = None + if cfg.get("neck"): + neck = NECKS.build(dict(cfg["neck"])) + + heads = {} + for name, head_cfg in cfg["heads"].items(): + head_cfg = copy.deepcopy(head_cfg) + + # Remove keys not needed for DLCLive inference + for k in ("target_generator", "criterion", "aggregator", "weight_init"): + if k in head_cfg: + head_cfg.pop(k) + + head_cfg["predictor"] = PREDICTORS.build(head_cfg["predictor"]) + heads[name] = HEADS.build(head_cfg) + + return PoseModel(cfg=cfg, backbone=backbone, neck=neck, heads=heads) + + +def _model_stride(backbone_stride: int | float, head_stride: int | float) -> float: + """Computes the model stride from a backbone and a head""" + if head_stride > 0: + return backbone_stride / head_stride + + return backbone_stride * -head_stride diff --git a/dlclive/pose_estimation_pytorch/models/modules/__init__.py b/dlclive/pose_estimation_pytorch/models/modules/__init__.py new file mode 100644 index 0000000..4974948 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/modules/__init__.py @@ -0,0 +1,24 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from dlclive.pose_estimation_pytorch.models.modules.conv_block import ( + AdaptBlock, + BasicBlock, + Bottleneck, +) +from dlclive.pose_estimation_pytorch.models.modules.conv_module import ( + HighResolutionModule, +) +from dlclive.pose_estimation_pytorch.models.modules.gated_attention_unit import ( + GatedAttentionUnit, +) +from dlclive.pose_estimation_pytorch.models.modules.norm import ( + ScaleNorm, +) diff --git a/dlclive/pose_estimation_pytorch/models/modules/conv_block.py b/dlclive/pose_estimation_pytorch/models/modules/conv_block.py new file mode 100644 index 0000000..f3fbb02 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/modules/conv_block.py @@ -0,0 +1,307 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""The code is based on DEKR: https://github.com/HRNet/DEKR/tree/main""" +from __future__ import annotations + +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn +import torchvision.ops as ops + +from dlclive.pose_estimation_pytorch.models.registry import Registry, build_from_cfg + + +BLOCKS = Registry("blocks", build_func=build_from_cfg) + + +class BaseBlock(ABC, nn.Module): + """Abstract Base class for defining custom blocks. + + This class defines an abstract base class for creating custom blocks used in the HigherHRNet for Human Pose Estimation. + + Attributes: + bn_momentum: Batch normalization momentum. + + Methods: + forward(x): Abstract method for defining the forward pass of the block. + """ + + def __init__(self): + super().__init__() + self.bn_momentum = 0.1 + + @abstractmethod + def forward(self, x: torch.Tensor): + """Abstract method for defining the forward pass of the block. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + pass + + def _init_weights(self, pretrained: str | None): + """Method for initializing block weights from pretrained models. + + Args: + pretrained: Path to pretrained model weights. + """ + if pretrained: + self.load_state_dict(torch.load(pretrained)) + + +@BLOCKS.register_module +class BasicBlock(BaseBlock): + """Basic Residual Block. + + This class defines a basic residual block used in HigherHRNet. + + Attributes: + expansion: The expansion factor used in the block. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + stride: Stride value for the convolutional layers. Default is 1. + downsample: Downsample layer to be used in the residual connection. Default is None. + dilation: Dilation rate for the convolutional layers. Default is 1. + """ + + expansion: int = 1 + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + downsample: nn.Module | None = None, + dilation: int = 1, + ): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn1 = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn2 = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the BasicBlock. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +@BLOCKS.register_module +class Bottleneck(BaseBlock): + """Bottleneck Residual Block. + + This class defines a bottleneck residual block used in HigherHRNet. + + Attributes: + expansion: The expansion factor used in the block. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + stride: Stride value for the convolutional layers. Default is 1. + downsample: Downsample layer to be used in the residual connection. Default is None. + dilation: Dilation rate for the convolutional layers. Default is 1. + """ + + expansion: int = 4 + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + downsample: nn.Module | None = None, + dilation: int = 1, + ): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=dilation, + bias=False, + dilation=dilation, + ) + self.bn2 = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum) + self.conv3 = nn.Conv2d( + out_channels, out_channels * self.expansion, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d( + out_channels * self.expansion, momentum=self.bn_momentum + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the Bottleneck block. + + Args: + x : Input tensor. + + Returns: + Output tensor. + """ + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +@BLOCKS.register_module +class AdaptBlock(BaseBlock): + """Adaptive Residual Block with Deformable Convolution. + + This class defines an adaptive residual block with deformable convolution used in HigherHRNet. + + Attributes: + expansion: The expansion factor used in the block. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + stride: Stride value for the convolutional layers. Default is 1. + downsample: Downsample layer to be used in the residual connection. Default is None. + dilation: Dilation rate for the convolutional layers. Default is 1. + deformable_groups: Number of deformable groups in the deformable convolution. Default is 1. + """ + + expansion: int = 1 + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + downsample: nn.Module | None = None, + dilation: int = 1, + deformable_groups: int = 1, + ): + super(AdaptBlock, self).__init__() + regular_matrix = torch.tensor( + [[-1, -1, -1, 0, 0, 0, 1, 1, 1], [-1, 0, 1, -1, 0, 1, -1, 0, 1]] + ) + self.register_buffer("regular_matrix", regular_matrix.float()) + self.downsample = downsample + self.transform_matrix_conv = nn.Conv2d(in_channels, 4, 3, 1, 1, bias=True) + self.translation_conv = nn.Conv2d(in_channels, 2, 3, 1, 1, bias=True) + self.adapt_conv = ops.DeformConv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False, + groups=deformable_groups, + ) + self.bn = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the AdaptBlock. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + residual = x + + N, _, H, W = x.shape + transform_matrix = self.transform_matrix_conv(x) + transform_matrix = transform_matrix.permute(0, 2, 3, 1).reshape( + (N * H * W, 2, 2) + ) + offset = torch.matmul(transform_matrix, self.regular_matrix) + offset = offset - self.regular_matrix + offset = offset.transpose(1, 2).reshape((N, H, W, 18)).permute(0, 3, 1, 2) + + translation = self.translation_conv(x) + offset[:, 0::2, :, :] += translation[:, 0:1, :, :] + offset[:, 1::2, :, :] += translation[:, 1:2, :, :] + + out = self.adapt_conv(x, offset) + out = self.bn(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out diff --git a/dlclive/pose_estimation_pytorch/models/modules/conv_module.py b/dlclive/pose_estimation_pytorch/models/modules/conv_module.py new file mode 100644 index 0000000..8f7241b --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/modules/conv_module.py @@ -0,0 +1,244 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""The code is based on DEKR: https://github.com/HRNet/DEKR/tree/main""" +import logging +from typing import List + +import torch.nn as nn + +from dlclive.pose_estimation_pytorch.models.modules import BasicBlock + + +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +class HighResolutionModule(nn.Module): + """High-Resolution Module. + + This class implements the High-Resolution Module used in HigherHRNet for Human Pose Estimation. + + Args: + num_branches: Number of branches in the module. + block: The block type used in each branch of the module. + num_blocks: List containing the number of blocks in each branch. + num_inchannels: List containing the number of input channels for each branch. + num_channels: List containing the number of output channels for each branch. + fuse_method: The fusion method used in the module. + multi_scale_output: Whether to output multi-scale features. Default is True. + """ + + def __init__( + self, + num_branches: int, + block: BasicBlock, + num_blocks: int, + num_inchannels: int, + num_channels: int, + fuse_method: str, + multi_scale_output: bool = True, + ): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, block, num_blocks, num_inchannels, num_channels + ) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, block, num_blocks, num_channels + ) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(True) + + def _check_branches( + self, + num_branches: int, + block: BasicBlock, + num_blocks: int, + num_inchannels: int, + num_channels: int, + ): + if num_branches != len(num_blocks): + error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( + num_branches, len(num_blocks) + ) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( + num_branches, len(num_channels) + ) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( + num_branches, len(num_inchannels) + ) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch( + self, + branch_index: int, + block: BasicBlock, + num_blocks: int, + num_channels: int, + stride: int = 1, + ) -> nn.Sequential: + downsample = None + if ( + stride != 1 + or self.num_inchannels[branch_index] + != num_channels[branch_index] * block.expansion + ): + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d( + num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM + ), + ) + + layers = [] + layers.append( + block( + self.num_inchannels[branch_index], + num_channels[branch_index], + stride, + downsample, + ) + ) + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block(self.num_inchannels[branch_index], num_channels[branch_index]) + ) + + return nn.Sequential(*layers) + + def _make_branches( + self, num_branches: int, block: BasicBlock, num_blocks: int, num_channels: int + ) -> nn.ModuleList: + branches = [] + + for i in range(num_branches): + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self) -> nn.ModuleList: + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False, + ), + nn.BatchNorm2d(num_inchannels[i]), + nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"), + ) + ) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False, + ), + nn.BatchNorm2d(num_outchannels_conv3x3), + ) + ) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False, + ), + nn.BatchNorm2d(num_outchannels_conv3x3), + nn.ReLU(True), + ) + ) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self) -> int: + return self.num_inchannels + + def forward(self, x) -> List: + """Forward pass through the HighResolutionModule. + + Args: + x: List of input tensors for each branch. + + Returns: + List of output tensors after processing through the module. + """ + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse diff --git a/dlclive/pose_estimation_pytorch/models/modules/csp.py b/dlclive/pose_estimation_pytorch/models/modules/csp.py new file mode 100644 index 0000000..3099eeb --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/modules/csp.py @@ -0,0 +1,387 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Implementation of modules needed for the CSPNeXt Backbone. Used in CSP-style models. + +Based on the building blocks used for the ``mmdetection`` CSPNeXt implementation. For +more information, see . +""" +import torch +import torch.nn as nn + + +def build_activation(activation_fn: str, *args, **kwargs) -> nn.Module: + if activation_fn == "SiLU": + return nn.SiLU(*args, **kwargs) + elif activation_fn == "ReLU": + return nn.ReLU(*args, **kwargs) + + raise NotImplementedError( + f"Unknown `CSPNeXT` activation: {activation_fn}. Must be one of 'SiLU', 'ReLU'" + ) + + +def build_norm(norm: str, *args, **kwargs) -> nn.Module: + if norm == "SyncBN": + return nn.SyncBatchNorm(*args, **kwargs) + elif norm == "BN": + return nn.BatchNorm2d(*args, **kwargs) + + raise NotImplementedError( + f"Unknown `CSPNeXT` norm_layer: {norm}. Must be one of 'SyncBN', 'BN'" + ) + + +class SPPBottleneck(nn.Module): + """Spatial pyramid pooling layer used in YOLOv3-SPP and (among others) CSPNeXt + + Args: + in_channels: input channels to the bottleneck + out_channels: output channels of the bottleneck + kernel_sizes: kernel sizes for the pooling layers + norm_layer: norm layer for the bottleneck + activation_fn: activation function for the bottleneck + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_sizes: tuple[int, ...] = (5, 9, 13), + norm_layer: str | None = "SyncBN", + activation_fn: str | None = "SiLU", + ): + super().__init__() + mid_channels = in_channels // 2 + self.conv1 = CSPConvModule( + in_channels, + mid_channels, + kernel_size=1, + stride=1, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + + self.poolings = nn.ModuleList( + [ + nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) + for ks in kernel_sizes + ] + ) + conv2_channels = mid_channels * (len(kernel_sizes) + 1) + self.conv2 = CSPConvModule( + conv2_channels, + out_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + + def forward(self, x): + x = self.conv1(x) + with torch.amp.autocast("cuda", enabled=False): + x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1) + x = self.conv2(x) + return x + + +class ChannelAttention(nn.Module): + """Channel attention Module. + + Args: + channels: Number of input/output channels of the layer. + """ + + def __init__(self, channels: int) -> None: + super().__init__() + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) + self.act = nn.Hardsigmoid(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + out = self.global_avgpool(x) + out = self.fc(out) + out = self.act(out) + return x * out + + +class CSPConvModule(nn.Module): + """Configurable convolution module used for CSPNeXT. + + Applies sequentially + - a convolution + - (optional) a norm layer + - (optional) an activation function + + Args: + in_channels: Input channels of the convolution. + out_channels: Output channels of the convolution. + kernel_size: Convolution kernel size. + stride: Convolution stride. + padding: Convolution padding. + dilation: Convolution dilation. + groups: Number of blocked connections from input to output channels. + norm_layer: Norm layer to apply, if any. + activation_fn: Activation function to apply, if any. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int] = 1, + padding: int | tuple[int, int] = 0, + dilation: int | tuple[int, int] = 1, + groups: int = 1, + norm_layer: str | None = None, + activation_fn: str | None = "ReLU", + ): + super().__init__() + + self.with_activation = activation_fn is not None + self.with_bias = norm_layer is None + self.with_norm = norm_layer is not None + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=self.with_bias, + ) + self.activate = None + self.norm = None + + if self.with_norm: + self.norm = build_norm(norm_layer, out_channels) + + if self.with_activation: + # Careful when adding activation functions: some should not be in-place + self.activate = build_activation(activation_fn, inplace=True) + + self._init_weights() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + if self.with_norm: + x = self.norm(x) + if self.with_activation: + x = self.activate(x) + return x + + def _init_weights(self) -> None: + """Same init as in convolutions""" + nn.init.kaiming_normal_(self.conv.weight, a=0, nonlinearity="relu") + if self.with_bias: + nn.init.constant_(self.conv.bias, 0) + + if self.with_norm: + nn.init.constant_(self.norm.weight, 1) + nn.init.constant_(self.norm.bias, 0) + + +class DepthwiseSeparableConv(nn.Module): + """Depth-wise separable convolution module used for CSPNeXT. + + Applies sequentially + - a depth-wise conv + - a point-wise conv + + Args: + in_channels: Input channels of the convolution. + out_channels: Output channels of the convolution. + kernel_size: Convolution kernel size. + stride: Convolution stride. + padding: Convolution padding. + dilation: Convolution dilation. + norm_layer: Norm layer to apply, if any. + activation_fn: Activation function to apply, if any. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int | tuple[int, int] = 1, + padding: int | tuple[int, int] = 0, + dilation: int | tuple[int, int] = 1, + norm_layer: str | None = None, + activation_fn: str | None = "ReLU", + ): + super().__init__() + + # depthwise convolution + self.depthwise_conv = CSPConvModule( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + + self.pointwise_conv = CSPConvModule( + in_channels, + out_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + return x + + +class CSPNeXtBlock(nn.Module): + """Basic bottleneck block used in CSPNeXt. + + Args: + in_channels: input channels for the block + out_channels: output channels for the block + expansion: expansion factor for the hidden channels + add_identity: add a skip-connection to the block + kernel_size: kernel size for the DepthwiseSeparableConv + norm_layer: Norm layer to apply, if any. + activation_fn: Activation function to apply, if any. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: float = 0.5, + add_identity: bool = True, + kernel_size: int = 5, + norm_layer: str | None = None, + activation_fn: str | None = "ReLU", + ) -> None: + super().__init__() + hidden_channels = int(out_channels * expansion) + self.conv1 = CSPConvModule( + in_channels, + hidden_channels, + 3, + stride=1, + padding=1, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + self.conv2 = DepthwiseSeparableConv( + hidden_channels, + out_channels, + kernel_size, + stride=1, + padding=kernel_size // 2, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + self.add_identity = add_identity and in_channels == out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + identity = x + out = self.conv1(x) + out = self.conv2(out) + + if self.add_identity: + return out + identity + else: + return out + + +class CSPLayer(nn.Module): + """Cross Stage Partial Layer. + + Args: + in_channels: input channels for the layer + out_channels: output channels for the block + expand_ratio: expansion factor for the mid-channels + num_blocks: the number of blocks to use + add_identity: add a skip-connection to the blocks + channel_attention: whether to apply channel attention + norm_layer: Norm layer to apply, if any. + activation_fn: Activation function to apply, if any. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expand_ratio: float = 0.5, + num_blocks: int = 1, + add_identity: bool = True, + channel_attention: bool = False, + norm_layer: str | None = None, + activation_fn: str | None = "ReLU", + ) -> None: + super().__init__() + mid_channels = int(out_channels * expand_ratio) + self.channel_attention = channel_attention + self.main_conv = CSPConvModule( + in_channels, + mid_channels, + 1, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + self.short_conv = CSPConvModule( + in_channels, + mid_channels, + 1, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + self.final_conv = CSPConvModule( + 2 * mid_channels, + out_channels, + 1, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + + self.blocks = nn.Sequential( + *[ + CSPNeXtBlock( + mid_channels, + mid_channels, + 1.0, + add_identity, + norm_layer=norm_layer, + activation_fn=activation_fn, + ) + for _ in range(num_blocks) + ] + ) + if channel_attention: + self.attention = ChannelAttention(2 * mid_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function.""" + x_short = self.short_conv(x) + + x_main = self.main_conv(x) + x_main = self.blocks(x_main) + + x_final = torch.cat((x_main, x_short), dim=1) + + if self.channel_attention: + x_final = self.attention(x_final) + return self.final_conv(x_final) diff --git a/dlclive/pose_estimation_pytorch/models/modules/gated_attention_unit.py b/dlclive/pose_estimation_pytorch/models/modules/gated_attention_unit.py new file mode 100644 index 0000000..f26aa20 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/modules/gated_attention_unit.py @@ -0,0 +1,237 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Gated Attention Unit + +Based on the building blocks used for the ``mmdetection`` CSPNeXt implementation. For +more information, see . +""" +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import timm.layers as timm_layers + +from dlclive.pose_estimation_pytorch.models.modules.norm import ScaleNorm + + +def rope(x, dim): + """Applies Rotary Position Embedding to input tensor.""" + shape = x.shape + if isinstance(dim, int): + dim = [dim] + + spatial_shape = [shape[i] for i in dim] + total_len = 1 + for i in spatial_shape: + total_len *= i + + position = torch.reshape( + torch.arange(total_len, dtype=torch.int, device=x.device), spatial_shape + ) + + for i in range(dim[-1] + 1, len(shape) - 1, 1): + position = torch.unsqueeze(position, dim=-1) + + half_size = shape[-1] // 2 + freq_seq = -torch.arange(half_size, dtype=torch.int, device=x.device) / float( + half_size + ) + inv_freq = 10000**-freq_seq + + sinusoid = position[..., None] * inv_freq[None, None, :] + + sin = torch.sin(sinusoid) + cos = torch.cos(sinusoid) + x1, x2 = torch.chunk(x, 2, dim=-1) + + return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) + + +class Scale(nn.Module): + """Scale vector by element multiplications. + + Args: + dim: The dimension of the scale vector. + init_value: The initial value of the scale vector. + trainable: Whether the scale vector is trainable. + """ + + def __init__(self, dim, init_value=1.0, trainable=True): + super().__init__() + self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) + + def forward(self, x): + return x * self.scale + + +class GatedAttentionUnit(nn.Module): + """Gated Attention Unit (GAU) in RTMBlock""" + + def __init__( + self, + num_token, + in_token_dims, + out_token_dims, + expansion_factor=2, + s=128, + eps=1e-5, + dropout_rate=0.0, + drop_path=0.0, + attn_type="self-attn", + act_fn="SiLU", + bias=False, + use_rel_bias=True, + pos_enc=False, + ): + super(GatedAttentionUnit, self).__init__() + self.s = s + self.num_token = num_token + self.use_rel_bias = use_rel_bias + self.attn_type = attn_type + self.pos_enc = pos_enc + + if drop_path > 0.0: + self.drop_path = timm_layers.DropPath(drop_path) + else: + self.drop_path = nn.Identity() + + self.e = int(in_token_dims * expansion_factor) + if use_rel_bias: + if attn_type == "self-attn": + self.w = nn.Parameter( + torch.rand([2 * num_token - 1], dtype=torch.float) + ) + else: + self.a = nn.Parameter(torch.rand([1, s], dtype=torch.float)) + self.b = nn.Parameter(torch.rand([1, s], dtype=torch.float)) + self.o = nn.Linear(self.e, out_token_dims, bias=bias) + + if attn_type == "self-attn": + self.uv = nn.Linear(in_token_dims, 2 * self.e + self.s, bias=bias) + self.gamma = nn.Parameter(torch.rand((2, self.s))) + self.beta = nn.Parameter(torch.rand((2, self.s))) + else: + self.uv = nn.Linear(in_token_dims, self.e + self.s, bias=bias) + self.k_fc = nn.Linear(in_token_dims, self.s, bias=bias) + self.v_fc = nn.Linear(in_token_dims, self.e, bias=bias) + nn.init.xavier_uniform_(self.k_fc.weight) + nn.init.xavier_uniform_(self.v_fc.weight) + + self.ln = ScaleNorm(in_token_dims, eps=eps) + + nn.init.xavier_uniform_(self.uv.weight) + + if act_fn == "SiLU" or act_fn == nn.SiLU: + self.act_fn = nn.SiLU(True) + elif act_fn == "ReLU" or act_fn == nn.ReLU: + self.act_fn = nn.ReLU(True) + else: + raise NotImplementedError + + if in_token_dims == out_token_dims: + self.shortcut = True + self.res_scale = Scale(in_token_dims) + else: + self.shortcut = False + + self.sqrt_s = math.sqrt(s) + + self.dropout_rate = dropout_rate + + if dropout_rate > 0.0: + self.dropout = nn.Dropout(dropout_rate) + + def rel_pos_bias(self, seq_len, k_len=None): + """Add relative position bias.""" + + if self.attn_type == "self-attn": + t = F.pad(self.w[: 2 * seq_len - 1], [0, seq_len]).repeat(seq_len) + t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2) + r = (2 * seq_len - 1) // 2 + t = t[..., r:-r] + else: + a = rope(self.a.repeat(seq_len, 1), dim=0) + b = rope(self.b.repeat(k_len, 1), dim=0) + t = torch.bmm(a, b.permute(0, 2, 1)) + return t + + def _forward(self, inputs): + """GAU Forward function.""" + + if self.attn_type == "self-attn": + x = inputs + else: + x, k, v = inputs + + x = self.ln(x) + + # [B, K, in_token_dims] -> [B, K, e + e + s] + uv = self.uv(x) + uv = self.act_fn(uv) + + if self.attn_type == "self-attn": + # [B, K, e + e + s] -> [B, K, e], [B, K, e], [B, K, s] + u, v, base = torch.split(uv, [self.e, self.e, self.s], dim=2) + # [B, K, 1, s] * [1, 1, 2, s] + [2, s] -> [B, K, 2, s] + base = base.unsqueeze(2) * self.gamma[None, None, :] + self.beta + + if self.pos_enc: + base = rope(base, dim=1) + # [B, K, 2, s] -> [B, K, s], [B, K, s] + q, k = torch.unbind(base, dim=2) + + else: + # [B, K, e + s] -> [B, K, e], [B, K, s] + u, q = torch.split(uv, [self.e, self.s], dim=2) + + k = self.k_fc(k) # -> [B, K, s] + v = self.v_fc(v) # -> [B, K, e] + + if self.pos_enc: + q = rope(q, 1) + k = rope(k, 1) + + # [B, K, s].permute() -> [B, s, K] + # [B, K, s] x [B, s, K] -> [B, K, K] + qk = torch.bmm(q, k.permute(0, 2, 1)) + + if self.use_rel_bias: + if self.attn_type == "self-attn": + bias = self.rel_pos_bias(q.size(1)) + else: + bias = self.rel_pos_bias(q.size(1), k.size(1)) + qk += bias[:, : q.size(1), : k.size(1)] + # [B, K, K] + kernel = torch.square(F.relu(qk / self.sqrt_s)) + + if self.dropout_rate > 0.0: + kernel = self.dropout(kernel) + # [B, K, K] x [B, K, e] -> [B, K, e] + x = u * torch.bmm(kernel, v) + + # [B, K, e] -> [B, K, out_token_dims] + x = self.o(x) + + return x + + def forward(self, x): + if self.shortcut: + if self.attn_type == "cross-attn": + res_shortcut = x[0] + else: + res_shortcut = x + main_branch = self.drop_path(self._forward(x)) + return self.res_scale(res_shortcut) + main_branch + else: + return self.drop_path(self._forward(x)) diff --git a/dlclive/pose_estimation_pytorch/models/modules/norm.py b/dlclive/pose_estimation_pytorch/models/modules/norm.py new file mode 100644 index 0000000..9bf839b --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/modules/norm.py @@ -0,0 +1,41 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Normalization layers""" +from __future__ import annotations + +import torch +import torch.nn as nn + + +class ScaleNorm(nn.Module): + """Implementation of ScaleNorm + + ScaleNorm was introduced in "Transformers without Tears: Improving the Normalization + of Self-Attention". + + Code based on the `mmpose` implementation. See https://github.com/open-mmlab/mmpose + for more details. + + Args: + dim: The dimension of the scale vector. + eps: The minimum value in clamp. + """ + + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.linalg.norm(x, dim=-1, keepdim=True) + norm = norm * self.scale + return x / norm.clamp(min=self.eps) * self.g diff --git a/dlclive/pose_estimation_pytorch/models/necks/__init__.py b/dlclive/pose_estimation_pytorch/models/necks/__init__.py new file mode 100644 index 0000000..a0bcb2e --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/necks/__init__.py @@ -0,0 +1,12 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from dlclive.pose_estimation_pytorch.models.necks.base import NECKS, BaseNeck +from dlclive.pose_estimation_pytorch.models.necks.transformer import Transformer diff --git a/dlclive/pose_estimation_pytorch/models/necks/base.py b/dlclive/pose_estimation_pytorch/models/necks/base.py new file mode 100644 index 0000000..0b3cb6e --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/necks/base.py @@ -0,0 +1,48 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from abc import ABC, abstractmethod + +import torch + +from dlclive.pose_estimation_pytorch.models.registry import Registry, build_from_cfg + +NECKS = Registry("necks", build_func=build_from_cfg) + + +class BaseNeck(ABC, torch.nn.Module): + """Base Neck class for pose estimation""" + + def __init__(self): + super().__init__() + + @abstractmethod + def forward(self, x: torch.Tensor): + """Abstract method for the forward pass through the Neck. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + pass + + def _init_weights(self, pretrained: str): + """Initialize the Neck with pretrained weights. + + Args: + pretrained: Path to the pretrained weights. + + Returns: + None + """ + if pretrained: + self.model.load_state_dict(torch.load(pretrained)) diff --git a/dlclive/pose_estimation_pytorch/models/necks/layers.py b/dlclive/pose_estimation_pytorch/models/necks/layers.py new file mode 100644 index 0000000..a25ad50 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/necks/layers.py @@ -0,0 +1,287 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class Residual(torch.nn.Module): + """Residual block module. + + This module implements a residual block for the transformer layers. + + Attributes: + fn: The function to apply in the residual block. + """ + + def __init__(self, fn: torch.nn.Module): + """Initialize the Residual block. + + Args: + fn: The function to apply in the residual block. + """ + super().__init__() + self.fn = fn + + def forward(self, x: torch.Tensor, **kwargs): + """Forward pass through the Residual block. + + Args: + x: Input tensor. + **kwargs: Additional keyword arguments for the function. + + Returns: + Output tensor. + """ + return self.fn(x, **kwargs) + x + + +class PreNorm(torch.nn.Module): + """PreNorm block module. + + This module implements pre-normalization for the transformer layers. + + Attributes: + dim: Dimension of the input tensor. + fn: The function to apply after normalization. + fusion_factor: Fusion factor for layer normalization. + Defaults to 1. + """ + + def __init__(self, dim: int, fn: torch.nn.Module, fusion_factor: int = 1): + """Initialize the PreNorm block. + + Args: + dim: Dimension of the input tensor. + fn: The function to apply after normalization. + fusion_factor: Fusion factor for layer normalization. + Defaults to 1. + """ + super().__init__() + self.norm = torch.nn.LayerNorm(dim * fusion_factor) + self.fn = fn + + def forward(self, x, **kwargs): + """Forward pass through the PreNorm block. + + Args: + x: Input tensor. + **kwargs: Additional keyword arguments for the function. + + Returns: + Output tensor. + """ + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(torch.nn.Module): + """FeedForward block module. + + This module implements the feedforward layer in the transformer layers. + + Attributes: + dim: Dimension of the input tensor. + hidden_dim: Dimension of the hidden layer. + dropout: Dropout rate. Defaults to 0.0. + """ + + def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0): + """Initialize the FeedForward block. + + Args: + dim: Dimension of the input tensor. + hidden_dim: Dimension of the hidden layer. + dropout: Dropout rate. Defaults to 0.0. + """ + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Linear(dim, hidden_dim), + torch.nn.GELU(), + torch.nn.Dropout(dropout), + torch.nn.Linear(hidden_dim, dim), + torch.nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor): + """Forward pass through the FeedForward block. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + return self.net(x) + + +class Attention(torch.nn.Module): + """Attention block module. + + This module implements the attention mechanism in the transformer layers. + + Attributes: + dim: Dimension of the input tensor. + heads: Number of attention heads. Defaults to 8. + dropout: Dropout rate. Defaults to 0.0. + num_keypoints: Number of keypoints. Defaults to None. + scale_with_head: Scale attention with the number of heads. + Defaults to False. + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dropout: float = 0.0, + num_keypoints: int = None, + scale_with_head: bool = False, + ): + """Initialize the Attention block. + + Args: + dim: Dimension of the input tensor. + heads: Number of attention heads. Defaults to 8. + dropout: Dropout rate. Defaults to 0.0. + num_keypoints: Number of keypoints. Defaults to None. + scale_with_head: Scale attention with the number of heads. + Defaults to False. + """ + super().__init__() + self.heads = heads + self.scale = (dim // heads) ** -0.5 if scale_with_head else dim**-0.5 + + self.to_qkv = torch.nn.Linear(dim, dim * 3, bias=False) + self.to_out = torch.nn.Sequential( + torch.nn.Linear(dim, dim), torch.nn.Dropout(dropout) + ) + self.num_keypoints = num_keypoints + + def forward(self, x: torch.Tensor, mask: torch.Tensor = None): + """Forward pass through the Attention block. + + Args: + x: Input tensor. + mask: Attention mask. Defaults to None. + + Returns: + Output tensor. + """ + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) + + dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max + + if mask is not None: + mask = F.pad(mask.flatten(1), (1, 0), value=True) + assert mask.shape[-1] == dots.shape[-1], "mask has incorrect dimensions" + mask = mask[:, None, :] * mask[:, :, None] + dots.masked_fill_(~mask, mask_value) + del mask + + attn = dots.softmax(dim=-1) + + out = torch.einsum("bhij,bhjd->bhid", attn, v) + + out = rearrange(out, "b h n d -> b n (h d)") + out = self.to_out(out) + return out + + +class TransformerLayer(torch.nn.Module): + """TransformerLayer block module. + + This module implements the Transformer layer in the transformer model. + + Attributes: + dim: Dimension of the input tensor. + depth: Depth of the transformer layer. + heads: Number of attention heads. + mlp_dim: Dimension of the MLP layer. + dropout: Dropout rate. + num_keypoints: Number of keypoints. Defaults to None. + all_attn: Apply attention to all keypoints. + Defaults to False. + scale_with_head: Scale attention with the number of heads. + Defaults to False. + """ + + def __init__( + self, + dim: int, + depth: int, + heads: int, + mlp_dim: int, + dropout: float, + num_keypoints: int = None, + all_attn: bool = False, + scale_with_head: bool = False, + ): + """Initialize the TransformerLayer block. + + Args: + dim: Dimension of the input tensor. + depth: Depth of the transformer layer. + heads: Number of attention heads. + mlp_dim: Dimension of the MLP layer. + dropout: Dropout rate. + num_keypoints: Number of keypoints. Defaults to None. + all_attn: Apply attention to all keypoints. Defaults to False. + scale_with_head: Scale attention with the number of heads. Defaults to False. + """ + super().__init__() + self.layers = torch.nn.ModuleList([]) + self.all_attn = all_attn + self.num_keypoints = num_keypoints + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + Residual( + PreNorm( + dim, + Attention( + dim, + heads=heads, + dropout=dropout, + num_keypoints=num_keypoints, + scale_with_head=scale_with_head, + ), + ) + ), + Residual( + PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) + ), + ] + ) + ) + + def forward( + self, x: torch.Tensor, mask: torch.Tensor = None, pos: torch.Tensor = None + ): + """Forward pass through the TransformerLayer block. + + Args: + x: Input tensor. + mask: Attention mask. Defaults to None. + pos: Positional encoding. Defaults to None. + + Returns: + Output tensor. + """ + for idx, (attn, ff) in enumerate(self.layers): + if idx > 0 and self.all_attn: + x[:, self.num_keypoints :] += pos + x = attn(x, mask=mask) + x = ff(x) + return x diff --git a/dlclive/pose_estimation_pytorch/models/necks/transformer.py b/dlclive/pose_estimation_pytorch/models/necks/transformer.py new file mode 100644 index 0000000..eae5118 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/necks/transformer.py @@ -0,0 +1,276 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from typing import Tuple + +import torch +from einops import rearrange, repeat +from timm.layers import trunc_normal_ + +from dlclive.pose_estimation_pytorch.models.necks.base import NECKS, BaseNeck +from dlclive.pose_estimation_pytorch.models.necks.layers import TransformerLayer +from dlclive.pose_estimation_pytorch.models.necks.utils import ( + make_sine_position_embedding, +) + +MIN_NUM_PATCHES = 16 +BN_MOMENTUM = 0.1 + + +@NECKS.register_module +class Transformer(BaseNeck): + """Transformer Neck for pose estimation. + title={TokenPose: Learning Keypoint Tokens for Human Pose Estimation}, + author={Yanjie Li and Shoukui Zhang and Zhicheng Wang and Sen Yang and Wankou Yang and Shu-Tao Xia and Erjin Zhou}, + booktitle={IEEE/CVF International Conference on Computer Vision (ICCV)}, + year={2021} + + Args: + feature_size: Size of the input feature map (height, width). + patch_size: Size of each patch used in the transformer. + num_keypoints: Number of keypoints in the pose estimation task. + dim: Dimension of the transformer. + depth: Number of transformer layers. + heads: Number of self-attention heads in the transformer. + mlp_dim: Dimension of the MLP used in the transformer. + Defaults to 3. + apply_init: Whether to apply weight initialization. + Defaults to False. + heatmap_size: Size of the heatmap. Defaults to [64, 64]. + channels: Number of channels in each patch. Defaults to 32. + dropout: Dropout rate for embeddings. Defaults to 0.0. + emb_dropout: Dropout rate for transformer layers. + Defaults to 0.0. + pos_embedding_type: Type of positional embedding. + Either 'sine-full', 'sine', or 'learnable'. + Defaults to "sine-full". + + Examples: + # Creating a Transformer neck with sine positional embedding + transformer = Transformer( + feature_size=(128, 128), + patch_size=(16, 16), + num_keypoints=17, + dim=256, + depth=6, + heads=8, + pos_embedding_type="sine" + ) + + # Creating a Transformer neck with learnable positional embedding + transformer = Transformer( + feature_size=(256, 256), + patch_size=(32, 32), + num_keypoints=17, + dim=512, + depth=12, + heads=16, + pos_embedding_type="learnable" + ) + """ + + def __init__( + self, + *, + feature_size: Tuple[int, int], + patch_size: Tuple[int, int], + num_keypoints: int, + dim: int, + depth: int, + heads: int, + mlp_dim: int = 3, + apply_init: bool = False, + heatmap_size: Tuple[int, int] = (64, 64), + channels: int = 32, + dropout: float = 0.0, + emb_dropout: float = 0.0, + pos_embedding_type: str = "sine-full" + ): + super().__init__() + + num_patches = (feature_size[0] // (patch_size[0])) * ( + feature_size[1] // (patch_size[1]) + ) + patch_dim = channels * patch_size[0] * patch_size[1] + + self.inplanes = 64 + self.patch_size = patch_size + self.heatmap_size = heatmap_size + self.num_keypoints = num_keypoints + self.num_patches = num_patches + self.pos_embedding_type = pos_embedding_type + self.all_attn = self.pos_embedding_type == "sine-full" + + self.keypoint_token = torch.nn.Parameter( + torch.zeros(1, self.num_keypoints, dim) + ) + h, w = ( + feature_size[0] // (self.patch_size[0]), + feature_size[1] // (self.patch_size[1]), + ) + + self._make_position_embedding(w, h, dim, pos_embedding_type) + + self.patch_to_embedding = torch.nn.Linear(patch_dim, dim) + self.dropout = torch.nn.Dropout(emb_dropout) + + self.transformer1 = TransformerLayer( + dim, + depth, + heads, + mlp_dim, + dropout, + num_keypoints=num_keypoints, + scale_with_head=True, + ) + self.transformer2 = TransformerLayer( + dim, + depth, + heads, + mlp_dim, + dropout, + num_keypoints=num_keypoints, + all_attn=self.all_attn, + scale_with_head=True, + ) + self.transformer3 = TransformerLayer( + dim, + depth, + heads, + mlp_dim, + dropout, + num_keypoints=num_keypoints, + all_attn=self.all_attn, + scale_with_head=True, + ) + + self.to_keypoint_token = torch.nn.Identity() + + if apply_init: + self.apply(self._init_weights) + + def _make_position_embedding( + self, w: int, h: int, d_model: int, pe_type="learnable" + ): + """Create position embeddings for the transformer. + + Args: + w: Width of the input feature map. + h: Height of the input feature map. + d_model: Dimension of the transformer encoder. + pe_type: Type of position embeddings. + Either "learnable" or "sine". Defaults to "learnable". + """ + with torch.no_grad(): + self.pe_h = h + self.pe_w = w + length = h * w + if pe_type != "learnable": + self.pos_embedding = torch.nn.Parameter( + make_sine_position_embedding(h, w, d_model), requires_grad=False + ) + else: + self.pos_embedding = torch.nn.Parameter( + torch.zeros(1, self.num_patches + self.num_keypoints, d_model) + ) + + def _make_layer( + self, block: torch.nn.Module, planes: int, blocks: int, stride: int = 1 + ) -> torch.nn.Sequential: + """Create a layer of the transformer encoder. + + Args: + block: The basic building block of the layer. + planes: Number of planes in the layer. + blocks: Number of blocks in the layer. + stride: Stride value. Defaults to 1. + + Returns: + The layer of the transformer encoder. + """ + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = torch.nn.Sequential( + torch.nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + torch.nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return torch.nn.Sequential(*layers) + + def _init_weights(self, m: torch.nn.Module): + """Initialize the weights of the model. + + Args: + m: A module of the model. + """ + print("Initialization...") + if isinstance(m, torch.nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, torch.nn.Linear) and m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + elif isinstance(m, torch.nn.LayerNorm): + torch.nn.init.constant_(m.bias, 0) + torch.nn.init.constant_(m.weight, 1.0) + + def forward(self, feature: torch.Tensor, mask=None) -> torch.Tensor: + """Forward pass through the Transformer neck. + + Args: + feature: Input feature map. + mask: Mask to apply to the transformer. + Defaults to None. + + Returns: + Output tensor from the transformer neck. + + Examples: + # Assuming feature is a torch.Tensor of shape (batch_size, channels, height, width) + output = transformer(feature) + """ + p = self.patch_size + + x = rearrange( + feature, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p[0], p2=p[1] + ) + x = self.patch_to_embedding(x) + + b, n, _ = x.shape + + keypoint_tokens = repeat(self.keypoint_token, "() n d -> b n d", b=b) + if self.pos_embedding_type in ["sine", "sine-full"]: + x += self.pos_embedding[:, :n] + x = torch.cat((keypoint_tokens, x), dim=1) + else: + x = torch.cat((keypoint_tokens, x), dim=1) + x += self.pos_embedding[:, : (n + self.num_keypoints)] + x = self.dropout(x) + + x1 = self.transformer1(x, mask, self.pos_embedding) + x2 = self.transformer2(x1, mask, self.pos_embedding) + x3 = self.transformer3(x2, mask, self.pos_embedding) + + x1_out = self.to_keypoint_token(x1[:, 0 : self.num_keypoints]) + x2_out = self.to_keypoint_token(x2[:, 0 : self.num_keypoints]) + x3_out = self.to_keypoint_token(x3[:, 0 : self.num_keypoints]) + + x = torch.cat((x1_out, x2_out, x3_out), dim=2) + return x diff --git a/dlclive/pose_estimation_pytorch/models/necks/utils.py b/dlclive/pose_estimation_pytorch/models/necks/utils.py new file mode 100644 index 0000000..028078b --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/necks/utils.py @@ -0,0 +1,60 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# + +import math + +import torch + + +def make_sine_position_embedding( + h: int, w: int, d_model: int, temperature: int = 10000, scale: float = 2 * math.pi +) -> torch.Tensor: + """Generate sine position embeddings for a given height, width, and model dimension. + + Args: + h: Height of the embedding. + w: Width of the embedding. + d_model: Dimension of the model. + temperature: Temperature parameter for position embedding calculation. + Defaults to 10000. + scale: Scaling factor for position embedding. Defaults to 2 * math.pi. + + Returns: + Sine position embeddings with shape (batch_size, d_model, h * w). + + Example: + >>> h, w, d_model = 10, 20, 512 + >>> pos_emb = make_sine_position_embedding(h, w, d_model) + >>> print(pos_emb.shape) # Output: torch.Size([1, 512, 200]) + """ + area = torch.ones(1, h, w) + y_embed = area.cumsum(1, dtype=torch.float32) + x_embed = area.cumsum(2, dtype=torch.float32) + one_direction_feats = d_model // 2 + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale + + dim_t = torch.arange(one_direction_feats, dtype=torch.float32) + dim_t = temperature ** (2 * (dim_t // 2) / one_direction_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = pos.flatten(2).permute(0, 2, 1) + + return pos diff --git a/dlclive/pose_estimation_pytorch/models/predictors/__init__.py b/dlclive/pose_estimation_pytorch/models/predictors/__init__.py new file mode 100644 index 0000000..0662ffa --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/predictors/__init__.py @@ -0,0 +1,24 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from dlclive.pose_estimation_pytorch.models.predictors.base import ( + PREDICTORS, + BasePredictor, +) +from dlclive.pose_estimation_pytorch.models.predictors.dekr_predictor import ( + DEKRPredictor, +) +from dlclive.pose_estimation_pytorch.models.predictors.sim_cc import SimCCPredictor +from dlclive.pose_estimation_pytorch.models.predictors.single_predictor import ( + HeatmapPredictor, +) +from dlclive.pose_estimation_pytorch.models.predictors.paf_predictor import ( + PartAffinityFieldPredictor, +) diff --git a/dlclive/pose_estimation_pytorch/models/predictors/base.py b/dlclive/pose_estimation_pytorch/models/predictors/base.py new file mode 100644 index 0000000..f8b8b98 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/predictors/base.py @@ -0,0 +1,64 @@ +""" +DeepLabCut Toolbox (deeplabcut.org) +© A. & M. Mathis Labs + +Licensed under GNU Lesser General Public License v3.0 +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import torch +from torch import nn + +from dlclive.pose_estimation_pytorch.models.registry import Registry, build_from_cfg + +PREDICTORS = Registry("predictors", build_func=build_from_cfg) + + +class BasePredictor(ABC, nn.Module): + """The base Predictor class. + + This class is an abstract base class (ABC) for defining predictors used in the + DeepLabCut Toolbox. All predictor classes should inherit from this base class and + implement the forward method. Regresses keypoint coordinates from a models output + maps + + Attributes: + num_animals: Number of animals in the project. Should be set in subclasses. + + Example: + # Create a subclass that inherits from BasePredictor + class MyPredictor(BasePredictor): + def __init__(self, num_animals): + super().__init__() + self.num_animals = num_animals + + def forward(self, outputs): + # Implement the forward pass of your custom predictor here. + pass + """ + + def __init__(self): + super().__init__() + self.num_animals = None + + @abstractmethod + def forward( + self, stride: float, outputs: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Abstract method for the forward pass of the Predictor. + + Args: + stride: the stride of the model + outputs: outputs of the model heads + + Returns: + A dictionary containing a "poses" key with the output tensor as value, and + optionally a "unique_bodyparts" with the unique bodyparts tensor as value. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + pass diff --git a/dlclive/pose_estimation_pytorch/models/predictors/dekr_predictor.py b/dlclive/pose_estimation_pytorch/models/predictors/dekr_predictor.py new file mode 100644 index 0000000..7261f29 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/predictors/dekr_predictor.py @@ -0,0 +1,408 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# + +from __future__ import annotations + +import torch +import torch.nn.functional as F + +from dlclive.pose_estimation_pytorch.models.predictors.base import ( + PREDICTORS, + BasePredictor, +) + + +@PREDICTORS.register_module +class DEKRPredictor(BasePredictor): + """DEKR Predictor class for multi-animal pose estimation. + + This class regresses keypoints and assembles them (if multianimal project) + from the output of DEKR (Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression). + Based on: + Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression + Zigang Geng, Ke Sun, Bin Xiao, Zhaoxiang Zhang, Jingdong Wang + CVPR + 2021 + Code based on: + https://github.com/HRNet/DEKR + + Args: + num_animals (int): Number of animals in the project. + detection_threshold (float, optional): Threshold for detection. Defaults to 0.01. + apply_sigmoid (bool, optional): Apply sigmoid to heatmaps. Defaults to True. + use_heatmap (bool, optional): Use heatmap to refine keypoint predictions. Defaults to True. + keypoint_score_type (str): Type of score to compute for keypoints. "heatmap" applies the heatmap + score to each keypoint. "center" applies the score of the center of each individual to + all of its keypoints. "combined" multiplies the score of the heatmap and individual + center for each keypoint. + + Attributes: + num_animals (int): Number of animals in the project. + detection_threshold (float): Threshold for detection. + apply_sigmoid (bool): Apply sigmoid to heatmaps. + use_heatmap (bool): Use heatmap. + keypoint_score_type (str): Type of score to compute for keypoints. "heatmap" applies the heatmap + score to each keypoint. "center" applies the score of the center of each individual to + all of its keypoints. "combined" multiplies the score of the heatmap and individual + center for each keypoint. + + Example: + # Create a DEKRPredictor instance with 2 animals. + predictor = DEKRPredictor(num_animals=2) + + # Make a forward pass with outputs and scale factors. + outputs = (heatmaps, offsets) # tuple of heatmaps and offsets + scale_factors = (0.5, 0.5) # tuple of scale factors for the poses + poses_with_scores = predictor.forward(outputs, scale_factors) + """ + + default_init = {"apply_sigmoid": True, "detection_threshold": 0.01} + + def __init__( + self, + num_animals: int, + detection_threshold: float = 0.01, + apply_sigmoid: bool = True, + clip_scores: bool = False, + use_heatmap: bool = True, + keypoint_score_type: str = "combined", + max_absorb_distance: int = 75, + ): + """ + Args: + num_animals: Number of animals in the project. + detection_threshold: Threshold for detection + apply_sigmoid: Apply sigmoid to heatmaps + clip_scores: If a sigmoid is not applied, this can be used to clip scores + for predicted keypoints to values in [0, 1]. + use_heatmap: Use heatmap to refine the keypoint predictions. + keypoint_score_type: Type of score to compute for keypoints. "heatmap" + applies the heatmap score to each keypoint. "center" applies the score + of the center of each individual to all of its keypoints. "combined" + multiplies the score of the heatmap and individual for each keypoint. + """ + super().__init__() + self.num_animals = num_animals + self.detection_threshold = detection_threshold + self.apply_sigmoid = apply_sigmoid + self.clip_scores = clip_scores + self.use_heatmap = use_heatmap + self.keypoint_score_type = keypoint_score_type + if self.keypoint_score_type not in ("heatmap", "center", "combined"): + raise ValueError(f"Unknown keypoint score type: {self.keypoint_score_type}") + + # TODO: Set as in HRNet/DEKR configs. Define as a constant. + self.max_absorb_distance = max_absorb_distance + + def forward( + self, stride: float, outputs: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Forward pass of DEKRPredictor. + + Args: + stride: the stride of the model + outputs: outputs of the model heads (heatmap, locref) + + Returns: + A dictionary containing a "poses" key with the output tensor as value, and + optionally a "unique_bodyparts" with the unique bodyparts tensor as value. + + Example: + # Assuming you have 'outputs' (heatmaps and offsets) and 'scale_factors' for poses + poses_with_scores = predictor.forward(outputs, scale_factors) + """ + heatmaps, offsets = outputs["heatmap"], outputs["offset"] + scale_factors = stride, stride + + if self.apply_sigmoid: + heatmaps = F.sigmoid(heatmaps) + + posemap = self.offset_to_pose(offsets) + + batch_size, num_joints_with_center, h, w = heatmaps.shape + num_joints = num_joints_with_center - 1 + + center_heatmaps = heatmaps[:, -1] + pose_ind, ctr_scores = self.get_top_values(center_heatmaps) + + posemap = posemap.permute(0, 2, 3, 1).view(batch_size, h * w, -1, 2) + poses = torch.zeros(batch_size, pose_ind.shape[1], num_joints, 2).to( + ctr_scores.device + ) + for i in range(batch_size): + pose = posemap[i, pose_ind[i]] + poses[i] = pose + + if self.use_heatmap: + poses = self._update_pose_with_heatmaps(poses, heatmaps[:, :-1]) + + if self.keypoint_score_type == "center": + score = ( + ctr_scores.unsqueeze(-1) + .expand(batch_size, -1, num_joints) + .unsqueeze(-1) + ) + elif self.keypoint_score_type == "heatmap": + score = self.get_heat_value(poses, heatmaps).unsqueeze(-1) + elif self.keypoint_score_type == "combined": + center_score = ( + ctr_scores.unsqueeze(-1) + .expand(batch_size, -1, num_joints) + .unsqueeze(-1) + ) + htmp_score = self.get_heat_value(poses, heatmaps).unsqueeze(-1) + score = center_score * htmp_score + else: + raise ValueError(f"Unknown keypoint score type: {self.keypoint_score_type}") + + poses[:, :, :, 0] = ( + poses[:, :, :, 0] * scale_factors[1] + 0.5 * scale_factors[1] + ) + poses[:, :, :, 1] = ( + poses[:, :, :, 1] * scale_factors[0] + 0.5 * scale_factors[0] + ) + + if self.clip_scores: + score = torch.clip(score, min=0, max=1) + + poses_w_scores = torch.cat([poses, score], dim=3) + # self.pose_nms(heatmaps, poses_w_scores) + return {"poses": poses_w_scores} + + def get_locations( + self, height: int, width: int, device: torch.device + ) -> torch.Tensor: + """Get locations for offsets. + + Args: + height: Height of the offsets. + width: Width of the offsets. + device: Device to use. + + Returns: + Offset locations. + + Example: + # Assuming you have 'height', 'width', and 'device' + locations = predictor.get_locations(height, width, device) + """ + shifts_x = torch.arange(0, width, step=1, dtype=torch.float32).to(device) + shifts_y = torch.arange(0, height, step=1, dtype=torch.float32).to(device) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + locations = torch.stack((shift_x, shift_y), dim=1) + return locations + + def get_reg_poses(self, offsets: torch.Tensor, num_joints: int) -> torch.Tensor: + """Get the regression poses from offsets. + + Args: + offsets: Offsets tensor. + num_joint: Number of joints. + + Returns: + Regression poses. + + Example: + # Assuming you have 'offsets' tensor and 'num_joints' + regression_poses = predictor.get_reg_poses(offsets, num_joints) + """ + batch_size, _, h, w = offsets.shape + offsets = offsets.permute(0, 2, 3, 1).reshape(batch_size, h * w, num_joints, 2) + locations = self.get_locations(h, w, offsets.device) + locations = locations[None, :, None, :].expand(batch_size, -1, num_joints, -1) + poses = locations - offsets + + return poses + + def offset_to_pose(self, offsets: torch.Tensor) -> torch.Tensor: + """Convert offsets to poses. + + Args: + offsets: Offsets tensor. + + Returns: + Poses from offsets. + + Example: + # Assuming you have 'offsets' tensor + poses = predictor.offset_to_pose(offsets) + """ + batch_size, num_offset, h, w = offsets.shape + num_joints = int(num_offset / 2) + reg_poses = self.get_reg_poses(offsets, num_joints) + + reg_poses = ( + reg_poses.contiguous() + .view(batch_size, h * w, 2 * num_joints) + .permute(0, 2, 1) + ) + reg_poses = reg_poses.contiguous().view(batch_size, -1, h, w).contiguous() + + return reg_poses + + def max_pool(self, heatmap: torch.Tensor) -> torch.Tensor: + """Apply max pooling to the heatmap. + + Args: + heatmap: Heatmap tensor. + + Returns: + Max pooled heatmap. + + Example: + # Assuming you have 'heatmap' tensor + max_pooled_heatmap = predictor.max_pool(heatmap) + """ + pool1 = torch.nn.MaxPool2d(3, 1, 1) + pool2 = torch.nn.MaxPool2d(5, 1, 2) + pool3 = torch.nn.MaxPool2d(7, 1, 3) + map_size = (heatmap.shape[1] + heatmap.shape[2]) / 2.0 + maxm = pool2( + heatmap + ) # Here I think pool 2 is a good match for default 17 pos_dist_tresh + + return maxm + + def get_top_values( + self, heatmap: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Get top values from the heatmap. + + Args: + heatmap: Heatmap tensor. + + Returns: + Position indices and scores. + + Example: + # Assuming you have 'heatmap' tensor + positions, scores = predictor.get_top_values(heatmap) + """ + maximum = self.max_pool(heatmap) + maximum = torch.eq(maximum, heatmap) + heatmap *= maximum + + batchsize, ny, nx = heatmap.shape + heatmap_flat = heatmap.reshape(batchsize, nx * ny) + + scores, pos_ind = torch.topk(heatmap_flat, self.num_animals, dim=1) + + return pos_ind, scores + + ########## WIP to take heatmap into account for scoring ########## + def _update_pose_with_heatmaps( + self, _poses: torch.Tensor, kpt_heatmaps: torch.Tensor + ): + """If a heatmap center is close enough from the regressed point, the final prediction is the center of this heatmap + + Args: + poses: poses tensor, shape (batch_size, num_animals, num_keypoints, 2) + kpt_heatmaps: heatmaps (does not contain the center heatmap), shape (batch_size, num_keypoints, h, w) + """ + poses = _poses.clone() + maxm = self.max_pool(kpt_heatmaps) + maxm = torch.eq(maxm, kpt_heatmaps).float() + kpt_heatmaps *= maxm + batch_size, num_keypoints, h, w = kpt_heatmaps.shape + kpt_heatmaps = kpt_heatmaps.view(batch_size, num_keypoints, -1) + val_k, ind = kpt_heatmaps.topk(self.num_animals, dim=2) + + x = ind % w + y = (ind / w).long() + heats_ind = torch.stack((x, y), dim=3) + + for b in range(batch_size): + for i in range(num_keypoints): + heat_ind = heats_ind[b, i].float() + pose_ind = poses[b, :, i] + pose_heat_diff = pose_ind[:, None, :] - heat_ind + pose_heat_diff.pow_(2) + pose_heat_diff = pose_heat_diff.sum(2) + pose_heat_diff.sqrt_() + keep_ind = torch.argmin(pose_heat_diff, dim=1) + + for p in range(keep_ind.shape[0]): + if pose_heat_diff[p, keep_ind[p]] < self.max_absorb_distance: + poses[b, p, i] = heat_ind[keep_ind[p]] + + return poses + + def get_heat_value( + self, pose_coords: torch.Tensor, heatmaps: torch.Tensor + ) -> torch.Tensor: + """Get heat values for pose coordinates and heatmaps. + + Args: + pose_coords: Pose coordinates tensor (batch_size, num_animals, num_joints, 2) + heatmaps: Heatmaps tensor (batch_size, 1+num_joints, h, w). + + Returns: + Heat values. + + Example: + # Assuming you have 'pose_coords' and 'heatmaps' tensors + heat_values = predictor.get_heat_value(pose_coords, heatmaps) + """ + h, w = heatmaps.shape[2:] + heatmaps_nocenter = heatmaps[:, :-1].flatten( + 2, 3 + ) # (batch_size, num_joints, h*w) + + # Predicted poses based on the offset can be outside of the image + x = torch.clamp(torch.floor(pose_coords[:, :, :, 0]), 0, w - 1).long() + y = torch.clamp(torch.floor(pose_coords[:, :, :, 1]), 0, h - 1).long() + keypoint_poses = (y * w + x).mT # (batch, num_joints, num_individuals) + heatscores = torch.gather(heatmaps_nocenter, 2, keypoint_poses) + return heatscores.mT # (batch, num_individuals, num_joints) + + def pose_nms(self, heatmaps: torch.Tensor, poses: torch.Tensor): + """Non-Maximum Suppression (NMS) for regressed poses. + + Args: + heatmaps: Heatmaps tensor. + poses: Pose proposals. + + Returns: + None + + Example: + # Assuming you have 'heatmaps' and 'poses' tensors + predictor.pose_nms(heatmaps, poses) + """ + pose_scores = poses[:, :, :, 2] + pose_coords = poses[:, :, :, :2] + + if pose_coords.shape[1] == 0: + return [], [] + + batch_size, num_people, num_joints, _ = pose_coords.shape + heatvals = self.get_heat_value(pose_coords, heatmaps) + heat_score = (torch.sum(heatvals, dim=1) / num_joints)[:, 0] + + # return heat_score + # pose_score = pose_score*heatvals + # poses = torch.cat([pose_coord.cpu(), pose_score.cpu()], dim=2) + + # keep_pose_inds = nms_core(cfg, pose_coord, heat_score) + # poses = poses[keep_pose_inds] + # heat_score = heat_score[keep_pose_inds] + + # if len(keep_pose_inds) > cfg.DATASET.MAX_NUM_PEOPLE: + # heat_score, topk_inds = torch.topk(heat_score, + # cfg.DATASET.MAX_NUM_PEOPLE) + # poses = poses[topk_inds] + + # poses = [poses.numpy()] + # scores = [i[:, 2].mean() for i in poses[0]] + + # return poses, scores diff --git a/dlclive/pose_estimation_pytorch/models/predictors/identity_predictor.py b/dlclive/pose_estimation_pytorch/models/predictors/identity_predictor.py new file mode 100644 index 0000000..a4837c5 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/predictors/identity_predictor.py @@ -0,0 +1,69 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Predictor to generate identity maps from head outputs""" +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from dlclive.pose_estimation_pytorch.models.predictors.base import ( + PREDICTORS, + BasePredictor, +) + + +@PREDICTORS.register_module +class IdentityPredictor(BasePredictor): + """Predictor to generate identity maps from head outputs + + Attributes: + apply_sigmoid: Apply sigmoid to heatmaps. Defaults to True. + """ + + def __init__(self, apply_sigmoid: bool = True): + """ + Args: + apply_sigmoid: Apply sigmoid to heatmaps. Defaults to True. + """ + super().__init__() + self.apply_sigmoid = apply_sigmoid + self.sigmoid = nn.Sigmoid() + + def forward( + self, stride: float, outputs: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """ + Swaps the dimensions so the heatmap are (batch_size, h, w, num_individuals), + optionally applies a sigmoid to the heatmaps, and rescales it to be the size + of the original image (so that the identity scores of keypoints can be computed) + + Args: + stride: the stride of the model + outputs: output of the model identity head, of shape (b, num_idv, w', h') + + Returns: + A dictionary containing a "heatmap" key with the identity heatmap tensor as + value. + """ + heatmaps = outputs["heatmap"] + h_out, w_out = heatmaps.shape[2:] + h_in, w_in = int(h_out * stride), int(w_out * stride) + heatmaps = F.resize( + heatmaps, + size=[h_in, w_in], + interpolation=F.InterpolationMode.BILINEAR, + antialias=True, + ) + if self.apply_sigmoid: + heatmaps = self.sigmoid(heatmaps) + + # permute to have shape (batch_size, h, w, num_individuals) + heatmaps = heatmaps.permute((0, 2, 3, 1)) + return {"heatmap": heatmaps} diff --git a/dlclive/pose_estimation_pytorch/models/predictors/paf_predictor.py b/dlclive/pose_estimation_pytorch/models/predictors/paf_predictor.py new file mode 100644 index 0000000..de83636 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/predictors/paf_predictor.py @@ -0,0 +1,368 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import numpy as np +import torch +import torch.nn.functional as F +from numpy.typing import NDArray + +from dlclive.pose_estimation_pytorch.models.predictors.base import ( + BasePredictor, + PREDICTORS, +) +from dlclive.core import inferenceutils + +Graph = list[tuple[int, int]] + + +@PREDICTORS.register_module +class PartAffinityFieldPredictor(BasePredictor): + """Predictor class for multiple animal pose estimation with part affinity fields. + + TODO: INSTALL scipy-1.14.1 + + Args: + num_animals: Number of animals in the project. + num_multibodyparts: Number of animal's body parts (ignoring unique body parts). + num_uniquebodyparts: Number of unique body parts. # FIXME - should not be needed here if we separate the unique bodypart head + graph: Part affinity field graph edges. + edges_to_keep: List of indices in `graph` of the edges to keep. + locref_stdev: Standard deviation for location refinement. + nms_radius: Radius of the Gaussian kernel. + sigma: Width of the 2D Gaussian distribution. + min_affinity: Minimal edge affinity to add a body part to an Assembly. + + Returns: + Regressed keypoints from heatmaps, locref_maps and part affinity fields, as in Tensorflow maDLC. + """ + + default_init = { + "locref_stdev": 7.2801, + "nms_radius": 5, + "sigma": 1, + "min_affinity": 0.05, + } + + def __init__( + self, + num_animals: int, + num_multibodyparts: int, + num_uniquebodyparts: int, + graph: Graph, + edges_to_keep: list[int], + locref_stdev: float, + nms_radius: int, + sigma: float, + min_affinity: float, + add_discarded: bool = False, + apply_sigmoid: bool = True, + clip_scores: bool = False, + force_fusion: bool = False, + return_preds: bool = False, + ): + """Initialize the PartAffinityFieldPredictor class. + + Args: + num_animals: Number of animals in the project. + num_multibodyparts: Number of animal's body parts (ignoring unique body parts). + num_uniquebodyparts: Number of unique body parts. + graph: Part affinity field graph edges. + edges_to_keep: List of indices in `graph` of the edges to keep. + locref_stdev: Standard deviation for location refinement. + nms_radius: Radius of the Gaussian kernel. + sigma: Width of the 2D Gaussian distribution. + min_affinity: Minimal edge affinity to add a body part to an Assembly. + return_preds: Whether to return predictions alongside the animals' poses + + Returns: + None + """ + super().__init__() + self.num_animals = num_animals + self.num_multibodyparts = num_multibodyparts + self.num_uniquebodyparts = num_uniquebodyparts + self.graph = graph + self.edges_to_keep = edges_to_keep + self.locref_stdev = locref_stdev + self.nms_radius = nms_radius + self.return_preds = return_preds + self.sigma = sigma + self.apply_sigmoid = apply_sigmoid + self.clip_scores = clip_scores + self.sigmoid = torch.nn.Sigmoid() + self.assembler = inferenceutils.Assembler.empty( + num_animals, + n_multibodyparts=num_multibodyparts, + n_uniquebodyparts=num_uniquebodyparts, + graph=graph, + paf_inds=edges_to_keep, + min_affinity=min_affinity, + add_discarded=add_discarded, + force_fusion=force_fusion, + ) + + def forward( + self, stride: float, outputs: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Forward pass of PartAffinityFieldPredictor. Gets predictions from model output. + + Args: + stride: the stride of the model + outputs: Output tensors from previous layers. + output = heatmaps, locref, pafs + heatmaps: torch.Tensor([batch_size, num_joints, height, width]) + locref: torch.Tensor([batch_size, num_joints, height, width]) + + Returns: + A dictionary containing a "poses" key with the output tensor as value. + + Example: + >>> predictor = PartAffinityFieldPredictor(num_animals=3, location_refinement=True, locref_stdev=7.2801) + >>> output = (torch.rand(32, 17, 64, 64), torch.rand(32, 34, 64, 64), torch.rand(32, 136, 64, 64)) + >>> stride = 8 + >>> poses = predictor.forward(stride, output) + """ + heatmaps = outputs["heatmap"] + locrefs = outputs["locref"] + pafs = outputs["paf"] + scale_factors = stride, stride + batch_size, n_channels, height, width = heatmaps.shape + + if self.apply_sigmoid: + heatmaps = self.sigmoid(heatmaps) + + # Filter predicted heatmaps with a 2D Gaussian kernel as in: + # https://openaccess.thecvf.com/content_CVPR_2020/papers/Huang_The_Devil_Is_in_the_Details_Delving_Into_Unbiased_Data_CVPR_2020_paper.pdf + kernel = self.make_2d_gaussian_kernel( + sigma=self.sigma, size=self.nms_radius * 2 + 1 + )[None, None] + kernel = kernel.repeat(n_channels, 1, 1, 1).to(heatmaps.device) + heatmaps = F.conv2d( + heatmaps, kernel, stride=1, padding="same", groups=n_channels + ) + + peaks = self.find_local_peak_indices_maxpool_nms( + heatmaps, self.nms_radius, threshold=0.01 + ) + if ~torch.any(peaks): + return { + "poses": -torch.ones( + (batch_size, self.num_animals, self.num_multibodyparts, 5) + ) + } + + locrefs = locrefs.reshape(batch_size, n_channels, 2, height, width) + locrefs = locrefs * self.locref_stdev + pafs = pafs.reshape(batch_size, -1, 2, height, width) + + graph = [self.graph[ind] for ind in self.edges_to_keep] + preds = self.compute_peaks_and_costs( + heatmaps, + locrefs, + pafs, + peaks, + graph, + self.edges_to_keep, + scale_factors, + n_id_channels=0, # FIXME Handle identity training + ) + poses = -torch.ones((batch_size, self.num_animals, self.num_multibodyparts, 5)) + poses_unique = -torch.ones((batch_size, 1, self.num_uniquebodyparts, 4)) + for i, data_dict in enumerate(preds): + assemblies, unique = self.assembler._assemble(data_dict, ind_frame=0) + if assemblies is not None: + for j, assembly in enumerate(assemblies): + poses[i, j, :, :4] = torch.from_numpy(assembly.data) + poses[i, j, :, 4] = assembly.affinity + if unique is not None: + poses_unique[i, 0, :, :4] = torch.from_numpy(unique) + + if self.clip_scores: + poses[..., 2] = torch.clip(poses[..., 2], min=0, max=1) + + out = {"poses": poses} + if self.return_preds: + out["preds"] = preds + return out + + @staticmethod + def find_local_peak_indices_maxpool_nms( + input_: torch.Tensor, radius: int, threshold: float + ) -> torch.Tensor: + pooled = F.max_pool2d(input_, kernel_size=radius, stride=1, padding=radius // 2) + maxima = input_ * torch.eq(input_, pooled).float() + peak_indices = torch.nonzero(maxima >= threshold, as_tuple=False) + return peak_indices.int() + + @staticmethod + def make_2d_gaussian_kernel(sigma: float, size: int) -> torch.Tensor: + k = torch.arange(-size // 2 + 1, size // 2 + 1, dtype=torch.float32) ** 2 + k = F.softmax(-k / (2 * (sigma**2)), dim=0) + return torch.einsum("i,j->ij", k, k) + + @staticmethod + def calc_peak_locations( + locrefs: torch.Tensor, + peak_inds_in_batch: torch.Tensor, + strides: tuple[float, float], + ) -> torch.Tensor: + s, b, r, c = peak_inds_in_batch.T + stride_y, stride_x = strides + strides = torch.Tensor((stride_x, stride_y)).to(locrefs.device) + off = locrefs[s, b, :, r, c] + loc = strides * peak_inds_in_batch[:, [3, 2]] + strides // 2 + off + return loc + + @staticmethod + def compute_edge_costs( + pafs: NDArray, + peak_inds_in_batch: NDArray, + graph: Graph, + paf_inds: list[int], + n_bodyparts: int, + n_points: int = 10, + n_decimals: int = 3, + ) -> list[dict[int, NDArray]]: + # Clip peak locations to PAFs dimensions + h, w = pafs.shape[-2:] + peak_inds_in_batch[:, 2] = np.clip(peak_inds_in_batch[:, 2], 0, h - 1) + peak_inds_in_batch[:, 3] = np.clip(peak_inds_in_batch[:, 3], 0, w - 1) + + n_samples = pafs.shape[0] + sample_inds = [] + edge_inds = [] + all_edges = [] + all_peaks = [] + for i in range(n_samples): + samples_i = peak_inds_in_batch[:, 0] == i + peak_inds = peak_inds_in_batch[samples_i, 1:] + if not np.any(peak_inds): + continue + peaks = peak_inds[:, 1:] + bpt_inds = peak_inds[:, 0] + idx = np.arange(peaks.shape[0]) + idx_per_bpt = {j: idx[bpt_inds == j].tolist() for j in range(n_bodyparts)} + edges = [] + for k, (s, t) in zip(paf_inds, graph): + inds_s = idx_per_bpt[s] + inds_t = idx_per_bpt[t] + if not (inds_s and inds_t): + continue + candidate_edges = ((i, j) for i in inds_s for j in inds_t) + edges.extend(candidate_edges) + edge_inds.extend([k] * len(inds_s) * len(inds_t)) + if not edges: + continue + sample_inds.extend([i] * len(edges)) + all_edges.extend(edges) + all_peaks.append(peaks[np.asarray(edges)]) + if not all_peaks: + return [dict() for _ in range(n_samples)] + + sample_inds = np.asarray(sample_inds, dtype=np.int32) + edge_inds = np.asarray(edge_inds, dtype=np.int32) + all_edges = np.asarray(all_edges, dtype=np.int32) + all_peaks = np.concatenate(all_peaks) + vecs_s = all_peaks[:, 0] + vecs_t = all_peaks[:, 1] + vecs = vecs_t - vecs_s + lengths = np.linalg.norm(vecs, axis=1).astype(np.float32) + lengths += np.spacing(1, dtype=np.float32) + xy = np.linspace(vecs_s, vecs_t, n_points, axis=1, dtype=np.int32) + y = pafs[ + sample_inds.reshape((-1, 1)), + edge_inds.reshape((-1, 1)), + :, + xy[..., 0], + xy[..., 1], + ] + integ = np.trapz(y, xy[..., ::-1], axis=1) + affinities = np.linalg.norm(integ, axis=1).astype(np.float32) + affinities /= lengths + np.round(affinities, decimals=n_decimals, out=affinities) + np.round(lengths, decimals=n_decimals, out=lengths) + + # Form cost matrices + all_costs = [] + for i in range(n_samples): + samples_i_mask = sample_inds == i + costs = dict() + for k in paf_inds: + edges_k_mask = edge_inds == k + idx = np.flatnonzero(samples_i_mask & edges_k_mask) + s, t = all_edges[idx].T + n_sources = np.unique(s).size + n_targets = np.unique(t).size + costs[k] = dict() + costs[k]["m1"] = affinities[idx].reshape((n_sources, n_targets)) + costs[k]["distance"] = lengths[idx].reshape((n_sources, n_targets)) + all_costs.append(costs) + + return all_costs + + @staticmethod + def _linspace(start: torch.Tensor, stop: torch.Tensor, num: int) -> torch.Tensor: + # Taken from https://github.com/pytorch/pytorch/issues/61292#issue-937937159 + steps = torch.linspace(0, 1, num, dtype=torch.float32, device=start.device) + steps = steps.reshape([-1, *([1] * start.ndim)]) + out = start[None] + steps * (stop - start)[None] + return out.swapaxes(0, 1) + + def compute_peaks_and_costs( + self, + heatmaps: torch.Tensor, + locrefs: torch.Tensor, + pafs: torch.Tensor, + peak_inds_in_batch: torch.Tensor, + graph: Graph, + paf_inds: list[int], + strides: tuple[float, float], + n_id_channels: int, + n_points: int = 10, + n_decimals: int = 3, + ) -> list[dict[str, NDArray]]: + n_samples, n_channels = heatmaps.shape[:2] + n_bodyparts = n_channels - n_id_channels + pos = self.calc_peak_locations(locrefs, peak_inds_in_batch, strides) + pos = np.round(pos.detach().cpu().numpy(), decimals=n_decimals) + heatmaps = heatmaps.detach().cpu().numpy() + pafs = pafs.detach().cpu().numpy() + peak_inds_in_batch = peak_inds_in_batch.detach().cpu().numpy() + costs = self.compute_edge_costs( + pafs, peak_inds_in_batch, graph, paf_inds, n_bodyparts, n_points, n_decimals + ) + s, b, r, c = peak_inds_in_batch.T + prob = np.round(heatmaps[s, b, r, c], n_decimals).reshape((-1, 1)) + if n_id_channels: + ids = np.round(heatmaps[s, -n_id_channels:, r, c], n_decimals) + + peaks_and_costs = [] + for i in range(n_samples): + xy = [] + p = [] + id_ = [] + samples_i_mask = peak_inds_in_batch[:, 0] == i + for j in range(n_bodyparts): + bpts_j_mask = peak_inds_in_batch[:, 1] == j + idx = np.flatnonzero(samples_i_mask & bpts_j_mask) + xy.append(pos[idx]) + p.append(prob[idx]) + if n_id_channels: + id_.append(ids[idx]) + dict_ = {"coordinates": (xy,), "confidence": p} + if costs is not None: + dict_["costs"] = costs[i] + if n_id_channels: + dict_["identity"] = id_ + peaks_and_costs.append(dict_) + + return peaks_and_costs diff --git a/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py b/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py new file mode 100644 index 0000000..e4ec134 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/predictors/sim_cc.py @@ -0,0 +1,163 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""SimCC predictor for the RTMPose model + +Based on the official ``mmpose`` SimCC codec and RTMCC head implementation. For more +information, see . +""" +from __future__ import annotations + +import numpy as np +import torch + +from dlclive.pose_estimation_pytorch.models.predictors.base import ( + BasePredictor, + PREDICTORS, +) + + +@PREDICTORS.register_module +class SimCCPredictor(BasePredictor): + """Class used to make pose predictions from RTMPose head outputs + + The RTMPose model uses coordinate classification for pose estimation. For more + information, see "SimCC: a Simple Coordinate Classification Perspective for Human + Pose Estimation" () and "RTMPose: Real-Time + Multi-Person Pose Estimation based on MMPose" (). + + Args: + simcc_split_ratio: The split ratio of pixels, as described in SimCC. + apply_softmax: Whether to apply softmax on the scores. + normalize_outputs: Whether to normalize the outputs before predicting maximums. + """ + + def __init__( + self, + simcc_split_ratio: float = 2.0, + apply_softmax: bool = False, + normalize_outputs: bool = False, + ) -> None: + super().__init__() + self.simcc_split_ratio = simcc_split_ratio + self.apply_softmax = apply_softmax + self.normalize_outputs = normalize_outputs + + def forward( + self, stride: float, outputs: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + x, y = outputs["x"].detach(), outputs["y"].detach() + if self.normalize_outputs: + x = get_simcc_normalized(x) + y = get_simcc_normalized(y) + + keypoints, scores = get_simcc_maximum( + x.cpu().numpy(), y.cpu().numpy(), self.apply_softmax + ) + + if keypoints.ndim == 2: + keypoints = keypoints[None, :] + scores = scores[None, :] + + keypoints /= self.simcc_split_ratio + scores = scores.reshape((*scores.shape, -1)) + keypoints_with_score = np.concatenate([keypoints, scores], axis=-1) + keypoints_with_score = torch.tensor(keypoints_with_score).unsqueeze(1) + return dict(poses=keypoints_with_score) + + +def get_simcc_maximum( + simcc_x: np.ndarray, + simcc_y: np.ndarray, + apply_softmax: bool = False, +) -> tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from SimCC representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + apply_softmax (bool): whether to apply softmax on the heatmap. + Defaults to False. + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + + assert isinstance(simcc_x, np.ndarray), "simcc_x should be numpy.ndarray" + assert isinstance(simcc_y, np.ndarray), "simcc_y should be numpy.ndarray" + assert simcc_x.ndim == 2 or simcc_x.ndim == 3, f"Invalid shape {simcc_x.shape}" + assert simcc_y.ndim == 2 or simcc_y.ndim == 3, f"Invalid shape {simcc_y.shape}" + assert simcc_x.ndim == simcc_y.ndim, f"{simcc_x.shape} != {simcc_y.shape}" + + if simcc_x.ndim == 3: + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + else: + N = None + + if apply_softmax: + simcc_x = simcc_x - np.max(simcc_x, axis=1, keepdims=True) + simcc_y = simcc_y - np.max(simcc_y, axis=1, keepdims=True) + ex, ey = np.exp(simcc_x), np.exp(simcc_y) + simcc_x = ex / np.sum(ex, axis=1, keepdims=True) + simcc_y = ey / np.sum(ey, axis=1, keepdims=True) + + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.0] = -1 + + if N: + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def get_simcc_normalized(pred: torch.Tensor) -> torch.Tensor: + """Normalize the predicted SimCC. + + See: + github.com/open-mmlab/mmpose/blob/main/mmpose/codecs/utils/post_processing.py#L12 + + Args: + pred: The predicted output. + + Returns: + The normalized output. + """ + b, k, _ = pred.shape + pred = pred.clamp(min=0) + + # Compute the binary mask + mask = (pred.amax(dim=-1) > 1).reshape(b, k, 1) + + # Normalize the tensor using the maximum value + norm = pred / pred.amax(dim=-1).reshape(b, k, 1) + + # return the normalized tensor + return torch.where(mask, norm, pred) diff --git a/dlclive/pose_estimation_pytorch/models/predictors/single_predictor.py b/dlclive/pose_estimation_pytorch/models/predictors/single_predictor.py new file mode 100644 index 0000000..c622cf9 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/predictors/single_predictor.py @@ -0,0 +1,164 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +from typing import Tuple + +import torch + +from dlclive.pose_estimation_pytorch.models.predictors.base import ( + BasePredictor, + PREDICTORS, +) + + +@PREDICTORS.register_module +class HeatmapPredictor(BasePredictor): + """Predictor class for pose estimation from heatmaps (and optionally locrefs). + + Args: + location_refinement: Enable location refinement. + locref_std: Standard deviation for location refinement. + apply_sigmoid: Apply sigmoid to heatmaps. Defaults to True. + + Returns: + Regressed keypoints from heatmaps and locref_maps of baseline DLC model (ResNet + Deconv). + """ + + def __init__( + self, + apply_sigmoid: bool = True, + clip_scores: bool = False, + location_refinement: bool = True, + locref_std: float = 7.2801, + ): + """ + Args: + apply_sigmoid: Apply sigmoid to heatmaps. Defaults to True. + clip_scores: If a sigmoid is not applied, this can be used to clip scores + for predicted keypoints to values in [0, 1]. + location_refinement : Enable location refinement. + locref_std: Standard deviation for location refinement. + """ + super().__init__() + self.apply_sigmoid = apply_sigmoid + self.clip_scores = clip_scores + self.sigmoid = torch.nn.Sigmoid() + self.location_refinement = location_refinement + self.locref_std = locref_std + + def forward( + self, stride: float, outputs: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Forward pass of SinglePredictor. Gets predictions from model output. + + Args: + stride: the stride of the model + outputs: output of the model heads (heatmap, locref) + + Returns: + A dictionary containing a "poses" key with the output tensor as value. + + Example: + >>> predictor = HeatmapPredictor(location_refinement=True, locref_std=7.2801) + >>> stride = 8 + >>> output = {"heatmap": torch.rand(32, 17, 64, 64), "locref": torch.rand(32, 17, 64, 64)} + >>> poses = predictor.forward(stride, output) + """ + heatmaps = outputs["heatmap"] + scale_factors = stride, stride + + if self.apply_sigmoid: + heatmaps = self.sigmoid(heatmaps) + + heatmaps = heatmaps.permute(0, 2, 3, 1) + batch_size, height, width, num_joints = heatmaps.shape + + locrefs = None + if self.location_refinement: + locrefs = outputs["locref"] + locrefs = locrefs.permute(0, 2, 3, 1).reshape( + batch_size, height, width, num_joints, 2 + ) + locrefs = locrefs * self.locref_std + + poses = self.get_pose_prediction(heatmaps, locrefs, scale_factors) + + if self.clip_scores: + poses[..., 2] = torch.clip(poses[..., 2], min=0, max=1) + + return {"poses": poses} + + def get_top_values( + self, heatmap: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the top values from the heatmap. + + Args: + heatmap: Heatmap tensor. + + Returns: + Y and X indices of the top values. + + Example: + >>> predictor = HeatmapPredictor(location_refinement=True, locref_std=7.2801) + >>> heatmap = torch.rand(32, 17, 64, 64) + >>> Y, X = predictor.get_top_values(heatmap) + """ + batchsize, ny, nx, num_joints = heatmap.shape + heatmap_flat = heatmap.reshape(batchsize, nx * ny, num_joints) + heatmap_top = torch.argmax(heatmap_flat, dim=1) + y, x = heatmap_top // nx, heatmap_top % nx + return y, x + + def get_pose_prediction( + self, heatmap: torch.Tensor, locref: torch.Tensor | None, scale_factors + ) -> torch.Tensor: + """Gets the pose prediction given the heatmaps and locref. + + Args: + heatmap: Heatmap tensor of shape (batch_size, height, width, num_joints) + locref: Locref tensor of shape (batch_size, height, width, num_joints, 2) + scale_factors: Scale factors for the poses. + + Returns: + Pose predictions of the format: (batch_size, num_people = 1, num_joints, 3) + + Example: + >>> predictor = HeatmapPredictor( + >>> location_refinement=True, locref_std=7.2801 + >>> ) + >>> heatmap = torch.rand(32, 17, 64, 64) + >>> locref = torch.rand(32, 17, 64, 64, 2) + >>> scale_factors = (0.5, 0.5) + >>> poses = predictor.get_pose_prediction(heatmap, locref, scale_factors) + """ + y, x = self.get_top_values(heatmap) + + batch_size, num_joints = x.shape + + dz = torch.zeros((batch_size, 1, num_joints, 3)).to(x.device) + for b in range(batch_size): + for j in range(num_joints): + dz[b, 0, j, 2] = heatmap[b, y[b, j], x[b, j], j] + if locref is not None: + dz[b, 0, j, :2] = locref[b, y[b, j], x[b, j], j, :] + + x, y = torch.unsqueeze(x, 1), torch.unsqueeze(y, 1) + + x = x * scale_factors[1] + 0.5 * scale_factors[1] + dz[:, :, :, 0] + y = y * scale_factors[0] + 0.5 * scale_factors[0] + dz[:, :, :, 1] + + pose = torch.empty((batch_size, 1, num_joints, 3)) + pose[:, :, :, 0] = x + pose[:, :, :, 1] = y + pose[:, :, :, 2] = dz[:, :, :, 2] + return pose diff --git a/dlclive/pose_estimation_pytorch/models/registry.py b/dlclive/pose_estimation_pytorch/models/registry.py new file mode 100644 index 0000000..45ed735 --- /dev/null +++ b/dlclive/pose_estimation_pytorch/models/registry.py @@ -0,0 +1,330 @@ +""" +DeepLabCut Toolbox (deeplabcut.org) +© A. & M. Mathis Labs + +Licensed under GNU Lesser General Public License v3.0 +""" + +import inspect +from functools import partial +from typing import Any, Dict, Optional + + +def build_from_cfg( + cfg: Dict, registry: "Registry", default_args: Optional[Dict] = None +) -> Any: + """Builds a module from the configuration dictionary when it represents a class configuration, + or call a function from the configuration dictionary when it represents a function configuration. + + Args: + cfg: Configuration dictionary. It should at least contain the key "type". + registry: The registry to search the type from. + default_args: Default initialization arguments. + Defaults to None. + + Returns: + Any: The constructed object. + + Example: + >>> from dlclive.models.registry import Registry, build_from_cfg + >>> class Model: + >>> def __init__(self, param): + >>> self.param = param + >>> cfg = {"type": "Model", "param": 10} + >>> registry = Registry("models") + >>> registry.register_module(Model) + >>> obj = build_from_cfg(cfg, registry) + >>> assert isinstance(obj, Model) + >>> assert obj.param == 10 + """ + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + obj_type = args.pop("type") + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError(f"{obj_type} is not in the {registry.name} registry") + elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): + obj_cls = obj_type + else: + raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}") + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f"{obj_cls.__name__}: {e}") + + +class Registry: + """A registry to map strings to classes or functions. + Registered objects could be built from the registry. Meanwhile, registered + functions could be called from the registry. + + Args: + name: Registry name. + build_func: Builds function to construct an instance from + the Registry. If neither ``parent`` nor + ``build_func`` is specified, the ``build_from_cfg`` + function is used. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be + inherited from ``parent``. Default: None. + parent: Parent registry. The class registered in + children's registry could be built from the parent. + Default: None. + scope: The scope of the registry. It is the key to search + for children's registry. If not specified, scope will be the + name of the package where the class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. + + Attributes: + name: Registry name. + module_dict: The dictionary containing registered modules. + children: The dictionary containing children registries. + scope: The scope of the registry. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = "." + + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = ( + self.__class__.__name__ + f"(name={self._name}, " + f"items={self._module_dict})" + ) + return format_str + + @staticmethod + def split_scope_key(key): + """Split scope and key. + The first scope will be split from key. + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + Return: + tuple[str | None, str]: The former element is the first scope of + the key, which can be ``None``. The latter is the remaining key. + """ + split_index = key.find(".") + if split_index != -1: + return key[:split_index], key[split_index + 1 :] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + + Args: + key: The class name in string format. + + Returns: + class: The corresponding class. + + Example: + >>> from dlclive.models.registry import Registry + >>> registry = Registry("models") + >>> class Model: + >>> pass + >>> registry.register_module(Model, "Model") + >>> assert registry.get("Model") == Model + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + """Builds an instance from the registry. + + Args: + *args: Arguments passed to the build function. + **kwargs: Keyword arguments passed to the build function. + + Returns: + Any: The constructed object. + + Example: + >>> from dlclive.models.registry import Registry, build_from_cfg + >>> class Model: + >>> def __init__(self, param): + >>> self.param = param + >>> cfg = {"type": "Model", "param": 10} + >>> registry = Registry("models") + >>> registry.register_module(Model) + >>> obj = registry.build(cfg, param=20) + >>> assert isinstance(obj, Model) + >>> assert obj.param == 20 + """ + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + Args: + registry: The registry to be added as children based on its scope. + + Returns: + None + + Example: + >>> from dlclive.models.registry import Registry + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> class Model: + >>> pass + >>> mmdet_models.register_module(Model) + >>> obj = models.build(dict(type='mmdet.Model')) + >>> assert isinstance(obj, Model) + """ + assert isinstance(registry, Registry) + assert registry.scope is not None + assert ( + registry.scope not in self.children + ), f"scope {registry.scope} exists in {self.name} registry" + self.children[registry.scope] = registry + + def _register_module(self, module, module_name=None, force=False): + """Register a module. + + Args: + module: Module class or function to be registered. + module_name: The module name(s) to be registered. + If not specified, the class name will be used. + force: Whether to override an existing class with the same name. + Default: False. + + Returns: + None + + Example: + >>> from dlclive.models.registry import Registry + >>> registry = Registry("models") + >>> class Model: + >>> pass + >>> registry._register_module(Model, "Model") + >>> assert registry.get("Model") == Model + """ + if not inspect.isclass(module) and not inspect.isfunction(module): + raise TypeError( + "module must be a class or a function, " f"but got {type(module)}" + ) + + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f"{name} is already registered " f"in {self.name}") + self._module_dict[name] = module + + def deprecated_register_module(self, cls=None, force=False): + """Decorator to register a class in the registry. + + Args: + cls: The class to be registered. + force: Whether to override an existing class with the same name. + Default: False. + + Returns: + type: The input class. + + Example: + >>> from dlclive.models.registry import Registry + >>> registry = Registry("models") + >>> @registry.deprecated_register_module() + >>> class Model: + >>> pass + >>> assert registry.get("Model") == Model + """ + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + Args: + name: The module name to be registered. If not + specified, the class name will be used. + force: Whether to override an existing class with + the same name. Default: False. + module: Module class or function to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + self._register_module(module=module, module_name=name, force=force) + return module + + return diff --git a/dlclive/pose_estimation_pytorch/runner.py b/dlclive/pose_estimation_pytorch/runner.py new file mode 100644 index 0000000..5e1d89e --- /dev/null +++ b/dlclive/pose_estimation_pytorch/runner.py @@ -0,0 +1,370 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""PyTorch and ONNX runners for DeepLabCut-Live""" +import copy +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import numpy as np +import torch +from torchvision.transforms import v2 + +import dlclive.pose_estimation_pytorch.data as data +import dlclive.pose_estimation_pytorch.models as models +import dlclive.pose_estimation_pytorch.dynamic_cropping as dynamic_cropping +from dlclive.core.runner import BaseRunner + + +@dataclass +class SkipFrames: + """Configuration for skip frames. + + Skip-frames can be used for top-down models running with a detector. If skip > 0, + then the detector will only be run every `skip` frames. Between frames where the + detector is run, bounding boxes will be computed from the pose estimated in the + previous frame. + + Every `N` frames, the detector will be run to detect bounding boxes for individuals. + In the "skipped" frames between the frames where the object detector is run, the + bounding boxes will be computed from the poses estimated in the previous frame (with + some margin added around the poses). + + Attributes: + skip: The number of frames to skip between each run of the detector. + margin: The margin (in pixels) to use when generating bboxes + """ + + skip: int + margin: int + _age: int = 0 + _detections: dict[str, torch.Tensor] | None = None + + def get_detections(self) -> dict[str, torch.Tensor] | None: + return self._detections + + def update(self, pose: torch.Tensor, w: int, h: int) -> None: + """Generates bounding boxes from a pose. + + Args: + pose: The pose from which to generate bounding boxes. + w: The width of the image. + h: The height of the image. + + Returns: + A dictionary containing the bounding boxes and scores for each detection. + """ + if self._age >= self.skip: + self._age = 0 + self._detections = None + return + + num_det, num_kpts = pose.shape[:2] + size = max(w, h) + + bboxes = torch.zeros((num_det, 4)) + bboxes[:, :2] = ( + torch.min(torch.nan_to_num(pose, size)[..., :2], dim=1)[0] - self.margin + ) + bboxes[:, 2:4] = ( + torch.max(torch.nan_to_num(pose, 0)[..., :2], dim=1)[0] + self.margin + ) + bboxes = torch.clip(bboxes, min=torch.zeros(4), max=torch.tensor([w, h, w, h])) + self._detections = dict(boxes=bboxes, scores=torch.ones(num_det)) + self._age += 1 + + +@dataclass +class TopDownConfig: + """Configuration for top-down models. + + Attributes: + bbox_cutoff: The minimum score required for a bounding box to be considered. + max_detections: The maximum number of detections to keep in a frame. If None, + the `max_detections` will be set to the number of individuals in the model + configuration file when `read_config` is called. + skip_frames: If defined, the detector will only be run every + `skip_frames.skip` frames. + """ + + bbox_cutoff: float = 0.6 + max_detections: int | None = 30 + crop_size: tuple[int, int] = (256, 256) + skip_frames: SkipFrames | None = None + + def read_config(self, model_cfg: dict) -> None: + crop = model_cfg.get("data", {}).get("inference", {}).get("top_down_crop") + if crop is not None: + self.crop_size = (crop["width"], crop["height"]) + + if self.max_detections is None: + individuals = model_cfg.get("metadata", {}).get("individuals", []) + self.max_detections = len(individuals) + + +class PyTorchRunner(BaseRunner): + """PyTorch runner for live pose estimation using DeepLabCut-Live. + + Args: + path: The path to the model to run inference with. + device: The device on which to run inference, e.g. "cpu", "cuda", "cuda:0" + precision: The precision of the model. One of "FP16" or "FP32". + single_animal: This option is only available for single-animal pose estimation + models. It makes the code behave in exactly the same way as DeepLabCut-Live + with version < 3.0.0. This ensures backwards compatibility with any + Processors that were implemented. + dynamic: Whether to use dynamic cropping. + top_down_config: Only for top-down models running with a detector. + """ + + def __init__( + self, + path: str | Path, + device: str = "auto", + precision: Literal["FP16", "FP32"] = "FP32", + single_animal: bool = True, + dynamic: dict | dynamic_cropping.DynamicCropper | None = None, + top_down_config: dict | TopDownConfig | None = None, + ) -> None: + super().__init__(path) + self.device = _parse_device(device) + self.precision = precision + self.single_animal = single_animal + + self.cfg = None + self.detector = None + self.model = None + self.transform = None + + # Parse Dynamic Cropping parameters + if isinstance(dynamic, dict): + dynamic_type = dynamic.get("type", "DynamicCropper") + if dynamic_type == "DynamicCropper": + cropper_cls = dynamic_cropping.DynamicCropper + else: + cropper_cls = dynamic_cropping.TopDownDynamicCropper + dynamic_params = dynamic.copy() + dynamic_params.pop("type") + dynamic = cropper_cls(**dynamic_params) + + # Parse Top-Down config + if isinstance(top_down_config, dict): + skip_frame_cfg = top_down_config.get("skip_frames") + if skip_frame_cfg is not None: + top_down_config["skip_frames"] = SkipFrames(**skip_frame_cfg) + top_down_config = TopDownConfig(**top_down_config) + + self.dynamic = dynamic + self.top_down_config = top_down_config + + def close(self) -> None: + """Clears any resources used by the runner.""" + pass + + @torch.inference_mode() + def get_pose(self, frame: np.ndarray) -> np.ndarray: + c, h, w = frame.shape + frame = ( + self.transform(torch.from_numpy(frame).permute(2, 0, 1)) + .unsqueeze(0) + .to(self.device) + ) + if self.precision == "FP16": + frame = frame.half() + + offsets_and_scales = None + if self.detector is not None: + detections = None + if self.top_down_config.skip_frames is not None: + detections = self.top_down_config.skip_frames.get_detections() + + if detections is None: + detections = self.detector(frame)[0] + + frame_batch, offsets_and_scales = self._prepare_top_down(frame, detections) + if len(frame_batch) == 0: + offsets_and_scales = [(0, 0), 1] + else: + frame = frame_batch.to(self.device) + + if self.dynamic is not None: + frame = self.dynamic.crop(frame) + + outputs = self.model(frame) + batch_pose = self.model.get_predictions(outputs)["bodypart"]["poses"] + + if self.dynamic is not None: + batch_pose = self.dynamic.update(batch_pose) + + if self.detector is None: + pose = batch_pose[0] + else: + pose = self._postprocess_top_down(batch_pose, offsets_and_scales) + if self.top_down_config.skip_frames is not None: + self.top_down_config.skip_frames.update(pose, w, h) + + if self.single_animal: + if len(pose) == 0: + bodyparts, coords = pose.shape[-2:] + return np.zeros((bodyparts, coords)) + + pose = pose[0] + + return pose.cpu().numpy() + + def init_inference(self, frame: np.ndarray, **kwargs) -> np.ndarray: + """ + Initializes inference process on the provided frame. + + This method serves as an abstract base method, meant to be implemented by + subclasses. It takes an input image frame and optional additional parameters + to set up and perform inference. The method must return a processed result + as a numpy array. + + Parameters + ---------- + frame : np.ndarray + The input image frame for which inference needs to be set up. + kwargs : dict, optional + Additional parameters that may be required for specific implementation + of the inference initialization. + + Returns + ------- + np.ndarray + The result of the inference after being initialized and processed. + """ + self.load_model() + return self.get_pose(frame) + + def load_model(self) -> None: + """Loads the model from the provided path.""" + raw_data = torch.load(self.path, map_location="cpu", weights_only=True) + + self.cfg = raw_data["config"] + self.model = models.PoseModel.build(self.cfg["model"]) + self.model.load_state_dict(raw_data["pose"]) + self.model = self.model.to(self.device) + self.model.eval() + + if self.precision == "FP16": + self.model = self.model.half() + + self.detector = None + if self.dynamic is None and raw_data.get("detector") is not None: + self.detector = models.DETECTORS.build(self.cfg["detector"]["model"]) + self.detector.to(self.device) + self.detector.load_state_dict(raw_data["detector"]) + self.detector.eval() + + if self.precision == "FP16": + self.detector = self.detector.half() + + if self.top_down_config is None: + self.top_down_config = TopDownConfig() + + self.top_down_config.read_config(self.cfg) + + if isinstance(self.dynamic, dynamic_cropping.TopDownDynamicCropper): + crop = self.cfg["data"]["inference"].get("top_down_crop", {}) + w, h = crop.get("width", 256), crop.get("height", 256) + self.dynamic.top_down_crop_size = w, h + + if ( + self.cfg["method"] == "td" + and self.detector is None + and self.dynamic is None + ): + raise ValueError( + "Top-down models must either use a detector or a TopDownDynamicCropper." + ) + + self.transform = v2.Compose( + [ + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + def read_config(self) -> dict: + """Reads the configuration file""" + if self.cfg is not None: + return copy.deepcopy(self.cfg) + + raw_data = torch.load(self.path, map_location="cpu", weights_only=True) + return raw_data["config"] + + def _prepare_top_down( + self, frame: torch.Tensor, detections: dict[str, torch.Tensor] + ): + """Prepares a frame for top-down pose estimation.""" + bboxes, scores = detections["boxes"], detections["scores"] + bboxes = bboxes[scores >= self.top_down_config.bbox_cutoff] + if len(bboxes) > 0 and self.top_down_config.max_detections is not None: + bboxes = bboxes[: self.top_down_config.max_detections] + + crops = [] + offsets_and_scales = [] + for bbox in bboxes: + x1, y1, x2, y2 = bbox + cropped_frame, offset, scale = data.top_down_crop_torch( + frame[0], + (x1, y1, x2 - x1, y2 - y1), + output_size=self.top_down_config.crop_size, + margin=0, + ) + crops.append(cropped_frame) + offsets_and_scales.append((offset, scale)) + + if len(crops) > 0: + frame_batch = torch.stack(crops, dim=0) + else: + crop_w, crop_h = self.top_down_config.crop_size + frame_batch = torch.zeros((0, 3, crop_h, crop_w), device=frame.device) + offsets_and_scales = [(0, 0), 1] + + return frame_batch, offsets_and_scales + + def _postprocess_top_down( + self, + batch_pose: torch.Tensor, + offsets_and_scales: list[tuple[tuple[int, int], tuple[float, float]]], + ) -> torch.Tensor: + """Post-processes pose for top-down models.""" + if len(batch_pose) == 0: + bodyparts, coords = batch_pose.shape[-2:] + return torch.zeros((0, bodyparts, coords)) + + poses = [] + for pose, (offset, scale) in zip(batch_pose, offsets_and_scales): + poses.append( + torch.cat( + [ + pose[..., :2] * torch.tensor(scale) + torch.tensor(offset), + pose[..., 2:3], + ], + dim=-1, + ) + ) + + return torch.cat(poses) + + +def _parse_device(device: str | None) -> str: + if device is None: + device = "auto" + + if device == "auto": + if torch.cuda.is_available(): + return "cuda" + return "cpu" + + return device diff --git a/dlclive/pose_estimation_tensorflow/__init__.py b/dlclive/pose_estimation_tensorflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dlclive/pose_estimation_tensorflow/graph.py b/dlclive/pose_estimation_tensorflow/graph.py new file mode 100644 index 0000000..72b3b76 --- /dev/null +++ b/dlclive/pose_estimation_tensorflow/graph.py @@ -0,0 +1,139 @@ +""" +DeepLabCut Toolbox (deeplabcut.org) +© A. & M. Mathis Labs + +Licensed under GNU Lesser General Public License v3.0 +""" + + +import tensorflow as tf + +vers = (tf.__version__).split(".") +if int(vers[0]) == 2 or int(vers[0]) == 1 and int(vers[1]) > 12: + tf = tf.compat.v1 +else: + tf = tf + + +def read_graph(file): + """ + Loads the graph from a protobuf file + + Parameters + ----------- + file : string + path to the protobuf file + + Returns + -------- + graph_def :class:`tensorflow.tf.compat.v1.GraphDef` + The graph definition of the DeepLabCut model found at the object's path + """ + + with tf.io.gfile.GFile(file, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + return graph_def + + +def finalize_graph(graph_def): + """ + Finalize the graph and get inputs to model + + Parameters + ----------- + graph_def :class:`tensorflow.compat.v1.GraphDef` + The graph of the DeepLabCut model, read using the :func:`read_graph` method + + Returns + -------- + graph :class:`tensorflow.compat.v1.GraphDef` + The finalized graph of the DeepLabCut model + inputs :class:`tensorflow.Tensor` + Input tensor(s) for the model + """ + + graph = tf.Graph() + with graph.as_default(): + tf.import_graph_def(graph_def, name="DLC") + graph.finalize() + + return graph + + +def get_output_nodes(graph): + """ + Get the output node names from a graph + + Parameters + ----------- + graph :class:`tensorflow.Graph` + The graph of the DeepLabCut model + + Returns + -------- + output : list + the output node names as a list of strings + """ + + op_names = [str(op.name) for op in graph.get_operations()] + if "concat_1" in op_names[-1]: + output = [op_names[-1]] + else: + output = [op_names[-1], op_names[-2]] + + return output + + +def get_output_tensors(graph): + """ + Get the names of the output tensors from a graph + + Parameters + ----------- + graph :class:`tensorflow.Graph` + The graph of the DeepLabCut model + + Returns + -------- + output : list + the output tensor names as a list of strings + """ + + output_nodes = get_output_nodes(graph) + output_tensor = [out + ":0" for out in output_nodes] + return output_tensor + + +def get_input_tensor(graph): + input_tensor = str(graph.get_operations()[0].name) + ":0" + return input_tensor + + +def extract_graph( + graph, tf_config=None +) -> tuple[tf.Session, tf.Tensor, list[tf.Tensor]]: + """ + Initializes a tensorflow session with the specified graph and extracts the model's inputs and outputs + + Parameters + ----------- + graph :class:`tensorflow.Graph` + a tensorflow graph containing the desired model + tf_config :class:`tensorflow.ConfigProto` + + Returns + -------- + sess :class:`tensorflow.Session` + a tensorflow session with the specified graph definition + outputs :class:`tensorflow.Tensor` + the output tensor(s) for the model + """ + + input_tensor = get_input_tensor(graph) + output_tensor = get_output_tensors(graph) + sess = tf.Session(graph=graph, config=tf_config) + inputs = graph.get_tensor_by_name(input_tensor) + outputs = [graph.get_tensor_by_name(out) for out in output_tensor] + + return sess, inputs, outputs diff --git a/dlclive/pose_estimation_tensorflow/pose.py b/dlclive/pose_estimation_tensorflow/pose.py new file mode 100644 index 0000000..3e69bb9 --- /dev/null +++ b/dlclive/pose_estimation_tensorflow/pose.py @@ -0,0 +1,120 @@ +""" +DeepLabCut Toolbox (deeplabcut.org) +© A. & M. Mathis Labs + +Licensed under GNU Lesser General Public License v3.0 +""" + + +import numpy as np + + +def extract_cnn_output(outputs, cfg): + """ + Extract location refinement and score map from DeepLabCut network + + Parameters + ----------- + outputs : list + List of outputs from DeepLabCut network. + Requires 2 entries: + index 0 is output from Sigmoid + index 1 is output from pose/locref_pred/block4/BiasAdd + + cfg : dict + Dictionary read from the pose_cfg.yaml file for the network. + + Returns + -------- + scmap : ? + score map + + locref : ? + location refinement + """ + + scmap = outputs[0] + scmap = np.squeeze(scmap) + locref = None + if cfg["location_refinement"]: + locref = np.squeeze(outputs[1]) + shape = locref.shape + locref = np.reshape(locref, (shape[0], shape[1], -1, 2)) + locref *= cfg["locref_stdev"] + if len(scmap.shape) == 2: # for single body part! + scmap = np.expand_dims(scmap, axis=2) + return scmap, locref + + +def argmax_pose_predict(scmap, offmat, stride): + """ + Combines score map and offsets to the final pose + + Parameters + ----------- + scmap : ? + score map + + offmat : ? + offsets + + stride : ? + ? + + Returns + -------- + pose :class:`numpy.ndarray` + pose as a numpy array + """ + + num_joints = scmap.shape[2] + pose = [] + for joint_idx in range(num_joints): + maxloc = np.unravel_index( + np.argmax(scmap[:, :, joint_idx]), scmap[:, :, joint_idx].shape + ) + offset = np.array(offmat[maxloc][joint_idx])[::-1] + pos_f8 = np.array(maxloc).astype("float") * stride + 0.5 * stride + offset + pose.append(np.hstack((pos_f8[::-1], [scmap[maxloc][joint_idx]]))) + return np.array(pose) + + +def get_top_values(scmap, n_top=5): + batchsize, ny, nx, num_joints = scmap.shape + scmap_flat = scmap.reshape(batchsize, nx * ny, num_joints) + if n_top == 1: + scmap_top = np.argmax(scmap_flat, axis=1)[None] + else: + scmap_top = np.argpartition(scmap_flat, -n_top, axis=1)[:, -n_top:] + for ix in range(batchsize): + vals = scmap_flat[ix, scmap_top[ix], np.arange(num_joints)] + arg = np.argsort(-vals, axis=0) + scmap_top[ix] = scmap_top[ix, arg, np.arange(num_joints)] + scmap_top = scmap_top.swapaxes(0, 1) + + Y, X = np.unravel_index(scmap_top, (ny, nx)) + return Y, X + + +def multi_pose_predict(scmap, locref, stride, num_outputs): + Y, X = get_top_values(scmap[None], num_outputs) + Y, X = Y[:, 0], X[:, 0] + num_joints = scmap.shape[2] + DZ = np.zeros((num_outputs, num_joints, 3)) + for m in range(num_outputs): + for k in range(num_joints): + x = X[m, k] + y = Y[m, k] + DZ[m, k, :2] = locref[y, x, k, :] + DZ[m, k, 2] = scmap[y, x, k] + + X = X.astype("float32") * stride + 0.5 * stride + DZ[:, :, 0] + Y = Y.astype("float32") * stride + 0.5 * stride + DZ[:, :, 1] + P = DZ[:, :, 2] + + pose = np.empty((num_joints, num_outputs * 3), dtype="float32") + pose[:, 0::3] = X.T + pose[:, 1::3] = Y.T + pose[:, 2::3] = P.T + + return pose diff --git a/dlclive/pose_estimation_tensorflow/runner.py b/dlclive/pose_estimation_tensorflow/runner.py new file mode 100644 index 0000000..fa05f8e --- /dev/null +++ b/dlclive/pose_estimation_tensorflow/runner.py @@ -0,0 +1,205 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""TensorFlow runners for DeepLabCut-Live""" +import glob +import os +from pathlib import Path +from typing import Any + +import numpy as np +import tensorflow as tf + +from dlclive.core.config import read_yaml +from dlclive.core.runner import BaseRunner +from dlclive.exceptions import DLCLiveError +from dlclive.pose_estimation_tensorflow.graph import ( + extract_graph, + finalize_graph, + get_output_nodes, + get_output_tensors, + read_graph, +) +from dlclive.pose_estimation_tensorflow.pose import ( + argmax_pose_predict, + extract_cnn_output, + multi_pose_predict, +) + + +class TensorFlowRunner(BaseRunner): + """TensorFlow runner for live pose estimation using DeepLabCut-Live.""" + + def __init__( + self, + path: str | Path, + model_type: str = "base", + tf_config: Any = None, + precision: str = "FP32", + ) -> None: + super().__init__(path) + self.cfg = self.read_config() + self.model_type = model_type + self.tf_config = tf_config + self.precision = precision + self.sess = None + self.inputs = None + self.outputs = None + self.tflite_interpreter = None + + def close(self) -> None: + """Clears any resources used by the runner.""" + if self.sess is not None: + self.sess.close() + self.sess = None + + def get_pose(self, frame: np.ndarray, **kwargs) -> np.ndarray: + if self.model_type in ["base", "tensorrt"]: + pose_output = self.sess.run( + self.outputs, feed_dict={self.inputs: np.expand_dims(frame, axis=0)} + ) + + elif self.model_type == "tflite": + self.tflite_interpreter.set_tensor( + self.inputs[0]["index"], + np.expand_dims(frame, axis=0).astype(np.float32), + ) + self.tflite_interpreter.invoke() + + if len(self.outputs) > 1: + pose_output = [ + self.tflite_interpreter.get_tensor(self.outputs[0]["index"]), + self.tflite_interpreter.get_tensor(self.outputs[1]["index"]), + ] + else: + pose_output = self.tflite_interpreter.get_tensor( + self.outputs[0]["index"] + ) + + else: + raise DLCLiveError( + f"model_type={self.model_type} is not supported. model_type must be " + f"'base', 'tflite', or 'tensorrt'" + ) + + # check if using TFGPUinference flag + # if not, get pose from network output + if len(pose_output) > 1: + scmap, locref = extract_cnn_output(pose_output, self.cfg) + num_outputs = self.cfg.get("num_outputs", 1) + if num_outputs > 1: + pose = multi_pose_predict( + scmap, locref, self.cfg["stride"], num_outputs + ) + else: + pose = argmax_pose_predict(scmap, locref, self.cfg["stride"]) + else: + pose = np.array(pose_output[0]) + pose = pose[:, [1, 0, 2]] + + return pose + + def init_inference(self, frame: np.ndarray, **kwargs) -> np.ndarray: + model_file = glob.glob(os.path.normpath(str(self.path) + "/*.pb"))[0] + + tf_ver = tf.__version__ + tf_version_2 = tf_ver[0] == "2" + + # load model + if self.model_type == "base": + graph_def = read_graph(model_file) + graph = finalize_graph(graph_def) + self.sess, self.inputs, self.outputs = extract_graph( + graph, tf_config=self.tf_config + ) + + elif self.model_type == "tflite": + ### + # the frame size needed to initialize the tflite model as + # tflite does not support saving a model with dynamic input size + ### + + # get input and output tensor names from graph_def + graph_def = read_graph(model_file) + graph = finalize_graph(graph_def) + output_nodes = get_output_nodes(graph) + output_nodes = [on.replace("DLC/", "") for on in output_nodes] + placeholder_shape = [1, frame.shape[0], frame.shape[1], 3] + + if tf_version_2: + converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( + model_file, + ["Placeholder"], + output_nodes, + input_shapes={"Placeholder": placeholder_shape}, + ) + else: + converter = tf.lite.TFLiteConverter.from_frozen_graph( + model_file, + ["Placeholder"], + output_nodes, + input_shapes={"Placeholder": placeholder_shape}, + ) + + try: + tflite_model = converter.convert() + except Exception: + raise DLCLiveError( + ( + "This model cannot be converted to tensorflow lite format. " + "To use tensorflow lite for live inference, " + "make sure to set TFGPUinference=False " + "when exporting the model from DeepLabCut" + ) + ) + + self.tflite_interpreter = tf.lite.Interpreter(model_content=tflite_model) + self.tflite_interpreter.allocate_tensors() + self.inputs = self.tflite_interpreter.get_input_details() + self.outputs = self.tflite_interpreter.get_output_details() + + elif self.model_type == "tensorrt": + graph_def = read_graph(model_file) + graph = finalize_graph(graph_def) + output_tensors = get_output_tensors(graph) + output_tensors = [ot.replace("DLC/", "") for ot in output_tensors] + + if (tf_ver[0] > 1) | (tf_ver[0] == 1 & tf_ver[1] >= 14): + converter = trt.TrtGraphConverter( + input_graph_def=graph_def, + nodes_blacklist=output_tensors, + is_dynamic_op=True, + ) + graph_def = converter.convert() + else: + graph_def = trt.create_inference_graph( + input_graph_def=graph_def, + outputs=output_tensors, + max_batch_size=1, + precision_mode=self.precision, + is_dynamic_op=True, + ) + + graph = finalize_graph(graph_def) + self.sess, self.inputs, self.outputs = extract_graph( + graph, tf_config=self.tf_config + ) + + else: + raise DLCLiveError( + f"model_type={self.model_type} is not supported. model_type must be " + "'base', 'tflite', or 'tensorrt'" + ) + + return self.get_pose(frame, **kwargs) + + def read_config(self) -> dict: + """Reads the configuration file""" + return read_yaml(self.path / "pose_cfg.yaml") diff --git a/dlclive/predictor/__init__.py b/dlclive/predictor/__init__.py new file mode 100644 index 0000000..3f6777c --- /dev/null +++ b/dlclive/predictor/__init__.py @@ -0,0 +1 @@ +from dlclive.predictor.single_predictor import HeatmapPredictor diff --git a/dlclive/predictor/base.py b/dlclive/predictor/base.py new file mode 100644 index 0000000..f8b8b98 --- /dev/null +++ b/dlclive/predictor/base.py @@ -0,0 +1,64 @@ +""" +DeepLabCut Toolbox (deeplabcut.org) +© A. & M. Mathis Labs + +Licensed under GNU Lesser General Public License v3.0 +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import torch +from torch import nn + +from dlclive.pose_estimation_pytorch.models.registry import Registry, build_from_cfg + +PREDICTORS = Registry("predictors", build_func=build_from_cfg) + + +class BasePredictor(ABC, nn.Module): + """The base Predictor class. + + This class is an abstract base class (ABC) for defining predictors used in the + DeepLabCut Toolbox. All predictor classes should inherit from this base class and + implement the forward method. Regresses keypoint coordinates from a models output + maps + + Attributes: + num_animals: Number of animals in the project. Should be set in subclasses. + + Example: + # Create a subclass that inherits from BasePredictor + class MyPredictor(BasePredictor): + def __init__(self, num_animals): + super().__init__() + self.num_animals = num_animals + + def forward(self, outputs): + # Implement the forward pass of your custom predictor here. + pass + """ + + def __init__(self): + super().__init__() + self.num_animals = None + + @abstractmethod + def forward( + self, stride: float, outputs: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Abstract method for the forward pass of the Predictor. + + Args: + stride: the stride of the model + outputs: outputs of the model heads + + Returns: + A dictionary containing a "poses" key with the output tensor as value, and + optionally a "unique_bodyparts" with the unique bodyparts tensor as value. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + pass diff --git a/dlclive/predictor/single_predictor.py b/dlclive/predictor/single_predictor.py new file mode 100644 index 0000000..c622cf9 --- /dev/null +++ b/dlclive/predictor/single_predictor.py @@ -0,0 +1,164 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +from typing import Tuple + +import torch + +from dlclive.pose_estimation_pytorch.models.predictors.base import ( + BasePredictor, + PREDICTORS, +) + + +@PREDICTORS.register_module +class HeatmapPredictor(BasePredictor): + """Predictor class for pose estimation from heatmaps (and optionally locrefs). + + Args: + location_refinement: Enable location refinement. + locref_std: Standard deviation for location refinement. + apply_sigmoid: Apply sigmoid to heatmaps. Defaults to True. + + Returns: + Regressed keypoints from heatmaps and locref_maps of baseline DLC model (ResNet + Deconv). + """ + + def __init__( + self, + apply_sigmoid: bool = True, + clip_scores: bool = False, + location_refinement: bool = True, + locref_std: float = 7.2801, + ): + """ + Args: + apply_sigmoid: Apply sigmoid to heatmaps. Defaults to True. + clip_scores: If a sigmoid is not applied, this can be used to clip scores + for predicted keypoints to values in [0, 1]. + location_refinement : Enable location refinement. + locref_std: Standard deviation for location refinement. + """ + super().__init__() + self.apply_sigmoid = apply_sigmoid + self.clip_scores = clip_scores + self.sigmoid = torch.nn.Sigmoid() + self.location_refinement = location_refinement + self.locref_std = locref_std + + def forward( + self, stride: float, outputs: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Forward pass of SinglePredictor. Gets predictions from model output. + + Args: + stride: the stride of the model + outputs: output of the model heads (heatmap, locref) + + Returns: + A dictionary containing a "poses" key with the output tensor as value. + + Example: + >>> predictor = HeatmapPredictor(location_refinement=True, locref_std=7.2801) + >>> stride = 8 + >>> output = {"heatmap": torch.rand(32, 17, 64, 64), "locref": torch.rand(32, 17, 64, 64)} + >>> poses = predictor.forward(stride, output) + """ + heatmaps = outputs["heatmap"] + scale_factors = stride, stride + + if self.apply_sigmoid: + heatmaps = self.sigmoid(heatmaps) + + heatmaps = heatmaps.permute(0, 2, 3, 1) + batch_size, height, width, num_joints = heatmaps.shape + + locrefs = None + if self.location_refinement: + locrefs = outputs["locref"] + locrefs = locrefs.permute(0, 2, 3, 1).reshape( + batch_size, height, width, num_joints, 2 + ) + locrefs = locrefs * self.locref_std + + poses = self.get_pose_prediction(heatmaps, locrefs, scale_factors) + + if self.clip_scores: + poses[..., 2] = torch.clip(poses[..., 2], min=0, max=1) + + return {"poses": poses} + + def get_top_values( + self, heatmap: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the top values from the heatmap. + + Args: + heatmap: Heatmap tensor. + + Returns: + Y and X indices of the top values. + + Example: + >>> predictor = HeatmapPredictor(location_refinement=True, locref_std=7.2801) + >>> heatmap = torch.rand(32, 17, 64, 64) + >>> Y, X = predictor.get_top_values(heatmap) + """ + batchsize, ny, nx, num_joints = heatmap.shape + heatmap_flat = heatmap.reshape(batchsize, nx * ny, num_joints) + heatmap_top = torch.argmax(heatmap_flat, dim=1) + y, x = heatmap_top // nx, heatmap_top % nx + return y, x + + def get_pose_prediction( + self, heatmap: torch.Tensor, locref: torch.Tensor | None, scale_factors + ) -> torch.Tensor: + """Gets the pose prediction given the heatmaps and locref. + + Args: + heatmap: Heatmap tensor of shape (batch_size, height, width, num_joints) + locref: Locref tensor of shape (batch_size, height, width, num_joints, 2) + scale_factors: Scale factors for the poses. + + Returns: + Pose predictions of the format: (batch_size, num_people = 1, num_joints, 3) + + Example: + >>> predictor = HeatmapPredictor( + >>> location_refinement=True, locref_std=7.2801 + >>> ) + >>> heatmap = torch.rand(32, 17, 64, 64) + >>> locref = torch.rand(32, 17, 64, 64, 2) + >>> scale_factors = (0.5, 0.5) + >>> poses = predictor.get_pose_prediction(heatmap, locref, scale_factors) + """ + y, x = self.get_top_values(heatmap) + + batch_size, num_joints = x.shape + + dz = torch.zeros((batch_size, 1, num_joints, 3)).to(x.device) + for b in range(batch_size): + for j in range(num_joints): + dz[b, 0, j, 2] = heatmap[b, y[b, j], x[b, j], j] + if locref is not None: + dz[b, 0, j, :2] = locref[b, y[b, j], x[b, j], j, :] + + x, y = torch.unsqueeze(x, 1), torch.unsqueeze(y, 1) + + x = x * scale_factors[1] + 0.5 * scale_factors[1] + dz[:, :, :, 0] + y = y * scale_factors[0] + 0.5 * scale_factors[0] + dz[:, :, :, 1] + + pose = torch.empty((batch_size, 1, num_joints, 3)) + pose[:, :, :, 0] = x + pose[:, :, :, 1] = y + pose[:, :, :, 2] = dz[:, :, :, 2] + return pose diff --git a/dlclive/processor/__init__.py b/dlclive/processor/__init__.py index 67e14db..657b405 100644 --- a/dlclive/processor/__init__.py +++ b/dlclive/processor/__init__.py @@ -4,6 +4,4 @@ Licensed under GNU Lesser General Public License v3.0 """ - from dlclive.processor.processor import Processor -from dlclive.processor.kalmanfilter import KalmanFilterPredictor diff --git a/dlclive/processor/kalmanfilter.py b/dlclive/processor/kalmanfilter.py index 447bcae..ff46805 100644 --- a/dlclive/processor/kalmanfilter.py +++ b/dlclive/processor/kalmanfilter.py @@ -5,9 +5,10 @@ Licensed under GNU Lesser General Public License v3.0 """ - import time + import numpy as np + from dlclive.processor import Processor @@ -25,7 +26,6 @@ def __init__( lik_thresh=0, **kwargs, ): - super().__init__(**kwargs) self.adapt = adapt @@ -41,16 +41,14 @@ def __init__( self.last_pose_time = 0 def _get_forward_model(self, dt): - F = np.zeros((self.n_states, self.n_states)) for d in range(self.nderiv + 1): for i in range(self.n_states - (d * self.bp * 2)): - F[i, i + (2 * self.bp * d)] = (dt ** d) / max(1, d) + F[i, i + (2 * self.bp * d)] = (dt**d) / max(1, d) return F def _init_kf(self, pose): - # get number of body parts self.bp = pose.shape[0] self.n_states = self.bp * 2 * (self.nderiv + 1) @@ -75,7 +73,6 @@ def _init_kf(self, pose): self.is_initialized = True def _predict(self): - F = self._get_forward_model(time.time() - self.last_pose_time) Pd = np.diag(self.P).reshape(self.P.shape[0], 1) @@ -85,7 +82,6 @@ def _predict(self): self.Pp = np.dot(np.dot(F, self.P), F.T) + self.Q def _get_residuals(self, pose): - z = np.zeros((self.n_states, 1)) z[: (self.bp * 2)] = pose[: self.bp, :2].reshape(self.bp * 2, 1) for i in range(self.bp * 2, self.n_states): @@ -93,7 +89,6 @@ def _get_residuals(self, pose): self.y = z - np.dot(self.H, self.Xp) def _update(self, liks): - S = np.dot(self.H, np.dot(self.Pp, self.H.T)) + self.R K = np.dot(np.dot(self.Pp, self.H.T), np.linalg.inv(S)) self.X = self.Xp + np.dot(K, self.y) @@ -101,7 +96,6 @@ def _update(self, liks): self.P = np.dot(self.I - np.dot(K, self.H), self.Pp) def _get_future_pose(self, dt): - Ff = self._get_forward_model(dt) Xf = np.dot(Ff, self.X) future_pose = Xf[: (self.bp * 2)].reshape(self.bp, 2) @@ -109,7 +103,6 @@ def _get_future_pose(self, dt): return future_pose def _get_state_likelihood(self, pose): - liks = pose[:, 2] liks_xy = np.repeat(liks, 2) liks_xy_deriv = np.tile(liks_xy, self.nderiv + 1) @@ -117,15 +110,12 @@ def _get_state_likelihood(self, pose): return liks_state def process(self, pose, **kwargs): - if not self.is_initialized: - self._init_kf(pose) self.last_pose_time = time.time() return pose else: - self._predict() self._get_residuals(pose) liks = self._get_state_likelihood(pose) diff --git a/dlclive/processor/processor.py b/dlclive/processor/processor.py index 8a52f5f..8bd28de 100644 --- a/dlclive/processor/processor.py +++ b/dlclive/processor/processor.py @@ -12,7 +12,7 @@ """ -class Processor(object): +class Processor: def __init__(self, **kwargs): pass diff --git a/dlclive/utils.py b/dlclive/utils.py index 4b0deaa..94a3dba 100644 --- a/dlclive/utils.py +++ b/dlclive/utils.py @@ -5,23 +5,24 @@ Licensed under GNU Lesser General Public License v3.0 """ +import warnings import numpy as np -import warnings + from dlclive.exceptions import DLCLiveWarning try: import skimage SK_IM = True -except Exception: +except ImportError as e: SK_IM = False try: import cv2 OPEN_CV = True -except Exception: +except ImportError as e: from PIL import Image OPEN_CV = False @@ -31,18 +32,18 @@ ) -def convert_to_ubyte(frame): - """ Converts an image to unsigned 8-bit integer numpy array. +def convert_to_ubyte(frame: np.ndarray) -> np.ndarray: + """Converts an image to unsigned 8-bit integer numpy array. If scikit-image is installed, uses skimage.img_as_ubyte, otherwise, uses a similar custom function. Parameters ---------- - image : :class:`numpy.ndarray` + frame: an image as a numpy array Returns ------- - :class:`numpy.ndarray` + :class: `numpy.ndarray` image converted to uint8 """ @@ -52,36 +53,33 @@ def convert_to_ubyte(frame): return _img_as_ubyte_np(frame) -def resize_frame(frame, resize=None): - """ Resizes an image. Uses OpenCV if installed, otherwise, uses pillow +def resize_frame(frame: np.ndarray, resize=None) -> np.ndarray: + """Resizes an image. Uses OpenCV if installed, otherwise, uses pillow Parameters ---------- - image : :class:`numpy.ndarray` + frame: an image as a numpy array """ if (resize is not None) and (resize != 1): + new_x = int(frame.shape[0] * resize) + new_y = int(frame.shape[1] * resize) if OPEN_CV: - - new_x = int(frame.shape[0] * resize) - new_y = int(frame.shape[1] * resize) return cv2.resize(frame, (new_y, new_x)) else: - img = Image.fromarray(frame) img = img.resize((new_y, new_x)) return np.asarray(img) else: - return frame -def img_to_rgb(frame): - """ Convert an image to RGB. Uses OpenCV is installed, otherwise uses pillow. +def img_to_rgb(frame: np.ndarray) -> np.ndarray: + """Convert an image to RGB. Uses OpenCV is installed, otherwise uses pillow. Parameters ---------- @@ -90,15 +88,12 @@ def img_to_rgb(frame): """ if frame.ndim == 2: - return gray_to_rgb(frame) elif frame.ndim == 3: - return bgr_to_rgb(frame) else: - warnings.warn( f"Image has {frame.ndim} dimensions. Must be 2 or 3 dimensions to convert to RGB", DLCLiveWarning, @@ -106,8 +101,8 @@ def img_to_rgb(frame): return frame -def gray_to_rgb(frame): - """ Convert an image from grayscale to RGB. Uses OpenCV is installed, otherwise uses pillow. +def gray_to_rgb(frame: np.ndarray) -> np.ndarray: + """Convert an image from grayscale to RGB. Uses OpenCV is installed, otherwise uses pillow. Parameters ---------- @@ -116,18 +111,16 @@ def gray_to_rgb(frame): """ if OPEN_CV: - return cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) else: - img = Image.fromarray(frame) img = img.convert("RGB") return np.asarray(img) -def bgr_to_rgb(frame): - """ Convert an image from BGR to RGB. Uses OpenCV is installed, otherwise uses pillow. +def bgr_to_rgb(frame: np.ndarray) -> np.ndarray: + """Convert an image from BGR to RGB. Uses OpenCV is installed, otherwise uses pillow. Parameters ---------- @@ -136,18 +129,16 @@ def bgr_to_rgb(frame): """ if OPEN_CV: - return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) else: - img = Image.fromarray(frame) img = img.convert("RGB") return np.asarray(img) -def _img_as_ubyte_np(frame): - """ Converts an image as a numpy array to unsinged 8-bit integer. +def _img_as_ubyte_np(frame: np.ndarray) -> np.ndarray: + """Converts an image as a numpy array to unsinged 8-bit integer. As in scikit-image img_as_ubyte, converts negative pixels to 0 and converts range to [0, 255] Parameters @@ -166,12 +157,10 @@ def _img_as_ubyte_np(frame): # check if already ubyte if np.issubdtype(im_type, np.uint8): - return frame # if floating elif np.issubdtype(im_type, np.floating): - if (np.min(frame) < -1) or (np.max(frame) > 1): raise ValueError("Images of type float must be between -1 and 1.") @@ -182,14 +171,12 @@ def _img_as_ubyte_np(frame): # if integer elif np.issubdtype(im_type, np.integer): - im_type_info = np.iinfo(im_type) frame *= 255 / im_type_info.max frame[frame < 0] = 0 return frame.astype(np.uint8) else: - raise TypeError( "image of type {} could not be converted to ubyte".format(im_type) ) diff --git a/dlclive/version.py b/dlclive/version.py index 7996e03..a2047a2 100644 --- a/dlclive/version.py +++ b/dlclive/version.py @@ -6,6 +6,5 @@ Licensed under GNU Lesser General Public License v3.0 """ - -__version__ = "1.0.4" +__version__ = "3.0.0a0" VERSION = __version__ diff --git a/docs/DLC Live Benchmark.md b/docs/DLC Live Benchmark.md new file mode 100755 index 0000000..583e9f2 --- /dev/null +++ b/docs/DLC Live Benchmark.md @@ -0,0 +1,32 @@ +## Inference time + +| System | Model type | Runtime | Device type | Precision | Video | Video length (s) - # Frames | FPS | Frame size | Display settings | Pose model backbone | Avg Inference time ± Std
*(including 1st inference)* | Avg Inference time ± Std | Average FPS ± Std | Model size | +| ------ | ---------- | -------- | ----------- | -------------------------------------- | ------------ | --------------------------- | --- | ---------- | ---------------- | ---------------------------------- | -------------------------------------------------------- | ------------------------ | ----------------- | ---------- | +| Linux | ONNX | ONNX | CUDA | Full precision (FP32) | Ventral gait | 10s - 1.5k | 150 | (658,302) | None | `ResNet50` (bu) | 29.02ms ± 47.59ms | 27.8ms ± 2.32ms | 36 ± 3 | 92.12 MB | +| Linux | ONNX | ONNX | CPU | Full precision (FP32) | Ventral gait | 10s - 1.5k | 150 | (658,302) | None | `ResNet50` (bu) | 146.12ms ± 13.26ms | 146.11 ± 13.25 | 7 ± 1 | 92.12 MB | +| Linux | PyTorch | PyTorch | CUDA | Full precision (FP32) | Ventral gait | 10s - 1.5k | 150 | (658,302) | None | `ResNet50` (bu) | 6.04ms ± 7.37ms | 5.97ms ± 6.8ms | 271 ± 112 | 96.5 MB | +| Linux | PyTorch | PyTorch | CPU | Full precision (FP32) | Ventral gait | 10s - 1.5k | 150 | (658,302) | None | `ResNet50` (bu) | 365.26ms ± 13.88ms | 365.17ms ± 13.44ms | 3 ± 0 | 96.5 MB | +| Linux | ONNX | TensorRT | CUDA | Full precision (FP32) - no caching | Ventral gait | 10s - 1.5k | 150 | (658,302) | None | `ResNet50` (bu) | 55.32ms ± 1254.16ms | 22.93ms ± 0.88 | 44 ± 2 | 92.12 MB | +| Linux | ONNX | TensorRT | CUDA | Full precision (FP32) - engine caching | Ventral gait | 10s - 1.5k | 150 | (658,302) | None | `ResNet50` (bu) | 20.8ms ± 3.4ms | 20.72ms ± 1.25ms | 48 ± 3 | 92.12 MB | +| Linux | ONNX | TensorRT | CUDA | FP16 | Ventral gait | 10s - 1.5k | 150 | (658,302) | None | `ResNet50` (bu) | 34.37ms ± 858.96ms | 12.19ms ± 0.87 | 82 ± 6 | 46.16 MB | +| Linux | ONNX | ONNX | CUDA | FP16 | Ventral gait | 10s - 1.5k | 150 | (658,302) | None | `ResNet50` (bu) | 21.74ms ± 43.24ms | 20.62ms ± 2.5ms | 49 ± 5 | 46.16 MB | +| Linux | PyTorch | PyTorch | CUDA | FP32 | Ventral gait | 10s - 1.5k | 150 | (164,75) | Resize=0.25 | `ResNet50` (bu) | 22.27ms ± 12.5ms | 22.16ms ± 11.65ms | 70 ± 68 | 96.5 MB | +| Linux | ONNX | ONNX | CUDA | (FP32) | Ventral gait | 10s - 1.5k | 150 | (164,75) | Resize=0.25 | `ResNet50` (bu) | 6.18ms ± 37.03ms | 5.22ms ± 0.86ms | 195 ± 25 | | +| Linux | ONNX | ONNX | CPU | (FP32) | Ventral gait | 10s - 1.5k | 150 | (164,75) | Resize=0.25 | `ResNet50` (bu) | 13.17ms ± 1.25ms | 13.17ms ± 1.23ms | 76 ± 4 | | +| Linux | ONNX | TensorRT | CUDA | (FP32) | Ventral gait | 10s - 1.5k | 150 | (164,75) | Resize=0.25 | `ResNet50` (bu) | 15.12ms ± 458.27ms | 3.28ms ± 0.24ms | 306 ± 23 | | +| Linux | ONNX | ONNX | CUDA | FP16 | Ventral gait | 10s - 1.5k | 150 | (164,75) | Resize=0.25 | `ResNet50` (bu) | 5.83ms ± 33.27ms | 4.97ms ± 1.5ms | 214 ± 45 | | +| Linux | ONNX | ONNX | CUDA | FP16 | Pigeon | 36s - ~1k | 30 | (480, 270) | Resize=0.25 | `ResNet50 + SSDLite detector` (td) | 17.08 ms ± 139.91ms | 12.82 ms ± 1.52 | 79 ± 8 | 45.50 MB | +| Linux | ONNX | ONNX | CUDA | FP32 | Pigeon | 36s - ~1k | 30 | (480, 270) | Resize=0.25 | `ResNet50 + SSDLite detector` (td) | 25.06 ms ± 129.74ms | 21.1 ms ± 0.82ms | 47 ± 2 | 90.79 MB | +| Linux | ONNX | TensorRT | CUDA | FP32 | Pigeon | 36s - ~1k | 30 | (480, 270) | Resize=0.25 | `ResNet50 + SSDLite detector` (td) | 6.18 ms ± 1376.44 ms | 14.22 ms ± 0.48ms | 70 ± 3 | | +| Linux | ONNX | TensorRT | CUDA | FP16 | Pigeon | 36s - ~1k | 30 | (480, 270) | Resize=0.25 | `ResNet50 + SSDLite detector` (td) | 49.81 ms ± 1361.7ms | 8.3 ms ± 0.75ms | 121 ± 11 | | +| Linux | PyTorch | PyTorch | CUDA | FP32 | Pigeon | 36s - ~1k | 30 | (480, 270) | Resize=0.25 | `ResNet50 + SSDLite detector` (td) | 7.7 ms ± 5.38 ms | 7.78 ms ± 6.0 ms | 185 ± 96 | | +| Linux | PyTorch | PyTorch | CPU | FP32 | Pigeon | 36s - ~1k | 30 | (480, 270) | Resize=0.25 | `ResNet50 + SSDLite detector` (td) | 167.33 ms ± 21.0 ms | 167.32 ms ± 21.01 ms | 6 ± 1 | | +| Linux | ONNX | ONNX | CPU | FP32 | Pigeon | 36s - ~1k | 30 | (480, 270) | Resize=0.25 | `ResNet50 + SSDLite detector` (td) | 85.64 ms ± 8.23 ms | 85.65 ms ± 8.23 | 12 ± 1 | | +| Linux | ONNX | ONNX | CPU | FP16 | Pigeon | 36s - ~1k | 30 | (480, 270) | Resize=0.25 | `ResNet50 + SSDLite detector` (td) | 161.32 ms ± 18.29ms | 161.3 ms ± 18.29ms | 6 ± 1 | | + +** **CUDA: NVIDIA GeForce RTX 3050 (6GB)** +** **CPU: 13th Gen Intel Core i7-13620H × 16** +** **Linux: Ubuntu 24.04 LTS** + +^ *Startup time at inference for a TensorRT engine takes between 30 and 50 seconds, +which skews the inference time measurement. Caching is used to reduce that time.* diff --git a/docs/assets/select_dlc.png b/docs/assets/select_dlc.png new file mode 100644 index 0000000..1848884 Binary files /dev/null and b/docs/assets/select_dlc.png differ diff --git a/docs/install_desktop.md b/docs/install_desktop.md index 5eb710c..2b81976 100755 --- a/docs/install_desktop.md +++ b/docs/install_desktop.md @@ -1,18 +1,32 @@ ### Install DeepLabCut-live on a desktop (Windows/Ubuntu) -We recommend that you install DeepLabCut-live in a conda environment (It is a standard python package though, and other distributions will also likely work). In this case, please install Anaconda: - -- [Windows](https://docs.anaconda.com/anaconda/install/windows/) -- [Linux](https://docs.anaconda.com/anaconda/install/linux/) - -Create a conda environment with python 3.7 and tensorflow: +We recommend that you install DeepLabCut-live in a conda environment (It is a standard +python package though, and other distributions will also likely work). In this case, +please install [Miniconda](https://docs.anaconda.com/miniconda/miniconda-install/) +(recommended) or Anaconda. + +If you have an Nvidia GPU and want to use its capabilities, you'll need to [install CUDA +](https://developer.nvidia.com/cuda-downloads) first (check that CUDA is installed - +checkout the installation guide for [linux]( +https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) or [Windows]( +https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html). + +Create a conda environment with python 3.10 or 3.11, and install +[`pytables`](https://www.pytables.org/usersguide/installation.html), `torch` and +`torchvision`. Make sure you [install the correct `torch` and `torchvision` versions +for your compute platform](https://pytorch.org/get-started/locally/)! ``` -conda create -n dlc-live python=3.7 tensorflow-gpu==1.13.1 # if using GPU -conda create -n dlc-live python=3.7 tensorflow==1.13.1 # if not using GPU +conda create -n dlc-live python=3.11 +conda activate dlc-live +conda install -c conda-forge pytables==3.8.0 + +# Installs PyTorch on Linux with CUDA 12.4 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 ``` -Activate the conda environment, install the DeepLabCut-live package, then test the installation: +Activate the conda environment, install the DeepLabCut-live package, then test the +installation: ``` conda activate dlc-live @@ -24,6 +38,10 @@ Note, you can also just run the test: `dlc-live-test` -If installed properly, this script will i) create a temporary folder ii) download the full_dog model from the [DeepLabCut Model Zoo](http://www.mousemotorlab.org/dlc-modelzoo), iii) download a short video clip of a dog, and iv) run inference while displaying keypoints. v) remove the temporary folder. +If installed properly, this script will i) create a temporary folder ii) download the +full_dog model from the [DeepLabCut Model Zoo]( +http://www.mousemotorlab.org/dlc-modelzoo), iii) download a short video clip of +a dog, and iv) run inference while displaying keypoints. v) remove the temporary folder. -Please note, you also should have curl installed on your computer (typically this is already installed on your system), but just in case, just run `sudo apt install curl` +Please note, you also should have curl installed on your computer (typically this is +already installed on your system), but just in case, just run `sudo apt install curl` diff --git a/docs/install_jetson.md b/docs/install_jetson.md index 33f6ee3..2db456e 100755 --- a/docs/install_jetson.md +++ b/docs/install_jetson.md @@ -1,17 +1,23 @@ ### Install DeepLabCut-live on a NVIDIA Jetson Development Kit -First, please follow NVIDIA's specific instructions to setup your Jetson Development Kit (see [Jetson Development Kit User Guides](https://developer.nvidia.com/embedded/learn/getting-started-jetson)). Once you have installed the NVIDIA Jetpack on your Jetson Development Kit, make sure all system libraries are up-to-date. In a terminal, run: +First, please follow NVIDIA's specific instructions to setup your Jetson Development Kit +(see [Jetson Development Kit User Guides](https://developer.nvidia.com/embedded/learn/getting-started-jetson)). Once you have installed the NVIDIA +Jetpack on your Jetson Development Kit, make sure all system libraries are up-to-date. +In a terminal, run: ``` sudo apt-get update sudo apt-get upgrade ``` -Lastly, please test that CUDA is installed properly by running: `nvcc --version`. The output should say the version of CUDA installed on your Jetson. +Lastly, please test that CUDA is installed properly by running: `nvcc --version`. The +output should say the version of CUDA installed on your Jetson. #### Install python, virtualenv, and tensorflow -We highly recommend installing DeepLabCut-live in a virtual environment. Please run the following command to install system dependencies needed to run python, to create virtual environments, and to run tensorflow: +We highly recommend installing DeepLabCut-live in a virtual environment. Please run the +following command to install system dependencies needed to run python, to create virtual +environments, and to run tensorflow: ``` sudo apt-get update @@ -32,7 +38,8 @@ sudo apt-get install libhdf5-serial-dev \ #### Create a virtual environment -Next, create a virtual environment called `dlc-live`, activate the `dlc-live` environment, and update it's package manger: +Next, create a virtual environment called `dlc-live`, activate the `dlc-live` +environment, and update it's package manager: ``` python3 -m venv dlc-live @@ -42,7 +49,10 @@ pip install -U pip testresources setuptools #### Install DeepLabCut-live dependencies -First, install python dependencies to run tensorflow (from [NVIDIA instructions to install tensorflow on Jetson platforms](https://docs.nvidia.com/deeplearning/frameworks/install-tf-jetson-platform/index.html)). _This may take ~15-30 minutes._ +First, install `python` dependencies to run `PyTorch` (from [NVIDIA instructions to +install PyTorch for Jetson Platform]( +https://docs.nvidia.com/deeplearning/frameworks/install-pytorch-jetson-platform/index.html)). +_This may take ~15-30 minutes._ ``` pip3 install numpy==1.16.1 \ @@ -57,25 +67,35 @@ pip3 install numpy==1.16.1 \ pybind11 ``` -Next, install tensorflow 1.x. This command will depend on the version of Jetpack you are using. If you are uncertain, please refer to [NVIDIA's instructions](https://docs.nvidia.com/deeplearning/frameworks/install-tf-jetson-platform/index.html#install). To install tensorflow 1.x on the latest version of NVIDIA Jetpack (version 4.4 as of 8/2/2020), please the command below. _This step will also take 15-30 mins_. +Next, install PyTorch >= 2.0. This command will depend on the version of Jetpack you are +using. If you are uncertain, please refer to [NVIDIA's instructions]( +https://docs.nvidia.com/deeplearning/frameworks/install-pytorch-jetson-platform/index.html). +To install PyTorch >= 2.0 ``` -pip3 install --pre --extra-index-url https://developer.download.nvidia.com/compute/redist/jp/v44 'tensorflow<2' +pip3 install --no-cache https://developer.download.nvidia.com/compute/redist/jp/v51/pytorch/ ``` +Currently, the only available PyTorch version that can be used is +`torch-2.0.0a0+8aa34602.nv23.03-cp38-cp38-linux_aarch64.whl`. + + Lastly, copy the opencv-python bindings into your virtual environment: ``` -cp -r /usr/lib/python3.6/dist-packages ~/dlc-live/lib/python3.6/dist-packages +cp -r /usr/lib/python3.12/dist-packages ~/dlc-live/lib/python3.12/dist-packages ``` #### Install the DeepLabCut-live package -Finally, please install DeepLabCut-live from PyPi (_this will take 3-5 mins_), then test the installation: +Finally, please install DeepLabCut-live from PyPi (_this will take 3-5 mins_), then +test the installation: ``` pip install deeplabcut-live dlc-live-test ``` -If installed properly, this script will i) download the full_dog model from the DeepLabCut Model Zoo, ii) download a short video clip of a dog, and iii) run inference while displaying keypoints. +If installed properly, this script will i) download the full_dog model from the +DeepLabCut Model Zoo, ii) download a short video clip of a dog, and iii) run inference +while displaying keypoints. diff --git a/pyproject.toml b/pyproject.toml index 0566071..200a5d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "deeplabcut-live" -version = "1.0.4" +version = "3.0.0a0" description = "Class to load exported DeepLabCut networks and perform pose estimation on single frames (from a camera feed)" authors = ["A. & M. Mathis Labs "] license = "AGPL-3.0-or-later" @@ -9,10 +9,8 @@ homepage = "https://github.com/DeepLabCut/DeepLabCut-live" repository = "https://github.com/DeepLabCut/DeepLabCut-live" classifiers = [ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", "Operating System :: OS Independent" ] @@ -26,18 +24,34 @@ dlc-live-test = "dlclive.check_install.check_install:main" dlc-live-benchmark = "dlclive.benchmark:main" [tool.poetry.dependencies] -python = ">=3.7.1,<3.11" -numpy = "^1.20" +python = ">=3.10,<3.12" +numpy = ">=1.20,<2" "ruamel.yaml" = "^0.17.20" colorcet = "^3.0.0" +einops = ">=0.6.1" Pillow = ">=8.0.0" py-cpuinfo = ">=5.0.0" tqdm = "^4.62.3" -tensorflow = "^2.7.0,<=2.12" -pandas = "^1.3" -tables = "^3.6" +pandas = ">=1.0.1,!=1.5.0" +tables = "^3.8" opencv-python-headless = "^4.5" -dlclibrary = ">=0.0.2" +dlclibrary = ">=0.0.6" +# PyTorch models +scipy = ">=1.9" +timm = { version = ">=1.0.7", optional = true } +torch = { version = ">=2.0.0", optional = true } +torchvision = { version = ">=0.15", optional = true } +# TensorFlow models +tensorflow = [ + { version = "^2.7.0,<=2.10", optional = true, platform = "win32" }, + { version = "^2.7.0,<=2.12", optional = true, platform = "linux" }, +] +tensorflow-macos = { version = "^2.7.0,<=2.12", optional = true, markers = "sys_platform == 'darwin'" } +tensorflow-metal = { version = "<1.3.0", optional = true, markers = "sys_platform == 'darwin'" } + +[tool.poetry.extras] +tf = [ "tensorflow", "tensorflow-macos", "tensorflow-metal"] +pytorch = ["scipy", "timm", "torch", "torchvision"] [tool.poetry.dev-dependencies] diff --git a/scripts/export.py b/scripts/export.py new file mode 100644 index 0000000..320ada0 --- /dev/null +++ b/scripts/export.py @@ -0,0 +1,81 @@ +"""Exports DeepLabCut models for DeepLabCut-Live""" +import warnings +from pathlib import Path + +import torch +from ruamel.yaml import YAML + + +def read_config_as_dict(config_path: str | Path) -> dict: + """ + Args: + config_path: the path to the configuration file to load + + Returns: + The configuration file with pure Python classes + """ + with open(config_path, "r") as f: + cfg = YAML(typ='safe', pure=True).load(f) + + return cfg + + +def export_dlc3_model( + export_path: Path, + model_config_path: Path, + pose_snapshot: Path, + detector_snapshot: Path | None = None, +) -> None: + """Exports a DLC3 model + + Args: + export_path: + model_config_path: + pose_snapshot: + detector_snapshot: + """ + model_cfg = read_config_as_dict(model_config_path) + + load_kwargs = dict(map_location="cpu", weights_only=True) + pose_weights = torch.load(pose_snapshot, **load_kwargs)["model"] + detector_weights = None + if detector_snapshot is None: + if model_cfg["method"].lower() == "td": + warnings.warn( + "The model is a top-down model but no detector snapshot was given." + "The configuration will be changed to run the model in bottom-up mode." + ) + model_cfg["method"] = "bu" + + else: + if model_cfg["method"].lower() == "bu": + raise ValueError(f"Cannot use a detector with a bottom-up model!") + detector_weights = torch.load(detector_snapshot, **load_kwargs)["model"] + + torch.save( + dict(config=model_cfg, detector=detector_weights, pose=pose_weights), + export_path, + ) + + +if __name__ == "__main__": + root = Path("/Users/john/Documents") + project_dir = root / "2024-10-14-my-model" + + # Exporting a top-down model + model_dir = project_dir / "top-down-resnet-50" / "model" + export_dlc3_model( + export_path=model_dir / "dlclive-export-fasterrcnnMobilenet-resnet50.pt", + model_config_path=model_dir / "pytorch_config.yaml", + pose_snapshot=model_dir / "snapshot-50.pt", + detector_snapshot=model_dir / "snapshot-detector-100.pt", + ) + + # Exporting a bottom-up model + model_dir = project_dir / "resnet-50" / "model" + export_dlc3_model( + export_path=model_dir / "dlclive-export-bu-resnet50.pt", + model_config_path=model_dir / "pytorch_config.yaml", + pose_snapshot=model_dir / "snapshot-50.pt", + detector_snapshot=None, + ) diff --git a/scripts/fix_deeplabcut_imports.py b/scripts/fix_deeplabcut_imports.py new file mode 100644 index 0000000..e73b4b0 --- /dev/null +++ b/scripts/fix_deeplabcut_imports.py @@ -0,0 +1,82 @@ +"""Script to update DeepLabCut imports when copying predictors""" +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class RecursiveImportFixer: + """Recursively fixes imports in python files""" + + import_prefix: str + new_import_prefix: str + dry_run: bool = False + + def fix_imports(self, target: Path) -> None: + if target.is_dir(): + self._walk_folder(target) + elif target.suffix == ".py": + self._fix_imports(target) + else: + raise ValueError(f"Oops! You can only fix `.py` files (not {target})") + + def _walk_folder(self, folder: Path) -> None: + if not folder.is_dir(): + raise ValueError(f"Oops! Something went wrong (not a folder): {folder}") + + for file in folder.iterdir(): + if file.suffix == ".py": + self._fix_imports(file) + elif file.is_dir(): + self._walk_folder(file) + + def _fix_imports(self, file: Path) -> None: + if not file.suffix == ".py": + raise ValueError(f"Oops! Something went wrong: {file}") + + print(f"Fixing file {file}") + with open(file, "r") as f: + file_content = f.readlines() + + fixed_lines = [] + for index, line in enumerate(file_content): + parsed = line + if self.import_prefix in line: + parsed = line.replace(self.import_prefix, self.new_import_prefix) + print(f" Found import on line {index}") + print(f" original: ```{line}```") + print(f" fixed: ```{parsed}```") + + fixed_lines.append(parsed) + + if not self.dry_run: + with open(file, "w") as f: + f.writelines(fixed_lines) + + +def main( + target: Path, + import_prefix: str, + new_import_prefix: str, + dry_run: bool, +) -> None: + print( + f"Replacing all imports of {import_prefix}.* in {target} with an import of " + f"{new_import_prefix}.*" + ) + fixer = RecursiveImportFixer(import_prefix, new_import_prefix, dry_run=dry_run) + fixer.fix_imports(target) + + +if __name__ == "__main__": + main( + target=Path("../dlclive/models").resolve(), + import_prefix="deeplabcut.pose_estimation_pytorch.models", + new_import_prefix="dlclive.models", + dry_run=True, + ) + main( + target=Path("../dlclive/models").resolve(), + import_prefix="deeplabcut.pose_estimation_pytorch.registry", + new_import_prefix="dlclive.models.registry", + dry_run=True, + )