Skip to content

Commit

Permalink
Add read timeout option
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEvilRoot committed Mar 8, 2021
1 parent d4efa94 commit d52b9b8
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
13 changes: 11 additions & 2 deletions src/main/kotlin/com/theevilroot/asyncsocket/CoroutineSocket.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.CompletionHandler
import java.util.concurrent.TimeUnit
import kotlin.coroutines.Continuation
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
import kotlin.coroutines.suspendCoroutine

open class CoroutineSocket(private val socket: AsynchronousSocketChannel) {
open class CoroutineSocket(
private val socket: AsynchronousSocketChannel,
private val readTimeout: Pair<Long, TimeUnit>?
) {

var isConnected: Boolean = false
private set
Expand All @@ -25,7 +29,12 @@ open class CoroutineSocket(private val socket: AsynchronousSocketChannel) {

open suspend fun read(buffer: ByteBuffer): Int {
return suspendCoroutine {
socket.read(buffer, it, ContinuationHandler<Int>())
if (readTimeout != null) {
socket.read(buffer, readTimeout.first,
readTimeout.second, it, ContinuationHandler<Int>())
} else {
socket.read(buffer, it, ContinuationHandler<Int>())
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import java.net.InetAddress
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousSocketChannel
import java.util.concurrent.TimeUnit

class Socks4CoroutineSocket(
val socksIsa: InetSocketAddress,
channel: AsynchronousSocketChannel,
val userId: String
) : CoroutineSocket(channel) {
val userId: String,
readTimeout: Pair<Long, TimeUnit>? = null
) : CoroutineSocket(channel, readTimeout) {

lateinit var remoteIsa: InetSocketAddress

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import java.net.InetAddress
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousSocketChannel
import java.util.concurrent.TimeUnit

class SocksCoroutineSocket(
val socksIsa: InetSocketAddress,
channel: AsynchronousSocketChannel,
val credentials: Pair<String, String>? = null
) : CoroutineSocket(channel) {
val credentials: Pair<String, String>? = null,
readTimeout: Pair<Long, TimeUnit>? = null
) : CoroutineSocket(channel, readTimeout) {

enum class Method { NO_AUTH, USER_PASS }

Expand Down
4 changes: 2 additions & 2 deletions src/test/kotlin/Core.kt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ suspend fun runSocksHttpRequestTest(channel: AsynchronousSocketChannel) {

suspend fun runRawHttpRequestTest(channel: AsynchronousSocketChannel) {
println("opening raw socket...")
val raw = CoroutineSocket(channel)
val raw = CoroutineSocket(channel, null)
raw.connect(InetSocketAddress("api.ipify.org", 80))
ByteBuffer.wrap(req).let { raw.write(it) }

Expand All @@ -71,7 +71,7 @@ suspend fun runRawHttpRequestTest(channel: AsynchronousSocketChannel) {
}

suspend fun runCoroutineSocketTest(channel: AsynchronousSocketChannel, dispatcher: CoroutineDispatcher) = withContext(dispatcher) {
val socket = CoroutineSocket(channel)
val socket = CoroutineSocket(channel, null)

launch {
socket.connect(InetSocketAddress("localhost", 9999))
Expand Down

0 comments on commit d52b9b8

Please sign in to comment.