-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,34 @@ | ||
# fruit-classification-pytorch | ||
Fruit classification using Kaggle Dataset [Fruit-360](https://www.kaggle.com/moltean/fruits) in pytorch | ||
# Fruit Classifier using Pytorch | ||
Fruit classification using Kaggle Dataset [Fruit-360](https://www.kaggle.com/moltean/fruits) in pytorch. | ||
This repository contains some code on : | ||
a) Creation of custom dataset using pytorch. Look at fruit.py to understand how the custom dataset can be prepared from a set of training and test images. | ||
b) Creation of a Network in pytorch which is simplier to create and try out any changes to it. | ||
c) Easy to train and test. | ||
|
||
## How to run the test | ||
Make sure to download the directory in a folder. | ||
# Training and Validation of Fruit-360 dataset. | ||
|
||
First load the datasets into npy files. | ||
## Step 1 | ||
The same concept applies to all different kinds of datasets. | ||
Firstly, load all the images that are downloaded from the above link and convert them into npy files. | ||
Advantage of using npy files is to use only 4 files named train_data.npy, train_labels.npy and validation_data.npy , validation_labels.npy | ||
rather than using thousands of files for pre-processing. | ||
|
||
To convert your training and validation dataset into npy files use the below script. | ||
|
||
``` | ||
python load_dataset.py --dataset-dir <Dataset Path> | ||
``` | ||
|
||
This creates train_data.npy, train_labels.npy, validation_data.npy, validation_labels.npy | ||
|
||
Using this, train and test the fruit dataset using train.py | ||
## Step 2 | ||
Use the dataset files that are created above, train the fruit classifier and evaluate the model. | ||
|
||
``` | ||
python train.py --data-dir <npy files folder> | ||
python train.py --data-dir <npy files folder> [--epochs <default:10>] | ||
``` | ||
|
||
This generates a log that trains the network for each epoch and finally do inference against the validation dataset spits out the validation accuracy. | ||
|
||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
lcskrishna
Author
Owner
|
||
|
||
|
I get this error :RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at /Users/distiller/project/conda/conda-bld/pytorch-nightly_1553836411291/work/aten/src/THNN/generic/ClassNLLCriterion.c:92