Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<?php

declare(strict_types=1);

namespace Rubix\ML\NeuralNet\CostFunctions\LeastSquares\Base\Contracts;

use NDArray;
use Stringable;

/**
* Cost Function
*
* @category Machine Learning
* @package Rubix/ML
* @author Samuel Akopyan <[email protected]>
*/
interface CostFunction extends Stringable
{
/**
* Compute the loss score.
*
* @internal
*
* @param NDArray $output
* @param NDArray $target
* @return float
*/
public function compute(NDArray $output, NDArray $target) : float;

/**
* Calculate the gradient of the cost function with respect to the output.
*
* @internal
*
* @param NDArray $output
* @param NDArray $target
* @return NDArray
*/
public function differentiate(NDArray $output, NDArray $target) : NDArray;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<?php

declare(strict_types=1);

namespace Rubix\ML\NeuralNet\CostFunctions\LeastSquares\Base\Contracts;

use Rubix\ML\NeuralNet\CostFunctions\LeastSquares\Base\Contracts\CostFunction;

interface RegressionLoss extends CostFunction
{
//
}
83 changes: 83 additions & 0 deletions src/NeuralNet/CostFunctions/LeastSquares/LeastSquares.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
<?php

declare(strict_types=1);

namespace Rubix\ML\NeuralNet\CostFunctions\LeastSquares;

use InvalidArgumentException;
use NumPower;
use NDArray;
use Rubix\ML\NeuralNet\CostFunctions\LeastSquares\Base\Contracts\RegressionLoss;

/**
* Least Squares
*
* Least Squares or *quadratic* loss is a function that measures the squared
* error between the target output and the actual output of a network.
*
* @category Machine Learning
* @package Rubix/ML
* @author Samuel Akopyan <[email protected]>
*/
class LeastSquares implements RegressionLoss
{
/**
* Compute the loss score.
*
* L(y, ŷ) = Σ(y - ŷ)^2 / n
*
* @internal
*
* @param NDArray $output The output of the network
* @param NDArray $target The target values
* @return float
*/
public function compute(NDArray $output, NDArray $target) : float
{
if ($output->shape() !== $target->shape()) {
throw new InvalidArgumentException('Output and target must have the same shape.');
}

// Compute difference: output - target
$diff = NumPower::subtract($output, $target);

// Square the difference: diff^2
$squared = NumPower::pow($diff, 2);

// Compute mean of all elements
return NumPower::mean($squared);
}

/**
* Calculate the gradient of the cost function with respect to the output.
*
* ∂L/∂ŷ = y - ŷ
*
* @internal
*
* @param NDArray $output The output of the network
* @param NDArray $target The target values
* @return NDArray
*/
public function differentiate(NDArray $output, NDArray $target) : NDArray
{
if ($output->shape() !== $target->shape()) {
throw new InvalidArgumentException('Output and target must have the same shape.');
}

// Gradient is simply: output - target
return NumPower::subtract($output, $target);
}

/**
* Return the string representation of the object.
*
* @internal
*
* @return string
*/
public function __toString() : string
{
return 'Least Squares';
}
}
206 changes: 206 additions & 0 deletions tests/NeuralNet/CostFunctions/LeastSquares/LeastSquaresTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
<?php

declare(strict_types = 1);

namespace Rubix\ML\Tests\NeuralNet\CostFunctions\LeastSquares;

use InvalidArgumentException;
use PHPUnit\Framework\Attributes\CoversClass;
use PHPUnit\Framework\Attributes\DataProvider;
use PHPUnit\Framework\Attributes\Group;
use PHPUnit\Framework\Attributes\Test;
use PHPUnit\Framework\Attributes\TestDox;
use NumPower;
use NDArray;
use Rubix\ML\NeuralNet\CostFunctions\LeastSquares\LeastSquares;
use PHPUnit\Framework\TestCase;
use Generator;

#[Group('CostFunctions')]
#[CoversClass(LeastSquares::class)]
class LeastSquaresTest extends TestCase
{
/**
* @var LeastSquares
*/
protected LeastSquares $costFn;

/**
* @return Generator<array>
*/
public static function computeProvider() : Generator
{
yield [
NumPower::array([]),
NumPower::array([]),
NAN,
];

yield [
NumPower::array([
[0.99],
]),
NumPower::array([
[1.0],
]),
0.0001000,
];

yield [
NumPower::array([
[1000.0],
]),
NumPower::array([
[1.0],
]),
998001.0,
];

yield [
NumPower::array([
[33.98],
[20.0],
[4.6],
[44.2],
[38.5],
]),
NumPower::array([
[36.0],
[22.0],
[18.0],
[41.5],
[38.0],
]),
39.0360794,
];
}

/**
* @return Generator<array>
*/
public static function differentiateProvider() : Generator
{
yield [
NumPower::array([
[0.99],
]),
NumPower::array([
[1.0],
]),
[
[-0.0099999],
],
];

yield [
NumPower::array([
[1000.0],
]),
NumPower::array([
[1.0],
]),
[
[999.0],
],
];

yield [
NumPower::array([
[33.98],
[20.0],
[4.6],
[44.2],
[38.5],
]),
NumPower::array([
[36.0],
[22.0],
[18.0],
[41.5],
[38.0],
]),
[
[-2.0200004],
[-2.0],
[-13.3999996],
[2.7000007],
[0.5],
],
];
}

protected function setUp() : void
{
$this->costFn = new LeastSquares();
}

#[Test]
#[TestDox('Can be cast to a string')]
public function testToString() : void
{
static::assertEquals('Least Squares', (string) $this->costFn);
}

#[Test]
#[TestDox('Throws exception when output and target shapes do not match in compute')]
public function testComputeThrowsExceptionOnShapeMismatch() : void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Output and target must have the same shape.');

$output = NumPower::array([[1.0, 2.0, 3.0]]);
$target = NumPower::array([[1.0, 2.0]]);

$this->costFn->compute($output, $target);
}

#[Test]
#[TestDox('Throws exception when output and target shapes do not match in differentiate')]
public function testDifferentiateThrowsExceptionOnShapeMismatch() : void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('Output and target must have the same shape.');

$output = NumPower::array([[1.0, 2.0, 3.0]]);
$target = NumPower::array([[1.0, 2.0]]);

$this->costFn->differentiate($output, $target);
}

/**
* @param NDArray $output
* @param NDArray $target
* @param float $expected
*/
#[Test]
#[TestDox('Compute loss score')]
#[DataProvider('computeProvider')]
public function testCompute(NDArray $output, NDArray $target, float $expected) : void
{
$loss = $this->costFn->compute($output, $target);

if (is_nan($expected)) {
self::assertNan($loss);
} else {
self::assertEqualsWithDelta($expected, $loss, 1e-7);
}
}

/**
* @param NDArray $output
* @param NDArray $target
* @param list<list<float>> $expected
*/
#[Test]
#[TestDox('Calculate gradient of cost function')]
#[DataProvider('differentiateProvider')]
public function testDifferentiate(NDArray $output, NDArray $target, array $expected) : void
{
$gradient = $this->costFn->differentiate($output, $target);

// Convert NDArray to PHP array for comparison
$gradientArray = $gradient->toArray();

self::assertEqualsWithDelta($expected, $gradientArray, 1e-7);
}
}
Loading