Skip to content

Commit f1cf34c

Browse files
authored
Merge branch 'main' into feat-torchmetrics-eval
2 parents fdb8fed + 6bf5ecb commit f1cf34c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2869
-1649
lines changed

CONTRIBUTING.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Developer's Guide
22

3-
Depends on Python 3.5+
3+
Depends on Python 3.9+
44

55
## Getting started
66

@@ -9,7 +9,7 @@ Depends on Python 3.5+
99
2. Clone your copy of the repository.
1010

1111
- **Using ssh**:
12-
12+
1313
```bash
1414
git clone [email protected]:[your user name]/DeepForest.git
1515
```
@@ -83,7 +83,7 @@ $ pytest -v
8383

8484
We use [yapf](https://github.com/google/yapf) for code formatting and style checking.
8585

86-
The easiest way to make sure your code is formatted correctly is to integrate it into your editor.
86+
The easiest way to make sure your code is formatted correctly is to integrate it into your editor.
8787
See [EDITOR SUPPORT](https://github.com/google/yapf/blob/main/EDITOR%20SUPPORT.md).
8888

8989
You can also run yapf from the command line to cleanup the style in your changes:
@@ -111,7 +111,7 @@ $ conda build conda_recipe/meta.yaml -c conda-forge -c defaults
111111

112112
Update the Conda recipe after every release.
113113

114-
Clone the [Weecology staged recipes](https://github.com/weecology/staged-recipes).
114+
Clone the [Weecology staged recipes](https://github.com/weecology/staged-recipes).
115115
Checkout the deepforest branch, update the `deepforest/meta.yaml` with the new version and the sha256 values. Sha256 values are obtained from the source on [PYPI download files](https://pypi.org/project/deepforest/#files) using the deepforest-{version-number}.tar.gz.
116116

117117
```jinja
@@ -131,7 +131,7 @@ $ docformatter --in-place --recursive src/deepforest/
131131

132132
### Update Documentation
133133

134-
The documentation is automatically updated for changes in functions.
134+
The documentation is automatically updated for changes in functions.
135135
However, the documentation should be updated after the addition of new functions or modules.
136136

137137
Change to the docs directory and use `sphinx-apidoc` to update the doc's `source`. Exclude the tests and setup.py documentation.
@@ -198,4 +198,4 @@ model.push_to_hub("weecology/deepforest-livestock")
198198
199199
The model will be uploaded to [https://huggingface.co/weecology/[model-name]](https://huggingface.co/weecology/[model-name])
200200
201-
Note: You must have appropriate permissions in the weecology organization to upload models.
201+
Note: You must have appropriate permissions in the weecology organization to upload models.

dev_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
albumentations>=1.0.0,<2.0.0
1+
albumentations>=2.0.0
22
aiolimiter
33
aiohttp
44
bump-my-version

docs/source/deepforest.rst

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Subpackages
88
:maxdepth: 4
99

1010
deepforest.data
11+
deepforest.datasets
1112

1213
Submodules
1314
----------
@@ -28,14 +29,6 @@ deepforest.callbacks module
2829
:undoc-members:
2930
:show-inheritance:
3031

31-
deepforest.dataset module
32-
-------------------------
33-
34-
.. automodule:: deepforest.dataset
35-
:members:
36-
:undoc-members:
37-
:show-inheritance:
38-
3932
deepforest.evaluate module
4033
--------------------------
4134

docs/user_guide/03_cropmodels.md

Lines changed: 222 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Why would you want to apply a model directly to each crop? Why not train a multi
1414

1515
While that approach is certainly valid, there are a few key benefits to using CropModels, especially in common use cases:
1616

17-
- **Flexible Labeling**: Object detection models require that all objects of a particular class be annotated within an image, which can be impossible for detailed category labels. For example, you might have bounding boxes for all trees in an image, but only have species or health labels for a small portion of them based on ground surveys. Training a multi-class object detection model would mean training on only a portion of your available data.
17+
- **Flexible Labeling**: Object detection models require that all objects of a particular class be annotated within an image, which can be impossible for detailed category labels. For example, you might have bounding boxes for all 'trees' in an image, but only have species or health labels for a small portion of them based on ground surveys. Training a multi-class object detection model would mean training on only a portion of your available data.
1818
- **Simpler and Extendable**: CropModels decouple detection and classification workflows, allowing separate handling of challenges like class imbalance and incomplete labels, without reducing the quality of the detections. Two-stage object detection models can be finicky with similar classes and often require expertise in managing learning rates.
1919
- **New Data and Multi-sensor Learning**: In many applications, the data needed for detection and classification may differ. The CropModel concept provides an extendable piece that allows for advanced pipelines.
2020

@@ -177,4 +177,225 @@ class CustomCropModel(CropModel):
177177

178178
# Create an instance of the custom CropModel
179179
model = CustomCropModel()
180+
```
181+
182+
## Making Predictions Outside of predict_tile
183+
184+
While `predict_tile` provides a convenient way to run predictions on detected objects, you can also use the CropModel directly for classification tasks. This is useful when you have pre-cropped images or want to run classification independently.
185+
186+
### Loading a Trained Model
187+
188+
```python
189+
from deepforest.model import CropModel
190+
from pytorch_lightning import Trainer
191+
from torchvision.datasets import ImageFolder
192+
import numpy as np
193+
194+
# Load a trained model from checkpoint
195+
cropmodel = CropModel.load_from_checkpoint("path/to/checkpoint.ckpt")
196+
197+
# The model will automatically load:
198+
# - The model architecture and weights
199+
# - The label dictionary mapping class names to indices
200+
# - The number of classes
201+
# - Any hyperparameters saved during training
202+
```
203+
204+
### Making Predictions on a Dataset
205+
206+
```python
207+
# Create a validation dataset
208+
from torchvision.datasets import ImageFolder
209+
val_ds = ImageFolder(root=root_dir, transform=cropmodel.get_transform(augment=False))
210+
211+
# Get predictions and labels
212+
images, labels, predictions = cropmodel.val_dataset_confusion(return_images=True)
213+
214+
# Create dataloader
215+
crop_dataloader = cropmodel.predict_dataloader(val_ds)
216+
217+
# Run prediction
218+
trainer = Trainer(
219+
gpus=1,
220+
accelerator="gpu",
221+
max_epochs=1,
222+
logger=False,
223+
enable_checkpointing=False
224+
)
225+
crop_results = trainer.predict(cropmodel, crop_dataloader)
226+
227+
# Process results using the built-in postprocessing method
228+
label, score = cropmodel.postprocess_predictions(crop_results)
229+
230+
# Convert numeric labels to class names
231+
label_names = [cropmodel.numeric_to_label_dict[x] for x in label]
232+
```
233+
234+
### Making Predictions on Single Images
235+
236+
You can also make predictions on individual images or batches:
237+
238+
```python
239+
import torch
240+
from PIL import Image
241+
242+
# Load and preprocess a single image
243+
image = Image.open("path/to/image.jpg")
244+
transform = cropmodel.get_transform(augment=False)
245+
tensor = transform(image).unsqueeze(0) # Add batch dimension
246+
247+
# Make prediction
248+
with torch.no_grad():
249+
output = cropmodel(tensor)
250+
# Convert to numpy for postprocessing
251+
output = output.cpu().numpy()
252+
# Use the same postprocessing method
253+
label, score = cropmodel.postprocess_predictions([output])
254+
class_name = cropmodel.numeric_to_label_dict[label[0]]
255+
confidence = score[0]
256+
```
257+
258+
## Model Architecture and Training
259+
260+
The CropModel uses a ResNet-50 backbone by default, but can be customized with any PyTorch model. The model includes:
261+
262+
- A classification head with the specified number of classes
263+
- Standard image preprocessing (resize to 224x224, normalization)
264+
- Data augmentation during training (random horizontal flips)
265+
- Accuracy and precision metrics for evaluation
266+
267+
### Training Process
268+
269+
```python
270+
# Initialize model
271+
crop_model = CropModel(num_classes=2)
272+
273+
# Create trainer
274+
crop_model.create_trainer(
275+
max_epochs=10,
276+
accelerator="gpu",
277+
devices=1
278+
)
279+
280+
# Load data
281+
crop_model.load_from_disk(
282+
train_dir="path/to/train",
283+
val_dir="path/to/val"
284+
)
285+
286+
# Train
287+
crop_model.trainer.fit(crop_model)
288+
289+
# Validate
290+
crop_model.trainer.validate(crop_model)
291+
292+
# Save checkpoint
293+
crop_model.trainer.save_checkpoint("model.ckpt")
294+
```
295+
296+
### Evaluation
297+
298+
The model provides several evaluation metrics:
299+
300+
```python
301+
# Get validation metrics
302+
metrics = crop_model.trainer.validate(crop_model)
303+
304+
# Get confusion matrix
305+
images, labels, predictions = crop_model.val_dataset_confusion(return_images=True)
306+
```
307+
308+
### Confusion Matrix Visualization
309+
310+
You can visualize the confusion matrix in several ways:
311+
312+
```python
313+
import matplotlib.pyplot as plt
314+
from torchmetrics.classification import MulticlassConfusionMatrix
315+
import seaborn as sns
316+
317+
# Method 1: Using torchmetrics
318+
metric = MulticlassConfusionMatrix(num_classes=crop_model.num_classes)
319+
metric.update(preds=predictions, target=labels)
320+
fig, ax = metric.plot()
321+
plt.title("Confusion Matrix")
322+
plt.show()
323+
324+
# Method 2: Using seaborn with val_dataset_confusion
325+
images, labels, predictions = crop_model.val_dataset_confusion(return_images=True)
326+
confusion_matrix = np.zeros((crop_model.num_classes, crop_model.num_classes))
327+
for true, pred in zip(labels, predictions):
328+
confusion_matrix[true][pred] += 1
329+
330+
# Plot with seaborn
331+
plt.figure(figsize=(10, 8))
332+
sns.heatmap(confusion_matrix,
333+
annot=True,
334+
fmt='g',
335+
xticklabels=list(crop_model.label_dict.keys()),
336+
yticklabels=list(crop_model.label_dict.keys()))
337+
plt.title("Confusion Matrix")
338+
plt.xlabel("Predicted")
339+
plt.ylabel("True")
340+
plt.show()
341+
342+
# Get per-class metrics
343+
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall, MulticlassF1Score
344+
345+
precision = MulticlassPrecision(num_classes=crop_model.num_classes)
346+
recall = MulticlassRecall(num_classes=crop_model.num_classes)
347+
f1 = MulticlassF1Score(num_classes=crop_model.num_classes)
348+
349+
precision_score = precision(torch.tensor(predictions), torch.tensor(labels))
350+
recall_score = recall(torch.tensor(predictions), torch.tensor(labels))
351+
f1_score = f1(torch.tensor(predictions), torch.tensor(labels))
352+
353+
print(f"Precision: {precision_score:.3f}")
354+
print(f"Recall: {recall_score:.3f}")
355+
print(f"F1 Score: {f1_score:.3f}")
356+
```
357+
358+
This will give you a comprehensive view of your model's performance, including:
359+
- A visual confusion matrix showing true vs predicted classes
360+
- Per-class precision, recall, and F1 scores
361+
- The ability to identify which classes are most commonly confused with each other
362+
363+
The confusion matrix is particularly useful for:
364+
- Identifying class imbalance issues
365+
- Finding classes that are frequently confused
366+
- Understanding the model's strengths and weaknesses
367+
- Guiding decisions about data collection and model improvement
368+
369+
## Advanced Usage
370+
371+
### Custom Model Architecture
372+
373+
You can use any PyTorch model as the backbone:
374+
375+
```python
376+
from torchvision.models import resnet101
377+
378+
# Initialize with custom model
379+
backbone = resnet101(weights='DEFAULT')
380+
crop_model = CropModel(
381+
num_classes=2,
382+
model=backbone
383+
)
384+
```
385+
386+
### Custom Training Loop
387+
388+
You can subclass CropModel to customize the training process:
389+
390+
```python
391+
class CustomCropModel(CropModel):
392+
def training_step(self, batch, batch_idx):
393+
x, y = batch
394+
outputs = self.forward(x)
395+
loss = F.cross_entropy(outputs, y)
396+
397+
# Add custom metrics
398+
self.log("custom_metric", value)
399+
400+
return loss
180401
```

docs/user_guide/07_scaling.md

Lines changed: 13 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -34,68 +34,28 @@ https://lightning.ai/docs/pytorch/latest/clouds/cluster_advanced.html#troublesho
3434

3535
## Prediction
3636

37-
Often we have a large number of tiles we want to predict. DeepForest uses [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) to scale inference. This gives us access to powerful tools for scaling without any changes to user code. DeepForest automatically detects whether you are running on GPU or CPU. The parallelization strategy is to run each tile on a separate GPU, we cannot parallelize crops from within the same tile across GPUs inside of main.predict_tile(). If you set m.create_trainer(accelerator="gpu", devices=4), and run predict_tile, you will only use 1 GPU per tile. This is because we need access to all crops to create a mosiac of the predictions.
37+
Often we have a large number of tiles we want to predict. DeepForest uses [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) to scale inference. This gives us access to powerful tools for scaling without any changes to user code. DeepForest automatically detects whether you are running on GPU or CPU.
3838

39-
### Scaling prediction across multiple GPUs
39+
There are three dataset strategies that *balance cpu memory, gpu memory, and gpu utilization* using batch sizes.
4040

41-
There are a few situations in which it is useful to replicate the DeepForest module across many separate Python processes. This is especially helpful when we have a series of non-interacting tasks, often called 'embarrassingly parallel' processes. In these cases, no DeepForest instance needs to communicate with another instance. Rather than coordinating GPUs with the associated annoyance of overhead and backend errors, we can just launch separate jobs and let them finish on their own. One helpful tool in Python is [Dask](https://www.dask.org/). Dask is a wonderful open-source tool for coordinating large-scale jobs. Dask can be run locally, across multiple machines, and with an arbitrary set of resources.
41+
```python
42+
prediction_single = m.predict_tile(path=path, patch_size=300, dataloader_strategy="single")
43+
```
44+
The `dataloader_strategy` parameter has three options:
4245

43-
### Example Dask and DeepForest integration using SLURM
46+
* **single**: Loads the entire image into CPU memory and passes individual windows to GPU.
4447

45-
Imagine we have a list of images we want to predict using `deepforest.main.predict_tile()`. DeepForest does not allow multi-GPU inference within each tile, as it is too much of a headache to make sure the threads return the correct overlapping window. Instead, we can parallelize across tiles, such that each GPU takes a tile and performs an action. The general structure is to create a Dask client across multiple GPUs, submit each DeepForest `predict_tile()` instance, and monitor the results. In this example, we are using a SLURMCluster, a common job scheduler for large clusters. There are many similar ways to create a Dask client object that will be specific to a particular organization. The following arguments are specific to the University of Florida cluster, but will be largely similar to other SLURM naming conventions. We use the extra Dask package, `dask-jobqueue`, which helps format the call.
48+
* **batch**: Loads the entire image into GPU memory and creates views of the image as batches. Requires the entire tile to fit into GPU memory. CPU parallelization is possible for loading images.
4649

50+
* **window**: Loads only the desired window of the image from the raster dataset. Most memory efficient option, but cannot parallelize across windows due to Python's Global Interpreter Lock, workers must be set to 0.
4751

48-
```python
49-
from dask_jobqueue import SLURMCluster
50-
from dask.distributed import Client
51-
52-
cluster = SLURMCluster(processes=1,
53-
cores=10,
54-
memory="40 GB",
55-
walltime='24:00:00',
56-
job_extra=extra_args,
57-
extra=['--resources gpu=1'],
58-
nanny=False,
59-
scheduler_options={"dashboard_address": ":8787"},
60-
local_directory="/orange/idtrees-collab/tmp/",
61-
death_timeout=100)
62-
print(cluster.job_script())
63-
cluster.scale(10)
64-
65-
dask_client = Client(cluster)
66-
```
52+
## Data Loading
6753

68-
This job script gets a single GPUs with "40GB" of memory with 10 cpus. We then ask for 10 instances of this setup.
69-
Now that we have a dask client, we can send our custom function.
54+
DeepForest uses PyTorch's DataLoader for efficient data loading. One important parameter for scaling is the number of CPU workers, which controls parallel data loading using multiple CPU processes. This can be set
7055

71-
```python
72-
import os
73-
from deepforest import main
74-
75-
def function_to_parallelize(tile):
76-
m = main.deepforest()
77-
m.load_model("weecology/deepforest-tree") # sub in the custom logic to load your own models
78-
boxes = m.predict_tile(raster_path=tile)
79-
# save the predictions using the tile pathname
80-
filename = "{}.csv".format(os.path.splitext(os.path.basename(tile))[0])
81-
filename = os.path.join(<savedir>,filename)
82-
boxes.to_csv(filename)
83-
84-
return filename
8556
```
86-
87-
```python
88-
tiles = [<list of tiles to predict>]
89-
futures = []
90-
for tile in tiles:
91-
future = client.submit(function_to_parallelize, tile)
92-
futures.append(future)
57+
m.config["workers"] = 10
9358
```
59+
0 workers runs without multiprocessing, workers > 1 runs with multiprocessing. Increase this value slowly, as IO constraints can lead to deadlocks among workers.
9460

95-
We can wait to see the futures as they complete! Dask also has a beautiful visualization tool using bokeh.
9661

97-
```python
98-
for x in futures:
99-
completed_filename = x.result()
100-
print(completed_filename)
101-
```

0 commit comments

Comments
 (0)