Skip to content

Commit aa14065

Browse files
Fix #7104 - tf.initializers.<random | glorot | he | leCunn>Uniform() ignores seed argument & add tests that replicated the issue, fix wrong serialization name registered for LeCunUniform initializer class (#7108)
Fix #7104 - tf.initializers.<random | glorot | he | leCunn>Uniform() ignores seed argument & add tests that replicated the issue, fix wrong serialization name registered for LeCunUniform initializer class
1 parent bea721d commit aa14065

File tree

2 files changed

+191
-3
lines changed

2 files changed

+191
-3
lines changed

tfjs-layers/src/initializers.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ export class RandomUniform extends Initializer {
128128
}
129129

130130
apply(shape: Shape, dtype?: DataType): Tensor {
131-
return randomUniform(shape, this.minval, this.maxval, dtype);
131+
return randomUniform(shape, this.minval, this.maxval, dtype, this.seed);
132132
}
133133

134134
override getConfig(): serialization.ConfigDict {
@@ -352,7 +352,7 @@ export class VarianceScaling extends Initializer {
352352
return truncatedNormal(shape, 0, stddev, dtype, this.seed);
353353
} else {
354354
const limit = Math.sqrt(3 * scale);
355-
return randomUniform(shape, -limit, limit, dtype);
355+
return randomUniform(shape, -limit, limit, dtype, this.seed);
356356
}
357357
}
358358

@@ -498,7 +498,7 @@ serialization.registerClass(LeCunNormal);
498498

499499
export class LeCunUniform extends VarianceScaling {
500500
/** @nocollapse */
501-
static override className = 'LeCunNormal';
501+
static override className = 'LeCunUniform';
502502

503503
constructor(args?: SeedOnlyInitializerArgs) {
504504
super({

tfjs-layers/src/initializers_test.ts

+188
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,25 @@ describeMathCPU('RandomUniform initializer', () => {
180180
expect(weights.dtype).toEqual('float32');
181181
expectTensorsValuesInRange(weights, 17, 47);
182182
});
183+
184+
it('with configured seed', () => {
185+
186+
const initializerConfig: serialization.ConfigDict = {
187+
className: 'RandomUniform',
188+
config: { seed: 42 }
189+
};
190+
191+
const expectedInitializer = getInitializer(initializerConfig);
192+
const actualInitializer = getInitializer(initializerConfig);
193+
194+
const expected = expectedInitializer.apply(shape, 'float32');
195+
const actual = actualInitializer.apply(shape, 'float32');
196+
197+
expect(actual.shape).toEqual(expected.shape);
198+
expect(actual.dtype).toEqual(expected.dtype);
199+
expectTensorsClose(actual, expected);
200+
});
201+
183202
it('Does not leak', () => {
184203
expectNoLeakedTensors(() => getInitializer('RandomUniform').apply([3]), 1);
185204
});
@@ -214,6 +233,25 @@ describeMathCPU('RandomNormal initializer', () => {
214233
expect(weights.dtype).toEqual('float32');
215234
// TODO(bileschi): Add test to assert the values match expectations.
216235
});
236+
237+
it('with configured seed', () => {
238+
239+
const initializerConfig: serialization.ConfigDict = {
240+
className: 'RandomNormal',
241+
config: { seed: 42 }
242+
};
243+
244+
const expectedInitializer = getInitializer(initializerConfig);
245+
const actualInitializer = getInitializer(initializerConfig);
246+
247+
const expected = expectedInitializer.apply(shape, 'float32');
248+
const actual = actualInitializer.apply(shape, 'float32');
249+
250+
expect(actual.shape).toEqual(expected.shape);
251+
expect(actual.dtype).toEqual(expected.dtype);
252+
expectTensorsClose(actual, expected);
253+
});
254+
217255
it('Does not leak', () => {
218256
expectNoLeakedTensors(() => getInitializer('RandomNormal').apply([3]), 1);
219257
});
@@ -239,6 +277,24 @@ describeMathCPU('HeNormal initializer', () => {
239277
expectTensorsValuesInRange(weights, -2 * stddev, 2 * stddev);
240278
});
241279

280+
it('with configured seed', () => {
281+
282+
const initializerConfig: serialization.ConfigDict = {
283+
className: 'HeNormal',
284+
config: { seed: 42 }
285+
};
286+
287+
const expectedInitializer = getInitializer(initializerConfig);
288+
const actualInitializer = getInitializer(initializerConfig);
289+
290+
const expected = expectedInitializer.apply(shape, 'float32');
291+
const actual = actualInitializer.apply(shape, 'float32');
292+
293+
expect(actual.shape).toEqual(expected.shape);
294+
expect(actual.dtype).toEqual(expected.dtype);
295+
expectTensorsClose(actual, expected);
296+
});
297+
242298
it('Does not leak', () => {
243299
expectNoLeakedTensors(() => getInitializer('HeNormal').apply([3]), 1);
244300
});
@@ -264,6 +320,24 @@ describeMathCPU('HeUniform initializer', () => {
264320
expectTensorsValuesInRange(weights, -bound, bound);
265321
});
266322

323+
it('with configured seed', () => {
324+
325+
const initializerConfig: serialization.ConfigDict = {
326+
className: 'HeUniform',
327+
config: { seed: 42 }
328+
};
329+
330+
const expectedInitializer = getInitializer(initializerConfig);
331+
const actualInitializer = getInitializer(initializerConfig);
332+
333+
const expected = expectedInitializer.apply(shape, 'float32');
334+
const actual = actualInitializer.apply(shape, 'float32');
335+
336+
expect(actual.shape).toEqual(expected.shape);
337+
expect(actual.dtype).toEqual(expected.dtype);
338+
expectTensorsClose(actual, expected);
339+
});
340+
267341
it('Does not leak', () => {
268342
expectNoLeakedTensors(() => getInitializer('heUniform').apply([3]), 1);
269343
});
@@ -289,6 +363,24 @@ describeMathCPU('LecunNormal initializer', () => {
289363
expectTensorsValuesInRange(weights, -2 * stddev, 2 * stddev);
290364
});
291365

366+
it('with configured seed', () => {
367+
368+
const initializerConfig: serialization.ConfigDict = {
369+
className: 'LeCunNormal',
370+
config: { seed: 42 }
371+
};
372+
373+
const expectedInitializer = getInitializer(initializerConfig);
374+
const actualInitializer = getInitializer(initializerConfig);
375+
376+
const expected = expectedInitializer.apply(shape, 'float32');
377+
const actual = actualInitializer.apply(shape, 'float32');
378+
379+
expect(actual.shape).toEqual(expected.shape);
380+
expect(actual.dtype).toEqual(expected.dtype);
381+
expectTensorsClose(actual, expected);
382+
});
383+
292384
it('Does not leak', () => {
293385
expectNoLeakedTensors(() => getInitializer('LeCunNormal').apply([3]), 1);
294386
});
@@ -314,6 +406,24 @@ describeMathCPU('LeCunUniform initializer', () => {
314406
expectTensorsValuesInRange(weights, -bound, bound);
315407
});
316408

409+
it('with configured seed', () => {
410+
411+
const initializerConfig: serialization.ConfigDict = {
412+
className: 'LeCunUniform',
413+
config: { seed: 42 }
414+
};
415+
416+
const expectedInitializer = getInitializer(initializerConfig);
417+
const actualInitializer = getInitializer(initializerConfig);
418+
419+
const expected = expectedInitializer.apply(shape, 'float32');
420+
const actual = actualInitializer.apply(shape, 'float32');
421+
422+
expect(actual.shape).toEqual(expected.shape);
423+
expect(actual.dtype).toEqual(expected.dtype);
424+
expectTensorsClose(actual, expected);
425+
});
426+
317427
it('Does not leak', () => {
318428
expectNoLeakedTensors(() => getInitializer('LeCunUniform').apply([3]), 1);
319429
});
@@ -348,6 +458,25 @@ describeMathCPU('TruncatedNormal initializer', () => {
348458
expect(weights.dtype).toEqual('float32');
349459
expectTensorsValuesInRange(weights, 0.0, 2.0);
350460
});
461+
462+
it('with configured seed', () => {
463+
464+
const initializerConfig: serialization.ConfigDict = {
465+
className: 'TruncatedNormal',
466+
config: { seed: 42 }
467+
};
468+
469+
const expectedInitializer = getInitializer(initializerConfig);
470+
const actualInitializer = getInitializer(initializerConfig);
471+
472+
const expected = expectedInitializer.apply(shape, 'float32');
473+
const actual = actualInitializer.apply(shape, 'float32');
474+
475+
expect(actual.shape).toEqual(expected.shape);
476+
expect(actual.dtype).toEqual(expected.dtype);
477+
expectTensorsClose(actual, expected);
478+
});
479+
351480
it('Does not leak', () => {
352481
expectNoLeakedTensors(
353482
() => getInitializer('TruncatedNormal').apply([3]), 1);
@@ -403,6 +532,25 @@ describeMathCPU('Glorot uniform initializer', () => {
403532
.toBeGreaterThan(-limit);
404533
});
405534
});
535+
536+
it('with configured seed', () => {
537+
538+
const initializerConfig: serialization.ConfigDict = {
539+
className: 'GlorotUniform',
540+
config: { seed: 42 }
541+
};
542+
543+
const expectedInitializer = getInitializer(initializerConfig);
544+
const actualInitializer = getInitializer(initializerConfig);
545+
546+
const expected = expectedInitializer.apply([7, 2], 'float32');
547+
const actual = actualInitializer.apply([7, 2], 'float32');
548+
549+
expect(actual.shape).toEqual(expected.shape);
550+
expect(actual.dtype).toEqual(expected.dtype);
551+
expectTensorsClose(actual, expected);
552+
});
553+
406554
it('Does not leak', () => {
407555
expectNoLeakedTensors(() => getInitializer('GlorotUniform').apply([3]), 1);
408556
});
@@ -429,6 +577,27 @@ describeMathCPU('VarianceScaling initializer', () => {
429577
const newConfig = newInit.getConfig();
430578
expect(newConfig['distribution']).toEqual(baseConfig['distribution']);
431579
});
580+
581+
it(`${distribution} with configured seed`, () => {
582+
583+
const initializerConfig: serialization.ConfigDict = {
584+
className: 'VarianceScaling',
585+
config: {
586+
distribution,
587+
seed: 42
588+
}
589+
};
590+
591+
const expectedInitializer = getInitializer(initializerConfig);
592+
const actualInitializer = getInitializer(initializerConfig);
593+
594+
const expected = expectedInitializer.apply([7, 2], 'float32');
595+
const actual = actualInitializer.apply([7, 2], 'float32');
596+
597+
expect(actual.shape).toEqual(expected.shape);
598+
expect(actual.dtype).toEqual(expected.dtype);
599+
expectTensorsClose(actual, expected);
600+
});
432601
});
433602
});
434603

@@ -485,6 +654,25 @@ describeMathCPU('Glorot normal initializer', () => {
485654
expect(variance2).toBeLessThan(variance1);
486655
});
487656
});
657+
658+
it('with configured seed', () => {
659+
660+
const initializerConfig: serialization.ConfigDict = {
661+
className: 'GlorotNormal',
662+
config: { seed: 42 }
663+
};
664+
665+
const expectedInitializer = getInitializer(initializerConfig);
666+
const actualInitializer = getInitializer(initializerConfig);
667+
668+
const expected = expectedInitializer.apply([7, 2], 'float32');
669+
const actual = actualInitializer.apply([7, 2], 'float32');
670+
671+
expect(actual.shape).toEqual(expected.shape);
672+
expect(actual.dtype).toEqual(expected.dtype);
673+
expectTensorsClose(actual, expected);
674+
});
675+
488676
it('Does not leak', () => {
489677
expectNoLeakedTensors(() => getInitializer('GlorotNormal').apply([3]), 1);
490678
});

0 commit comments

Comments
 (0)