@@ -180,6 +180,25 @@ describeMathCPU('RandomUniform initializer', () => {
180
180
expect ( weights . dtype ) . toEqual ( 'float32' ) ;
181
181
expectTensorsValuesInRange ( weights , 17 , 47 ) ;
182
182
} ) ;
183
+
184
+ it ( 'with configured seed' , ( ) => {
185
+
186
+ const initializerConfig : serialization . ConfigDict = {
187
+ className : 'RandomUniform' ,
188
+ config : { seed : 42 }
189
+ } ;
190
+
191
+ const expectedInitializer = getInitializer ( initializerConfig ) ;
192
+ const actualInitializer = getInitializer ( initializerConfig ) ;
193
+
194
+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
195
+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
196
+
197
+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
198
+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
199
+ expectTensorsClose ( actual , expected ) ;
200
+ } ) ;
201
+
183
202
it ( 'Does not leak' , ( ) => {
184
203
expectNoLeakedTensors ( ( ) => getInitializer ( 'RandomUniform' ) . apply ( [ 3 ] ) , 1 ) ;
185
204
} ) ;
@@ -214,6 +233,25 @@ describeMathCPU('RandomNormal initializer', () => {
214
233
expect ( weights . dtype ) . toEqual ( 'float32' ) ;
215
234
// TODO(bileschi): Add test to assert the values match expectations.
216
235
} ) ;
236
+
237
+ it ( 'with configured seed' , ( ) => {
238
+
239
+ const initializerConfig : serialization . ConfigDict = {
240
+ className : 'RandomNormal' ,
241
+ config : { seed : 42 }
242
+ } ;
243
+
244
+ const expectedInitializer = getInitializer ( initializerConfig ) ;
245
+ const actualInitializer = getInitializer ( initializerConfig ) ;
246
+
247
+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
248
+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
249
+
250
+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
251
+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
252
+ expectTensorsClose ( actual , expected ) ;
253
+ } ) ;
254
+
217
255
it ( 'Does not leak' , ( ) => {
218
256
expectNoLeakedTensors ( ( ) => getInitializer ( 'RandomNormal' ) . apply ( [ 3 ] ) , 1 ) ;
219
257
} ) ;
@@ -239,6 +277,24 @@ describeMathCPU('HeNormal initializer', () => {
239
277
expectTensorsValuesInRange ( weights , - 2 * stddev , 2 * stddev ) ;
240
278
} ) ;
241
279
280
+ it ( 'with configured seed' , ( ) => {
281
+
282
+ const initializerConfig : serialization . ConfigDict = {
283
+ className : 'HeNormal' ,
284
+ config : { seed : 42 }
285
+ } ;
286
+
287
+ const expectedInitializer = getInitializer ( initializerConfig ) ;
288
+ const actualInitializer = getInitializer ( initializerConfig ) ;
289
+
290
+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
291
+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
292
+
293
+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
294
+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
295
+ expectTensorsClose ( actual , expected ) ;
296
+ } ) ;
297
+
242
298
it ( 'Does not leak' , ( ) => {
243
299
expectNoLeakedTensors ( ( ) => getInitializer ( 'HeNormal' ) . apply ( [ 3 ] ) , 1 ) ;
244
300
} ) ;
@@ -264,6 +320,24 @@ describeMathCPU('HeUniform initializer', () => {
264
320
expectTensorsValuesInRange ( weights , - bound , bound ) ;
265
321
} ) ;
266
322
323
+ it ( 'with configured seed' , ( ) => {
324
+
325
+ const initializerConfig : serialization . ConfigDict = {
326
+ className : 'HeUniform' ,
327
+ config : { seed : 42 }
328
+ } ;
329
+
330
+ const expectedInitializer = getInitializer ( initializerConfig ) ;
331
+ const actualInitializer = getInitializer ( initializerConfig ) ;
332
+
333
+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
334
+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
335
+
336
+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
337
+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
338
+ expectTensorsClose ( actual , expected ) ;
339
+ } ) ;
340
+
267
341
it ( 'Does not leak' , ( ) => {
268
342
expectNoLeakedTensors ( ( ) => getInitializer ( 'heUniform' ) . apply ( [ 3 ] ) , 1 ) ;
269
343
} ) ;
@@ -289,6 +363,24 @@ describeMathCPU('LecunNormal initializer', () => {
289
363
expectTensorsValuesInRange ( weights , - 2 * stddev , 2 * stddev ) ;
290
364
} ) ;
291
365
366
+ it ( 'with configured seed' , ( ) => {
367
+
368
+ const initializerConfig : serialization . ConfigDict = {
369
+ className : 'LeCunNormal' ,
370
+ config : { seed : 42 }
371
+ } ;
372
+
373
+ const expectedInitializer = getInitializer ( initializerConfig ) ;
374
+ const actualInitializer = getInitializer ( initializerConfig ) ;
375
+
376
+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
377
+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
378
+
379
+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
380
+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
381
+ expectTensorsClose ( actual , expected ) ;
382
+ } ) ;
383
+
292
384
it ( 'Does not leak' , ( ) => {
293
385
expectNoLeakedTensors ( ( ) => getInitializer ( 'LeCunNormal' ) . apply ( [ 3 ] ) , 1 ) ;
294
386
} ) ;
@@ -314,6 +406,24 @@ describeMathCPU('LeCunUniform initializer', () => {
314
406
expectTensorsValuesInRange ( weights , - bound , bound ) ;
315
407
} ) ;
316
408
409
+ it ( 'with configured seed' , ( ) => {
410
+
411
+ const initializerConfig : serialization . ConfigDict = {
412
+ className : 'LeCunUniform' ,
413
+ config : { seed : 42 }
414
+ } ;
415
+
416
+ const expectedInitializer = getInitializer ( initializerConfig ) ;
417
+ const actualInitializer = getInitializer ( initializerConfig ) ;
418
+
419
+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
420
+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
421
+
422
+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
423
+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
424
+ expectTensorsClose ( actual , expected ) ;
425
+ } ) ;
426
+
317
427
it ( 'Does not leak' , ( ) => {
318
428
expectNoLeakedTensors ( ( ) => getInitializer ( 'LeCunUniform' ) . apply ( [ 3 ] ) , 1 ) ;
319
429
} ) ;
@@ -348,6 +458,25 @@ describeMathCPU('TruncatedNormal initializer', () => {
348
458
expect ( weights . dtype ) . toEqual ( 'float32' ) ;
349
459
expectTensorsValuesInRange ( weights , 0.0 , 2.0 ) ;
350
460
} ) ;
461
+
462
+ it ( 'with configured seed' , ( ) => {
463
+
464
+ const initializerConfig : serialization . ConfigDict = {
465
+ className : 'TruncatedNormal' ,
466
+ config : { seed : 42 }
467
+ } ;
468
+
469
+ const expectedInitializer = getInitializer ( initializerConfig ) ;
470
+ const actualInitializer = getInitializer ( initializerConfig ) ;
471
+
472
+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
473
+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
474
+
475
+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
476
+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
477
+ expectTensorsClose ( actual , expected ) ;
478
+ } ) ;
479
+
351
480
it ( 'Does not leak' , ( ) => {
352
481
expectNoLeakedTensors (
353
482
( ) => getInitializer ( 'TruncatedNormal' ) . apply ( [ 3 ] ) , 1 ) ;
@@ -403,6 +532,25 @@ describeMathCPU('Glorot uniform initializer', () => {
403
532
. toBeGreaterThan ( - limit ) ;
404
533
} ) ;
405
534
} ) ;
535
+
536
+ it ( 'with configured seed' , ( ) => {
537
+
538
+ const initializerConfig : serialization . ConfigDict = {
539
+ className : 'GlorotUniform' ,
540
+ config : { seed : 42 }
541
+ } ;
542
+
543
+ const expectedInitializer = getInitializer ( initializerConfig ) ;
544
+ const actualInitializer = getInitializer ( initializerConfig ) ;
545
+
546
+ const expected = expectedInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
547
+ const actual = actualInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
548
+
549
+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
550
+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
551
+ expectTensorsClose ( actual , expected ) ;
552
+ } ) ;
553
+
406
554
it ( 'Does not leak' , ( ) => {
407
555
expectNoLeakedTensors ( ( ) => getInitializer ( 'GlorotUniform' ) . apply ( [ 3 ] ) , 1 ) ;
408
556
} ) ;
@@ -429,6 +577,27 @@ describeMathCPU('VarianceScaling initializer', () => {
429
577
const newConfig = newInit . getConfig ( ) ;
430
578
expect ( newConfig [ 'distribution' ] ) . toEqual ( baseConfig [ 'distribution' ] ) ;
431
579
} ) ;
580
+
581
+ it ( `${ distribution } with configured seed` , ( ) => {
582
+
583
+ const initializerConfig : serialization . ConfigDict = {
584
+ className : 'VarianceScaling' ,
585
+ config : {
586
+ distribution,
587
+ seed : 42
588
+ }
589
+ } ;
590
+
591
+ const expectedInitializer = getInitializer ( initializerConfig ) ;
592
+ const actualInitializer = getInitializer ( initializerConfig ) ;
593
+
594
+ const expected = expectedInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
595
+ const actual = actualInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
596
+
597
+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
598
+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
599
+ expectTensorsClose ( actual , expected ) ;
600
+ } ) ;
432
601
} ) ;
433
602
} ) ;
434
603
@@ -485,6 +654,25 @@ describeMathCPU('Glorot normal initializer', () => {
485
654
expect ( variance2 ) . toBeLessThan ( variance1 ) ;
486
655
} ) ;
487
656
} ) ;
657
+
658
+ it ( 'with configured seed' , ( ) => {
659
+
660
+ const initializerConfig : serialization . ConfigDict = {
661
+ className : 'GlorotNormal' ,
662
+ config : { seed : 42 }
663
+ } ;
664
+
665
+ const expectedInitializer = getInitializer ( initializerConfig ) ;
666
+ const actualInitializer = getInitializer ( initializerConfig ) ;
667
+
668
+ const expected = expectedInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
669
+ const actual = actualInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
670
+
671
+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
672
+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
673
+ expectTensorsClose ( actual , expected ) ;
674
+ } ) ;
675
+
488
676
it ( 'Does not leak' , ( ) => {
489
677
expectNoLeakedTensors ( ( ) => getInitializer ( 'GlorotNormal' ) . apply ( [ 3 ] ) , 1 ) ;
490
678
} ) ;
0 commit comments