10
10
11
11
use libc;
12
12
use cell:: UnsafeCell ;
13
+ use sync:: atomic:: { AtomicUsize , Ordering } ;
13
14
14
15
pub struct RWLock {
15
16
inner : UnsafeCell < libc:: pthread_rwlock_t > ,
16
17
write_locked : UnsafeCell < bool > ,
18
+ num_readers : AtomicUsize ,
17
19
}
18
20
19
21
unsafe impl Send for RWLock { }
@@ -24,6 +26,7 @@ impl RWLock {
24
26
RWLock {
25
27
inner : UnsafeCell :: new ( libc:: PTHREAD_RWLOCK_INITIALIZER ) ,
26
28
write_locked : UnsafeCell :: new ( false ) ,
29
+ num_readers : AtomicUsize :: new ( 0 ) ,
27
30
}
28
31
}
29
32
#[ inline]
@@ -54,23 +57,31 @@ impl RWLock {
54
57
panic ! ( "rwlock read lock would result in deadlock" ) ;
55
58
} else {
56
59
debug_assert_eq ! ( r, 0 ) ;
60
+ self . num_readers . fetch_add ( 1 , Ordering :: Relaxed ) ;
57
61
}
58
62
}
59
63
#[ inline]
60
64
pub unsafe fn try_read ( & self ) -> bool {
61
65
let r = libc:: pthread_rwlock_tryrdlock ( self . inner . get ( ) ) ;
62
- if r == 0 && * self . write_locked . get ( ) {
63
- self . raw_unlock ( ) ;
64
- false
66
+ if r == 0 {
67
+ if * self . write_locked . get ( ) {
68
+ self . raw_unlock ( ) ;
69
+ false
70
+ } else {
71
+ self . num_readers . fetch_add ( 1 , Ordering :: Relaxed ) ;
72
+ true
73
+ }
65
74
} else {
66
- r == 0
75
+ false
67
76
}
68
77
}
69
78
#[ inline]
70
79
pub unsafe fn write ( & self ) {
71
80
let r = libc:: pthread_rwlock_wrlock ( self . inner . get ( ) ) ;
72
- // see comments above for why we check for EDEADLK and write_locked
73
- if r == libc:: EDEADLK || * self . write_locked . get ( ) {
81
+ // See comments above for why we check for EDEADLK and write_locked. We
82
+ // also need to check that num_readers is 0.
83
+ if r == libc:: EDEADLK || * self . write_locked . get ( ) ||
84
+ self . num_readers . load ( Ordering :: Relaxed ) != 0 {
74
85
if r == 0 {
75
86
self . raw_unlock ( ) ;
76
87
}
@@ -83,12 +94,14 @@ impl RWLock {
83
94
#[ inline]
84
95
pub unsafe fn try_write ( & self ) -> bool {
85
96
let r = libc:: pthread_rwlock_trywrlock ( self . inner . get ( ) ) ;
86
- if r == 0 && * self . write_locked . get ( ) {
87
- self . raw_unlock ( ) ;
88
- false
89
- } else if r == 0 {
90
- * self . write_locked . get ( ) = true ;
91
- true
97
+ if r == 0 {
98
+ if * self . write_locked . get ( ) || self . num_readers . load ( Ordering :: Relaxed ) != 0 {
99
+ self . raw_unlock ( ) ;
100
+ false
101
+ } else {
102
+ * self . write_locked . get ( ) = true ;
103
+ true
104
+ }
92
105
} else {
93
106
false
94
107
}
@@ -101,10 +114,12 @@ impl RWLock {
101
114
#[ inline]
102
115
pub unsafe fn read_unlock ( & self ) {
103
116
debug_assert ! ( !* self . write_locked. get( ) ) ;
117
+ self . num_readers . fetch_sub ( 1 , Ordering :: Relaxed ) ;
104
118
self . raw_unlock ( ) ;
105
119
}
106
120
#[ inline]
107
121
pub unsafe fn write_unlock ( & self ) {
122
+ debug_assert_eq ! ( self . num_readers. load( Ordering :: Relaxed ) , 0 ) ;
108
123
debug_assert ! ( * self . write_locked. get( ) ) ;
109
124
* self . write_locked . get ( ) = false ;
110
125
self . raw_unlock ( ) ;
0 commit comments