@@ -5,11 +5,13 @@ import { sum } from "$lib/utils/sum";
55import  { 
66	embeddingEndpoints , 
77	embeddingEndpointSchema , 
8- 	type  EmbeddingEndpoint , 
98}  from  "$lib/server/embeddingEndpoints/embeddingEndpoints" ; 
109import  {  embeddingEndpointTransformersJS  }  from  "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints" ; 
1110
1211import  JSON5  from  "json5" ; 
12+ import  type  {  EmbeddingModel  }  from  "$lib/types/EmbeddingModel" ; 
13+ import  {  collections  }  from  "./database" ; 
14+ import  {  ObjectId  }  from  "mongodb" ; 
1315
1416const  modelConfig  =  z . object ( { 
1517	/** Used as an identifier in DB */ 
@@ -42,67 +44,77 @@ const rawEmbeddingModelJSON =
4244
4345const  embeddingModelsRaw  =  z . array ( modelConfig ) . parse ( JSON5 . parse ( rawEmbeddingModelJSON ) ) ; 
4446
45- const  processEmbeddingModel  =  async  ( m : z . infer < typeof  modelConfig > )  =>  ( { 
46- 	...m , 
47- 	id : m . id  ||  m . name , 
47+ const  embeddingModels  =  embeddingModelsRaw . map ( ( rawEmbeddingModel )  =>  { 
48+ 	const  embeddingModel : EmbeddingModel  =  { 
49+ 		name : rawEmbeddingModel . name , 
50+ 		description : rawEmbeddingModel . description , 
51+ 		websiteUrl : rawEmbeddingModel . websiteUrl , 
52+ 		modelUrl : rawEmbeddingModel . modelUrl , 
53+ 		chunkCharLength : rawEmbeddingModel . chunkCharLength , 
54+ 		maxBatchSize : rawEmbeddingModel . maxBatchSize , 
55+ 		preQuery : rawEmbeddingModel . preQuery , 
56+ 		prePassage : rawEmbeddingModel . prePassage , 
57+ 		_id : new  ObjectId ( ) , 
58+ 		createdAt : new  Date ( ) , 
59+ 		updatedAt : new  Date ( ) , 
60+ 		endpoints : rawEmbeddingModel . endpoints , 
61+ 	} ; 
62+ 
63+ 	return  embeddingModel ; 
4864} ) ; 
4965
50- const  addEndpoint  =  ( m : Awaited < ReturnType < typeof  processEmbeddingModel > > )  =>  ( { 
51- 	...m , 
52- 	getEndpoint : async  ( ) : Promise < EmbeddingEndpoint >  =>  { 
53- 		if  ( ! m . endpoints )  { 
54- 			return  embeddingEndpointTransformersJS ( { 
55- 				type : "transformersjs" , 
56- 				weight : 1 , 
57- 				model : m , 
58- 			} ) ; 
59- 		} 
66+ export  const  getEmbeddingEndpoint  =  async  ( embeddingModel : EmbeddingModel )  =>  { 
67+ 	if  ( ! embeddingModel . endpoints )  { 
68+ 		return  embeddingEndpointTransformersJS ( { 
69+ 			type : "transformersjs" , 
70+ 			weight : 1 , 
71+ 			model : embeddingModel , 
72+ 		} ) ; 
73+ 	} 
6074
61- 		 const  totalWeight  =  sum ( m . endpoints . map ( ( e )  =>  e . weight ) ) ; 
62- 
63- 		 let  random  =  Math . random ( )  *  totalWeight ; 
64- 
65- 		 for  ( const  endpoint  of  m . endpoints )  { 
66- 			 if  ( random  <  endpoint . weight )  { 
67- 				 const  args  =  {  ...endpoint ,  model : m  } ; 
68- 
69- 				 switch   ( args . type )   { 
70- 					 case   "tei" : 
71- 						 return   embeddingEndpoints . tei ( args ) ; 
72- 					case   "transformersjs" : 
73- 						 return   embeddingEndpoints . transformersjs ( args ) ; 
74- 					case   "openai" : 
75- 						 return   embeddingEndpoints . openai ( args ) ; 
76- 					case   "hfapi" : 
77- 						 return   embeddingEndpoints . hfapi ( args ) ; 
78- 					default : 
79- 						 throw   new   Error ( `Unknown endpoint type:  ${ args } ` ) ; 
80- 				} 
75+ 	const  totalWeight  =  sum ( embeddingModel . endpoints . map ( ( e )  =>  e . weight ) ) ; 
76+ 
77+ 	let  random  =  Math . random ( )  *  totalWeight ; 
78+ 
79+ 	for  ( const  endpoint  of  embeddingModel . endpoints )  { 
80+ 		if  ( random  <  endpoint . weight )  { 
81+ 			const  args  =  {  ...endpoint ,  model : embeddingModel  } ; 
82+ 			 console . log ( args . type ) ; 
83+ 
84+ 			switch   ( args . type )   { 
85+ 				case   " tei" : 
86+ 					return   embeddingEndpoints . tei ( args ) ; 
87+ 				case   " transformersjs" : 
88+ 					return   embeddingEndpoints . transformersjs ( args ) ; 
89+ 				case   " openai" : 
90+ 					return   embeddingEndpoints . openai ( args ) ; 
91+ 				case   " hfapi" : 
92+ 					return   embeddingEndpoints . hfapi ( args ) ; 
93+ 				default : 
94+ 					 throw   new   Error ( `Unknown endpoint type:  ${ args } ` ) ; 
8195			} 
82- 
83- 			random  -=  endpoint . weight ; 
8496		} 
8597
86- 		throw  new  Error ( `Failed to select embedding endpoint` ) ; 
87- 	} , 
88- } ) ; 
89- 
90- export  const  embeddingModels  =  await  Promise . all ( 
91- 	embeddingModelsRaw . map ( ( e )  =>  processEmbeddingModel ( e ) . then ( addEndpoint ) ) 
92- ) ; 
93- 
94- export  const  defaultEmbeddingModel  =  embeddingModels [ 0 ] ; 
98+ 		random  -=  endpoint . weight ; 
99+ 	} 
95100
96- const  validateEmbeddingModel  =  ( _models : EmbeddingBackendModel [ ] ,  key : "id"  |  "name" )  =>  { 
97- 	return  z . enum ( [ _models [ 0 ] [ key ] ,  ..._models . slice ( 1 ) . map ( ( m )  =>  m [ key ] ) ] ) ; 
101+ 	throw  new  Error ( `Failed to select embedding endpoint` ) ; 
98102} ; 
99103
100- export  const  validateEmbeddingModelById  =  ( _models : EmbeddingBackendModel [ ] )  =>  { 
101- 	return  validateEmbeddingModel ( _models ,  "id" ) ; 
102- } ; 
104+ export  const  getDefaultEmbeddingModel  =  async  ( ) : Promise < EmbeddingModel >  =>  { 
105+ 	if  ( ! embeddingModels [ 0 ] )  { 
106+ 		throw  new  Error ( `Failed to find default embedding endpoint` ) ; 
107+ 	} 
108+ 
109+ 	const  defaultModel  =  await  collections . embeddingModels . findOne ( { 
110+ 		_id : embeddingModels [ 0 ] . _id , 
111+ 	} ) ; 
103112
104- export  const  validateEmbeddingModelByName  =  ( _models : EmbeddingBackendModel [ ] )  =>  { 
105- 	return  validateEmbeddingModel ( _models ,  "name" ) ; 
113+ 	return  defaultModel  ? defaultModel  : embeddingModels [ 0 ] ; 
106114} ; 
107115
108- export  type  EmbeddingBackendModel  =  typeof  defaultEmbeddingModel ; 
116+ // to mimic current behaivor with creating embedding models from scratch during server start 
117+ export  async  function  pupulateEmbeddingModel ( )  { 
118+ 	await  collections . embeddingModels . deleteMany ( { } ) ; 
119+ 	await  collections . embeddingModels . insertMany ( embeddingModels ) ; 
120+ } 
0 commit comments