Skip to content

Commit

Permalink
Add generic test for classify method of concrete openml providers
Browse files Browse the repository at this point in the history
The method classify was returning opposite value in the feedzai/openml-java project.
In order to improve UTs, the method AbstractProviderModelBaseTest#classifyIndexOfMaxScoresValue() was added in order to test the method #classify for all the providers.
  • Loading branch information
shengwangsw committed Oct 9, 2019
1 parent af0681f commit 02febe6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ public interface ClassificationMLModel extends MachineLearningModel {
* @param instance The {@link Instance} to be classified.
* @return The index of the class nominal value according to the {@link DatasetSchema}
* provided during training of the model.
*
* @deprecated The idea is to classify the biggest value of the class probabilities distribution obtained from #getClassDistribution(),
* We no longer need this because we can just obtain the biggest value from the class probabilities distribution itself.
*
*/
@Deprecated
int classify(Instance instance);

}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@ public abstract class AbstractProviderModelBaseTest<M extends ClassificationMLMo
* TESTS *
* * * * * * * * * * * * * * * * * * * * * * * * * * */


/**
* Verifies that the {@link ClassificationMLModel#classify(Instance)} " returns the index of the greatest value in
* the class probability distribution produced by the calling
* {@link ClassificationMLModel#getClassDistribution(Instance)} on the model
*
* @see ClassificationMLModel
*/
protected void canGetClassDistributionMaxValueIndex(final M model, final Instance instance){

final double[] scores = model.getClassDistribution(instance);

final int classificationIndex = model.classify(instance);

final double maxScore = Arrays.stream(scores).max().getAsDouble();

assertThat(Arrays.asList(ArrayUtils.toObject(scores)).indexOf(maxScore))
.as("The index of maximum value")
.isEqualTo(classificationIndex);
}


/**
* Checks that is possible to get a {@link MachineLearningProvider} given a valid provider and algorithm.
*/
Expand Down

0 comments on commit 02febe6

Please sign in to comment.