@@ -180,6 +180,25 @@ describeMathCPU('RandomUniform initializer', () => {
180180 expect ( weights . dtype ) . toEqual ( 'float32' ) ;
181181 expectTensorsValuesInRange ( weights , 17 , 47 ) ;
182182 } ) ;
183+
184+ it ( 'with configured seed' , ( ) => {
185+
186+ const initializerConfig : serialization . ConfigDict = {
187+ className : 'RandomUniform' ,
188+ config : { seed : 42 }
189+ } ;
190+
191+ const expectedInitializer = getInitializer ( initializerConfig ) ;
192+ const actualInitializer = getInitializer ( initializerConfig ) ;
193+
194+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
195+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
196+
197+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
198+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
199+ expectTensorsClose ( actual , expected ) ;
200+ } ) ;
201+
183202 it ( 'Does not leak' , ( ) => {
184203 expectNoLeakedTensors ( ( ) => getInitializer ( 'RandomUniform' ) . apply ( [ 3 ] ) , 1 ) ;
185204 } ) ;
@@ -214,6 +233,25 @@ describeMathCPU('RandomNormal initializer', () => {
214233 expect ( weights . dtype ) . toEqual ( 'float32' ) ;
215234 // TODO(bileschi): Add test to assert the values match expectations.
216235 } ) ;
236+
237+ it ( 'with configured seed' , ( ) => {
238+
239+ const initializerConfig : serialization . ConfigDict = {
240+ className : 'RandomNormal' ,
241+ config : { seed : 42 }
242+ } ;
243+
244+ const expectedInitializer = getInitializer ( initializerConfig ) ;
245+ const actualInitializer = getInitializer ( initializerConfig ) ;
246+
247+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
248+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
249+
250+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
251+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
252+ expectTensorsClose ( actual , expected ) ;
253+ } ) ;
254+
217255 it ( 'Does not leak' , ( ) => {
218256 expectNoLeakedTensors ( ( ) => getInitializer ( 'RandomNormal' ) . apply ( [ 3 ] ) , 1 ) ;
219257 } ) ;
@@ -239,6 +277,24 @@ describeMathCPU('HeNormal initializer', () => {
239277 expectTensorsValuesInRange ( weights , - 2 * stddev , 2 * stddev ) ;
240278 } ) ;
241279
280+ it ( 'with configured seed' , ( ) => {
281+
282+ const initializerConfig : serialization . ConfigDict = {
283+ className : 'HeNormal' ,
284+ config : { seed : 42 }
285+ } ;
286+
287+ const expectedInitializer = getInitializer ( initializerConfig ) ;
288+ const actualInitializer = getInitializer ( initializerConfig ) ;
289+
290+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
291+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
292+
293+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
294+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
295+ expectTensorsClose ( actual , expected ) ;
296+ } ) ;
297+
242298 it ( 'Does not leak' , ( ) => {
243299 expectNoLeakedTensors ( ( ) => getInitializer ( 'HeNormal' ) . apply ( [ 3 ] ) , 1 ) ;
244300 } ) ;
@@ -264,6 +320,24 @@ describeMathCPU('HeUniform initializer', () => {
264320 expectTensorsValuesInRange ( weights , - bound , bound ) ;
265321 } ) ;
266322
323+ it ( 'with configured seed' , ( ) => {
324+
325+ const initializerConfig : serialization . ConfigDict = {
326+ className : 'HeUniform' ,
327+ config : { seed : 42 }
328+ } ;
329+
330+ const expectedInitializer = getInitializer ( initializerConfig ) ;
331+ const actualInitializer = getInitializer ( initializerConfig ) ;
332+
333+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
334+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
335+
336+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
337+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
338+ expectTensorsClose ( actual , expected ) ;
339+ } ) ;
340+
267341 it ( 'Does not leak' , ( ) => {
268342 expectNoLeakedTensors ( ( ) => getInitializer ( 'heUniform' ) . apply ( [ 3 ] ) , 1 ) ;
269343 } ) ;
@@ -289,6 +363,24 @@ describeMathCPU('LecunNormal initializer', () => {
289363 expectTensorsValuesInRange ( weights , - 2 * stddev , 2 * stddev ) ;
290364 } ) ;
291365
366+ it ( 'with configured seed' , ( ) => {
367+
368+ const initializerConfig : serialization . ConfigDict = {
369+ className : 'LeCunNormal' ,
370+ config : { seed : 42 }
371+ } ;
372+
373+ const expectedInitializer = getInitializer ( initializerConfig ) ;
374+ const actualInitializer = getInitializer ( initializerConfig ) ;
375+
376+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
377+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
378+
379+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
380+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
381+ expectTensorsClose ( actual , expected ) ;
382+ } ) ;
383+
292384 it ( 'Does not leak' , ( ) => {
293385 expectNoLeakedTensors ( ( ) => getInitializer ( 'LeCunNormal' ) . apply ( [ 3 ] ) , 1 ) ;
294386 } ) ;
@@ -314,6 +406,24 @@ describeMathCPU('LeCunUniform initializer', () => {
314406 expectTensorsValuesInRange ( weights , - bound , bound ) ;
315407 } ) ;
316408
409+ it ( 'with configured seed' , ( ) => {
410+
411+ const initializerConfig : serialization . ConfigDict = {
412+ className : 'LeCunUniform' ,
413+ config : { seed : 42 }
414+ } ;
415+
416+ const expectedInitializer = getInitializer ( initializerConfig ) ;
417+ const actualInitializer = getInitializer ( initializerConfig ) ;
418+
419+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
420+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
421+
422+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
423+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
424+ expectTensorsClose ( actual , expected ) ;
425+ } ) ;
426+
317427 it ( 'Does not leak' , ( ) => {
318428 expectNoLeakedTensors ( ( ) => getInitializer ( 'LeCunUniform' ) . apply ( [ 3 ] ) , 1 ) ;
319429 } ) ;
@@ -348,6 +458,25 @@ describeMathCPU('TruncatedNormal initializer', () => {
348458 expect ( weights . dtype ) . toEqual ( 'float32' ) ;
349459 expectTensorsValuesInRange ( weights , 0.0 , 2.0 ) ;
350460 } ) ;
461+
462+ it ( 'with configured seed' , ( ) => {
463+
464+ const initializerConfig : serialization . ConfigDict = {
465+ className : 'TruncatedNormal' ,
466+ config : { seed : 42 }
467+ } ;
468+
469+ const expectedInitializer = getInitializer ( initializerConfig ) ;
470+ const actualInitializer = getInitializer ( initializerConfig ) ;
471+
472+ const expected = expectedInitializer . apply ( shape , 'float32' ) ;
473+ const actual = actualInitializer . apply ( shape , 'float32' ) ;
474+
475+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
476+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
477+ expectTensorsClose ( actual , expected ) ;
478+ } ) ;
479+
351480 it ( 'Does not leak' , ( ) => {
352481 expectNoLeakedTensors (
353482 ( ) => getInitializer ( 'TruncatedNormal' ) . apply ( [ 3 ] ) , 1 ) ;
@@ -403,6 +532,25 @@ describeMathCPU('Glorot uniform initializer', () => {
403532 . toBeGreaterThan ( - limit ) ;
404533 } ) ;
405534 } ) ;
535+
536+ it ( 'with configured seed' , ( ) => {
537+
538+ const initializerConfig : serialization . ConfigDict = {
539+ className : 'GlorotUniform' ,
540+ config : { seed : 42 }
541+ } ;
542+
543+ const expectedInitializer = getInitializer ( initializerConfig ) ;
544+ const actualInitializer = getInitializer ( initializerConfig ) ;
545+
546+ const expected = expectedInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
547+ const actual = actualInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
548+
549+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
550+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
551+ expectTensorsClose ( actual , expected ) ;
552+ } ) ;
553+
406554 it ( 'Does not leak' , ( ) => {
407555 expectNoLeakedTensors ( ( ) => getInitializer ( 'GlorotUniform' ) . apply ( [ 3 ] ) , 1 ) ;
408556 } ) ;
@@ -429,6 +577,27 @@ describeMathCPU('VarianceScaling initializer', () => {
429577 const newConfig = newInit . getConfig ( ) ;
430578 expect ( newConfig [ 'distribution' ] ) . toEqual ( baseConfig [ 'distribution' ] ) ;
431579 } ) ;
580+
581+ it ( `${ distribution } with configured seed` , ( ) => {
582+
583+ const initializerConfig : serialization . ConfigDict = {
584+ className : 'VarianceScaling' ,
585+ config : {
586+ distribution,
587+ seed : 42
588+ }
589+ } ;
590+
591+ const expectedInitializer = getInitializer ( initializerConfig ) ;
592+ const actualInitializer = getInitializer ( initializerConfig ) ;
593+
594+ const expected = expectedInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
595+ const actual = actualInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
596+
597+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
598+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
599+ expectTensorsClose ( actual , expected ) ;
600+ } ) ;
432601 } ) ;
433602} ) ;
434603
@@ -485,6 +654,25 @@ describeMathCPU('Glorot normal initializer', () => {
485654 expect ( variance2 ) . toBeLessThan ( variance1 ) ;
486655 } ) ;
487656 } ) ;
657+
658+ it ( 'with configured seed' , ( ) => {
659+
660+ const initializerConfig : serialization . ConfigDict = {
661+ className : 'GlorotNormal' ,
662+ config : { seed : 42 }
663+ } ;
664+
665+ const expectedInitializer = getInitializer ( initializerConfig ) ;
666+ const actualInitializer = getInitializer ( initializerConfig ) ;
667+
668+ const expected = expectedInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
669+ const actual = actualInitializer . apply ( [ 7 , 2 ] , 'float32' ) ;
670+
671+ expect ( actual . shape ) . toEqual ( expected . shape ) ;
672+ expect ( actual . dtype ) . toEqual ( expected . dtype ) ;
673+ expectTensorsClose ( actual , expected ) ;
674+ } ) ;
675+
488676 it ( 'Does not leak' , ( ) => {
489677 expectNoLeakedTensors ( ( ) => getInitializer ( 'GlorotNormal' ) . apply ( [ 3 ] ) , 1 ) ;
490678 } ) ;
0 commit comments