diff --git a/basis/SOCKET.sig b/basis/SOCKET.sig index 906057fc9..fb6c145ff 100644 --- a/basis/SOCKET.sig +++ b/basis/SOCKET.sig @@ -47,9 +47,10 @@ signature SOCKET = sig val setRCVBUF : ('af, 'st) sock * int -> unit val getTYPE : ('af, 'st) sock -> SOCK.sock_type val getERROR : ('af, 'st) sock -> bool -(* + val getPeerName : ('af, 'st) sock -> 'af sock_addr val getSockName : ('af, 'st) sock -> 'af sock_addr +(* val getNREAD : ('af, 'st) sock -> int val getATMARK : ('af, active stream) sock -> bool *) diff --git a/basis/Socket.sml b/basis/Socket.sml index 37475f21f..62a051dd3 100644 --- a/basis/Socket.sml +++ b/basis/Socket.sml @@ -2,6 +2,8 @@ local fun not_impl s = raise Fail ("not implemented: " ^ s) + fun isNull s = prim("__is_null",s : string) : bool + fun getCtx () : foreignptr = prim("__get_ctx",()) (* error utilities *) @@ -178,6 +180,33 @@ local val getNREAD : ('af, 'st) sock -> int val getATMARK : ('af, active stream) sock -> bool *) + + fun getPeerName (s:('af, 'st) sock) : 'af sock_addr = + case #af s of + Inet_af => + let val (addr,port) = prim("sml_getpeername_inet", #fd s) + in maybe_failure "Socket.Ctl.getPeerName" port + ; Inet_sa {addr=addr,port=port} + end + | Unix_af => + let val name = prim("sml_getpeername_unix", #fd s) + in if isNull name then failure "Socket.Ctl.getPeerName" + else Unix_sa {name=name} + end + + fun getSockName (s:('af, 'st) sock) : 'af sock_addr = + case #af s of + Inet_af => + let val (addr,port) = prim("sml_getsockname_inet", #fd s) + in maybe_failure "Socket.Ctl.getSockName" port + ; Inet_sa {addr=addr,port=port} + end + | Unix_af => + let val name = prim("sml_getsockname_unix", #fd s) + in if isNull name then failure "Socket.Ctl.getSockName" + else Unix_sa {name=name} + end + end type sock_desc = int diff --git a/src/Runtime/Socket.c b/src/Runtime/Socket.c index a263f79d1..3a6c2b224 100644 --- a/src/Runtime/Socket.c +++ b/src/Runtime/Socket.c @@ -139,6 +139,123 @@ REG_POLY_FUN_HDR(sml_sock_accept_unix, return vPair; } +uintptr_t +sml_getsockname_inet(uintptr_t vPair, + size_t sock) +{ + // return type is "addr * port" + // vPair points to allocated return pair + + sml_debug("[sml_getsockname_inet"); + + struct sockaddr_in addr; + socklen_t len = sizeof(addr); + + // initialise allocated memory + mkTagPairML(vPair); + first(vPair) = convertIntToML(0); // initialise + second(vPair) = convertIntToML(0); + int ret = getsockname(convertIntToC(sock), + (struct sockaddr *) &addr, + &len); + + if (ret < 0 || len > sizeof(addr)) { + sml_debug("]*\n"); + second(vPair) = convertIntToML(-1); + return vPair; + } + first(vPair) = convertIntToML(ntohl(addr.sin_addr.s_addr)); + second(vPair) = convertIntToML(ntohs(addr.sin_port)); + sml_debug("]\n"); + return vPair; +} + +String +REG_POLY_FUN_HDR(sml_getsockname_unix, + Region rString, + size_t sock) +{ + // rString points to a string region + + sml_debug("[sml_getsockname_unix"); + + struct sockaddr_un addr; + socklen_t len = sizeof(addr); + + // initialise allocated memory + memset(&addr, '\0', sizeof(addr)); // zero structure out + int ret = getsockname(convertIntToC(sock), + (struct sockaddr *) &addr, + &len); + + if (ret < 0 || len > sizeof(addr)) { + sml_debug("]*\n"); + return NULL; + } + String s = REG_POLY_CALL(convertStringToML, rString, addr.sun_path); + sml_debug("]\n"); + return s; +} + +uintptr_t +sml_getpeername_inet(uintptr_t vPair, + size_t sock) +{ + // return type is "addr * port" + // vPair points to allocated return pair + + sml_debug("[sml_getpeername_inet"); + + struct sockaddr_in addr; + socklen_t len = sizeof(addr); + + // initialise allocated memory + mkTagPairML(vPair); + first(vPair) = convertIntToML(0); // initialise + second(vPair) = convertIntToML(0); + int ret = getpeername(convertIntToC(sock), + (struct sockaddr *) &addr, + &len); + + if (ret < 0 || len > sizeof(addr)) { + sml_debug("]*\n"); + second(vPair) = convertIntToML(-1); + return vPair; + } + first(vPair) = convertIntToML(ntohl(addr.sin_addr.s_addr)); + second(vPair) = convertIntToML(ntohs(addr.sin_port)); + sml_debug("]\n"); + return vPair; +} + +String +REG_POLY_FUN_HDR(sml_getpeername_unix, + Region rString, + size_t sock) +{ + // rString points to a string region + + sml_debug("[sml_getpeername_unix"); + + struct sockaddr_un addr; + socklen_t len = sizeof(addr); + + // initialise allocated memory + memset(&addr, '\0', sizeof(addr)); // zero structure out + int ret = getpeername(convertIntToC(sock), + (struct sockaddr *) &addr, + &len); + + if (ret < 0 || len > sizeof(addr)) { + sml_debug("]*\n"); + return NULL; + } + String s = REG_POLY_CALL(convertStringToML, rString, addr.sun_path); + sml_debug("]\n"); + return s; +} + + // returns -1 on error size_t sml_sock_listen(size_t sock, size_t i) diff --git a/test/server.sml b/test/server.sml index 0d9937e8e..3fc64c26c 100644 --- a/test/server.sml +++ b/test/server.sml @@ -1,11 +1,18 @@ fun sendHello sock = - let + let val bind_addr = Socket.Ctl.getSockName sock + val bind_pair = INetSock.fromAddr bind_addr + val peer_addr = Socket.Ctl.getPeerName sock + val peer_pair = INetSock.fromAddr peer_addr + fun pr (inaddr,port) = NetHostDB.toString inaddr ^ ":" ^ Int.toString port val t = Time.now() val date = Date.fromTimeLocal t val date_str = Date.toString date - val msg = "Hello world! " ^ - "The date is " ^ date_str ^ "..." + val msg = "Hello world! \n" ^ + "The date is " ^ date_str ^ "... \n" ^ + "Bound address is " ^ pr bind_pair ^ "... \n" ^ + "Peer address is " ^ pr peer_pair ^ "... " + val res = "HTTP/1.1 200 OK\r\nContent-Length: " ^ Int.toString (size msg) ^ "\r\n\r\n" ^ msg ^ "\r\n\r\n" val slc = Word8VectorSlice.full (Byte.stringToBytes res)