77from  .artifact  import  Artifact 
88from  .dict_utils  import  dict_get 
99from  .image_operators  import  ImageDataString 
10- from  .operator  import  InstanceOperatorValidator 
10+ from  .operator  import  InstanceOperator ,  InstanceOperatorValidator 
1111from  .settings_utils  import  get_constants , get_settings 
1212from  .type_utils  import  isoftype 
1313from  .types  import  Image 
@@ -87,6 +87,18 @@ def loads_instance(batch):
8787    return  batch 
8888
8989
90+ class  SerializeInstancesBeforeDump (InstanceOperator ):
91+ 
92+    def  process (
93+         self , instance : Dict [str , Any ], stream_name : Optional [str ] =  None 
94+     ) ->  Dict [str , Any ]:
95+         if  settings .task_data_as_text :
96+             instance ["task_data" ] =  json .dumps (instance ["task_data" ])
97+ 
98+         if  not  isinstance (instance ["source" ], str ):
99+             instance ["source" ] =  json .dumps (instance ["source" ])
100+         return  instance 
101+ 
90102class  FinalizeDataset (InstanceOperatorValidator ):
91103    group_by : List [List [str ]]
92104    remove_unnecessary_fields : bool  =  True 
@@ -126,13 +138,6 @@ def _get_instance_task_data(
126138            task_data  =  {** task_data , ** instance ["reference_fields" ]}
127139        return  task_data 
128140
129-     def  serialize_instance_fields (self , instance , task_data ):
130-         if  settings .task_data_as_text :
131-             instance ["task_data" ] =  json .dumps (task_data )
132- 
133-         if  not  isinstance (instance ["source" ], str ):
134-             instance ["source" ] =  json .dumps (instance ["source" ])
135-         return  instance 
136141
137142    def  process (
138143        self , instance : Dict [str , Any ], stream_name : Optional [str ] =  None 
@@ -157,7 +162,7 @@ def process(
157162                for  instance  in  instance .pop (constants .demos_field )
158163            ]
159164
160-         instance  =  self . serialize_instance_fields ( instance ,  task_data ) 
165+         instance [ "task_data" ]  =  task_data 
161166
162167        if  self .remove_unnecessary_fields :
163168            keys_to_delete  =  []
@@ -202,7 +207,8 @@ def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
202207            instance , dict 
203208        ), f"Instance should be a dict, got { type (instance )}  
204209        schema  =  get_schema (stream_name )
210+ 
205211        assert  all (
206212            key  in  instance  for  key  in  schema 
207-         ), f"Instance should have the following keys: { schema } { instance }  
208-         schema .encode_example (instance )
213+         ), f"Instance should have the following keys: { schema . keys () } { instance . keys () }  
214+         #  schema.encode_example(instance)
0 commit comments