Skip to content

Commit b90a4ff

Browse files
committed
Transform labels now returns self for method chaining
1 parent d4e18f2 commit b90a4ff

File tree

5 files changed

+15
-17
lines changed

5 files changed

+15
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- Removed Robust Z Score tolerance parameter
1111
- Added slice method to Dataset API
1212
- Loda now performs density estimation on the fly
13+
- Transform labels now returns self for method chaining
1314

1415
- 0.0.12-beta
1516
- Added AdaMax neural network Optimizer

docs/datasets/labeled.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ array(2) {
106106
```
107107

108108
#### Transform
109-
Transform the labels in the dataset using a callback function:
109+
Transform the labels in the dataset using a callback function and return self for method chaining:
110110
```php
111-
public transformLabels(callable $fn) : void
111+
public transformLabels(callable $fn) : self
112112
```
113113

114114
> **Note:** The callback function is given a label as its only argument and should return the transformed label as a continuous or categorical value.
@@ -158,7 +158,7 @@ $filtered = $dataset->filterByLabel(function ($label)) {
158158
```
159159

160160
#### Sorting
161-
Sort the dataset by label:
161+
Sort the dataset by label and return self for method chaining:
162162
```php
163163
public sortByLabel(bool $descending = false) : self
164164
```

src/AnomalyDetectors/KDLOF.php

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,8 @@ public function predict(Dataset $dataset) : array
216216
*/
217217
public function rank(Dataset $dataset) : array
218218
{
219-
if ($this->tree->bare()) {
220-
throw new RuntimeException('The estimator has not'
221-
. ' been trained.');
219+
if ($this->tree->bare() or empty($this->lrds)) {
220+
throw new RuntimeException('The estimator has not been trained.');
222221
}
223222

224223
DatasetIsCompatibleWithEstimator::check($dataset, $this);
@@ -235,11 +234,6 @@ public function rank(Dataset $dataset) : array
235234
*/
236235
protected function localOutlierFactor(array $sample) : float
237236
{
238-
if (empty($this->lrds)) {
239-
throw new RuntimeException('Local reachability distances have'
240-
. ' not been computed, must train estimator first.');
241-
}
242-
243237
[$indices, $distances] = $this->tree->nearest($sample, $this->k);
244238

245239
$lrd = $this->localReachabilityDensity($indices, $distances);

src/Datasets/Labeled.php

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,13 @@ public function labelType() : ?int
192192
}
193193

194194
/**
195-
* Map labels to their new values.
195+
* Map labels to their new valuesa dn return self for method chaining.
196196
*
197197
* @param callable $callback
198198
* @throws \RuntimeException
199+
* @return self
199200
*/
200-
public function transformLabels(callable $callback) : void
201+
public function transformLabels(callable $callback) : self
201202
{
202203
$labels = array_map($callback, $this->labels);
203204

@@ -209,6 +210,8 @@ public function transformLabels(callable $callback) : void
209210
}
210211

211212
$this->labels = $labels;
213+
214+
return $this;
212215
}
213216

214217
/**
@@ -445,7 +448,7 @@ public function filterByLabel(callable $callback) : self
445448
* @param bool $descending
446449
* @return self
447450
*/
448-
public function sortByColumn(int $index, bool $descending = false)
451+
public function sortByColumn(int $index, bool $descending = false) : self
449452
{
450453
$order = $this->column($index);
451454

@@ -463,9 +466,9 @@ public function sortByColumn(int $index, bool $descending = false)
463466
* Sort the dataset in place by its labels.
464467
*
465468
* @param bool $descending
466-
* @return \Rubix\ML\Datasets\Dataset
469+
* @return self
467470
*/
468-
public function sortByLabel(bool $descending = false) : Dataset
471+
public function sortByLabel(bool $descending = false) : self
469472
{
470473
array_multisort(
471474
$this->labels,

src/Datasets/Unlabeled.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ public function filterByColumn(int $index, callable $callback) : self
257257
* @param bool $descending
258258
* @return self
259259
*/
260-
public function sortByColumn(int $index, bool $descending = false)
260+
public function sortByColumn(int $index, bool $descending = false) : self
261261
{
262262
$column = $this->column($index);
263263

0 commit comments

Comments
 (0)