Skip to content

Commit 936b448

Browse files
Subject: Add R2Score metric. (#8169) (#8353)
Body: FEATURE Co-authored-by: Matthew Soulanille <[email protected]>
1 parent 0677375 commit 936b448

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

tfjs-layers/src/exports_metrics.ts

+19
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,22 @@ export function MSE(yTrue: Tensor, yPred: Tensor): Tensor {
314314
export function mse(yTrue: Tensor, yPred: Tensor): Tensor {
315315
return losses.meanSquaredError(yTrue, yPred);
316316
}
317+
318+
/**
319+
* Computes R2 score.
320+
*
321+
* ```js
322+
* const yTrue = tf.tensor2d([[0, 1], [3, 4]]);
323+
* const yPred = tf.tensor2d([[0, 1], [-3, -4]]);
324+
* const r2Score = tf.metrics.r2Score(yTrue, yPred);
325+
* r2Score.print();
326+
* ```
327+
* @param yTrue Truth Tensor.
328+
* @param yPred Prediction Tensor.
329+
* @return R2 score Tensor.
330+
*
331+
* @doc {heading: 'Metrics', namespace: 'metrics'}
332+
*/
333+
export function r2Score(yTrue: Tensor, yPred: Tensor): Tensor {
334+
return metrics.r2Score(yTrue, yPred);
335+
}

tfjs-layers/src/metrics.ts

+9-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ import {Tensor, tidy} from '@tensorflow/tfjs-core';
1717

1818
import * as K from './backend/tfjs_backend';
1919
import {NotImplementedError, ValueError} from './errors';
20-
import {categoricalCrossentropy as categoricalCrossentropyLoss, cosineProximity, meanAbsoluteError, meanAbsolutePercentageError, meanSquaredError, sparseCategoricalCrossentropy as sparseCategoricalCrossentropyLoss} from './losses';
21-
import {binaryCrossentropy as lossBinaryCrossentropy} from './losses';
22-
import {lossesMap} from './losses';
20+
import {binaryCrossentropy as lossBinaryCrossentropy, categoricalCrossentropy as categoricalCrossentropyLoss, cosineProximity, lossesMap, meanAbsoluteError, meanAbsolutePercentageError, meanSquaredError, sparseCategoricalCrossentropy as sparseCategoricalCrossentropyLoss} from './losses';
2321
import {LossOrMetricFn} from './types';
2422
import * as util from './utils/generic_utils';
2523

@@ -112,6 +110,14 @@ export function sparseTopKCategoricalAccuracy(
112110
throw new NotImplementedError();
113111
}
114112

113+
export function r2Score(yTrue: Tensor, yPred: Tensor): Tensor {
114+
return tidy(() => {
115+
const sumSquaresResiduals = yTrue.sub(yPred).square().sum();
116+
const sumSquares = yTrue.sub(yTrue.mean()).square().sum();
117+
return tfc.scalar(1).sub(sumSquaresResiduals.div(sumSquares));
118+
});
119+
}
120+
115121
// Aliases.
116122
export const mse = meanSquaredError;
117123
export const MSE = meanSquaredError;

tfjs-layers/src/metrics_test.ts

+22-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import {scalar, Tensor, tensor, tensor1d, tensor2d} from '@tensorflow/tfjs-core'
1616

1717
import {setEpsilon} from './backend/common';
1818
import * as tfl from './index';
19-
import {binaryAccuracy, categoricalAccuracy, get, getLossOrMetricName} from './metrics';
19+
import {binaryAccuracy, categoricalAccuracy, get, getLossOrMetricName, r2Score} from './metrics';
2020
import {LossOrMetricFn} from './types';
2121
import {describeMathCPUAndGPU, describeMathCPUAndWebGL2, expectTensorsClose} from './utils/test_utils';
2222

@@ -283,6 +283,27 @@ describeMathCPUAndGPU('recall metric', () => {
283283
});
284284
});
285285

286+
describeMathCPUAndGPU('r2Score', () => {
287+
it('1D', () => {
288+
const yTrue = tensor1d([3, -0.5, 2, 7, 4.2, 8.5, 1.3, 2.8, 6.7, 9.0]);
289+
const yPred = tensor1d([2.5, 0.0, 2.1, 7.8, 4.0, 8.2, 1.4, 2.9, 6.5, 9.1]);
290+
const score = r2Score(yTrue, yPred);
291+
expectTensorsClose(score, scalar(0.985));
292+
});
293+
it('2D', () => {
294+
const yTrue = tensor2d([
295+
[3, 2.5], [-0.5, 3.2], [2, 1.9], [7, 5.1], [4.2, 3.8], [8.5, 7.4],
296+
[1.3, 0.6], [2.8, 2.1], [6.7, 5.3], [9.0, 8.7]
297+
]);
298+
const yPred = tensor2d([
299+
[2.7, 2.3], [0.0, 3.1], [2.1, 1.8], [6.8, 5.0], [4.1, 3.7], [8.4, 7.2],
300+
[1.4, 0.7], [2.9, 2.2], [6.6, 5.2], [9.2, 8.9]
301+
]);
302+
const score = r2Score(yTrue, yPred);
303+
expectTensorsClose(score, scalar(0.995));
304+
});
305+
});
306+
286307
describe('metrics.get', () => {
287308
it('valid name, not alias', () => {
288309
expect(get('binaryAccuracy') === get('categoricalAccuracy')).toEqual(false);

0 commit comments

Comments
 (0)