1010using System . Runtime . CompilerServices ;
1111using System . Runtime . Remoting ;
1212using System . Runtime . Serialization ;
13- using System . Runtime . Serialization . Formatters . Binary ;
1413using System . Runtime . Versioning ;
1514using System . Security . Permissions ;
1615using System . Text ;
@@ -241,29 +240,39 @@ private static void InvokeCallback(object eventContextPair)
241240 // END EventContextPair private class.
242241 // ----------------------------------------
243242
244- // ----------------------------------------
245- // Private class for restricting allowed types from deserialization.
246- // ----------------------------------------
247-
248- private class SqlDependencyProcessDispatcherSerializationBinder : SerializationBinder
243+ //-----------------------------------------------
244+ // Private Class to add ObjRef as DataContract
245+ //-----------------------------------------------
246+ [ SecurityPermission ( SecurityAction . Assert , Flags = SecurityPermissionFlag . RemotingConfiguration ) ]
247+ [ DataContract ]
248+ private class SqlClientObjRef
249249 {
250- public override Type BindToType ( string assemblyName , string typeName )
250+ [ DataMember ]
251+ private static ObjRef s_sqlObjRef ;
252+ internal static IRemotingTypeInfo _typeInfo ;
253+
254+ private SqlClientObjRef ( ) { }
255+
256+ public SqlClientObjRef ( SqlDependencyProcessDispatcher dispatcher ) : base ( )
251257 {
252- // Deserializing an unexpected type can inject objects with malicious side effects.
253- // If the type is unexpected, throw an exception to stop deserialization.
254- if ( typeName == nameof ( SqlDependencyProcessDispatcher ) )
255- {
256- return typeof ( SqlDependencyProcessDispatcher ) ;
257- }
258- else
259- {
260- throw new ArgumentException ( "Unexpected type" , nameof ( typeName ) ) ;
261- }
258+ s_sqlObjRef = RemotingServices . Marshal ( dispatcher ) ;
259+ _typeInfo = s_sqlObjRef . TypeInfo ;
260+ }
261+
262+ internal static bool CanCastToSqlDependencyProcessDispatcher ( )
263+ {
264+ return _typeInfo . CanCastTo ( typeof ( SqlDependencyProcessDispatcher ) , s_sqlObjRef ) ;
262265 }
266+
267+ internal ObjRef GetObjRef ( )
268+ {
269+ return s_sqlObjRef ;
270+ }
271+
263272 }
264- // ----------------------------------------
265- // END SqlDependencyProcessDispatcherSerializationBinder private class.
266- // ----------------------------------------
273+ // ------------------------------------------
274+ // End SqlClientObjRef private class.
275+ // -------------------------------------------
267276
268277 // ----------------
269278 // Instance members
@@ -306,10 +315,9 @@ public override Type BindToType(string assemblyName, string typeName)
306315 private static readonly string _typeName = ( typeof ( SqlDependencyProcessDispatcher ) ) . FullName ;
307316
308317 // -----------
309- // BID members
318+ // EventSource members
310319 // -----------
311320
312-
313321 private readonly int _objectID = System . Threading . Interlocked . Increment ( ref _objectTypeCount ) ;
314322 private static int _objectTypeCount ; // EventSource Counter
315323 internal int ObjectID
@@ -336,7 +344,7 @@ public SqlDependency(SqlCommand command) : this(command, null, SQL.SqlDependency
336344 }
337345
338346 /// <include file='..\..\..\..\..\..\..\doc\snippets\Microsoft.Data.SqlClient\SqlDependency.xml' path='docs/members[@name="SqlDependency"]/ctorCommandOptionsTimeout/*' />
339- [ System . Security . Permissions . HostProtectionAttribute ( ExternalThreading = true ) ]
347+ [ HostProtection ( ExternalThreading = true ) ]
340348 public SqlDependency ( SqlCommand command , string options , int timeout )
341349 {
342350 long scopeID = SqlClientEventSource . Log . TryNotificationScopeEnterEvent ( "<sc.SqlDependency|DEP> {0}, options: '{1}', timeout: '{2}'" , ObjectID , options , timeout ) ;
@@ -597,11 +605,13 @@ private static void ObtainProcessDispatcher()
597605 _processDispatcher = dependency . SingletonProcessDispatcher ; // Set to static instance.
598606
599607 // Serialize and set in native.
600- ObjRef objRef = GetObjRef ( _processDispatcher ) ;
601- BinaryFormatter formatter = new BinaryFormatter ( ) ;
602- MemoryStream stream = new MemoryStream ( ) ;
603- GetSerializedObject ( objRef , formatter , stream ) ;
604- SNINativeMethodWrapper . SetData ( stream . GetBuffer ( ) ) ; // Native will be forced to synchronize and not overwrite.
608+ using ( MemoryStream stream = new MemoryStream ( ) )
609+ {
610+ SqlClientObjRef objRef = new SqlClientObjRef ( _processDispatcher ) ;
611+ DataContractSerializer serializer = new DataContractSerializer ( objRef . GetType ( ) ) ;
612+ GetSerializedObject ( objRef , serializer , stream ) ;
613+ SNINativeMethodWrapper . SetData ( stream . ToArray ( ) ) ; // Native will be forced to synchronize and not overwrite.
614+ }
605615 }
606616 else
607617 {
@@ -628,37 +638,39 @@ private static void ObtainProcessDispatcher()
628638#if DEBUG // Possibly expensive, limit to debug.
629639 SqlClientEventSource . Log . TryNotificationTraceEvent ( "<sc.SqlDependency.ObtainProcessDispatcher|DEP> AppDomain.CurrentDomain.FriendlyName: {0}" , AppDomain . CurrentDomain . FriendlyName ) ;
630640#endif
631- BinaryFormatter formatter = new BinaryFormatter ( ) ;
632- MemoryStream stream = new MemoryStream ( nativeStorage ) ;
633- _processDispatcher = GetDeserializedObject ( formatter , stream ) ; // Deserialize and set for appdomain.
634- SqlClientEventSource . Log . TryNotificationTraceEvent ( "<sc.SqlDependency.ObtainProcessDispatcher|DEP> processDispatcher obtained, ID: {0}" , _processDispatcher . ObjectID ) ;
641+ using ( MemoryStream stream = new MemoryStream ( nativeStorage ) )
642+ {
643+ DataContractSerializer serializer = new DataContractSerializer ( typeof ( SqlClientObjRef ) ) ;
644+ if ( SqlClientObjRef . CanCastToSqlDependencyProcessDispatcher ( ) )
645+ {
646+ // Deserialize and set for appdomain.
647+ _processDispatcher = GetDeserializedObject ( serializer , stream ) ;
648+ }
649+ else
650+ {
651+ throw new ArgumentException ( Strings . SqlDependency_UnexpectedValueOnDeserialize ) ;
652+ }
653+ SqlClientEventSource . Log . TryNotificationTraceEvent ( "<sc.SqlDependency.ObtainProcessDispatcher|DEP> processDispatcher obtained, ID: {0}" , _processDispatcher . ObjectID ) ;
654+ }
635655 }
636656 }
637657
638658 // ---------------------------------------------------------
639659 // Static security asserted methods - limit scope of assert.
640660 // ---------------------------------------------------------
641661
642- [ SecurityPermission ( SecurityAction . Assert , Flags = SecurityPermissionFlag . RemotingConfiguration ) ]
643- private static ObjRef GetObjRef ( SqlDependencyProcessDispatcher _processDispatcher )
644- {
645- return RemotingServices . Marshal ( _processDispatcher ) ;
646- }
647-
648662 [ SecurityPermission ( SecurityAction . Assert , Flags = SecurityPermissionFlag . SerializationFormatter ) ]
649- private static void GetSerializedObject ( ObjRef objRef , BinaryFormatter formatter , MemoryStream stream )
663+ private static void GetSerializedObject ( SqlClientObjRef objRef , DataContractSerializer serializer , MemoryStream stream )
650664 {
651- formatter . Serialize ( stream , objRef ) ;
665+ serializer . WriteObject ( stream , objRef ) ;
652666 }
653667
654668 [ SecurityPermission ( SecurityAction . Assert , Flags = SecurityPermissionFlag . SerializationFormatter ) ]
655- private static SqlDependencyProcessDispatcher GetDeserializedObject ( BinaryFormatter formatter , MemoryStream stream )
669+ private static SqlDependencyProcessDispatcher GetDeserializedObject ( DataContractSerializer serializer , MemoryStream stream )
656670 {
657- // Use a custom SerializationBinder to restrict deserialized types to SqlDependencyProcessDispatcher.
658- formatter . Binder = new SqlDependencyProcessDispatcherSerializationBinder ( ) ;
659- object result = formatter . Deserialize ( stream ) ;
660- Debug . Assert ( result . GetType ( ) == typeof ( SqlDependencyProcessDispatcher ) , "Unexpected type stored in native!" ) ;
661- return ( SqlDependencyProcessDispatcher ) result ;
671+ object refResult = serializer . ReadObject ( stream ) ;
672+ var result = RemotingServices . Unmarshal ( ( refResult as SqlClientObjRef ) . GetObjRef ( ) ) ;
673+ return result as SqlDependencyProcessDispatcher ;
662674 }
663675
664676 // -------------------------
@@ -1325,7 +1337,6 @@ private void AddCommandInternal(SqlCommand cmd)
13251337 {
13261338 if ( cmd != null )
13271339 {
1328- // Don't bother with BID if command null.
13291340 long scopeID = SqlClientEventSource . Log . TryNotificationScopeEnterEvent ( "<sc.SqlDependency.AddCommandInternal|DEP> {0}, SqlCommand: {1}" , ObjectID , cmd . ObjectID ) ;
13301341 try
13311342 {
0 commit comments