From 2fe3355560605327fd2eb198c6c632c580acbbec Mon Sep 17 00:00:00 2001
From: "Adam D. Ruppe" <destructionator@gmail.com>
Date: Wed, 22 Dec 2021 18:46:29 -0500
Subject: [PATCH] Windows hybrid server support

---
 cgi.d | 514 +++++++++++++++++++++++++++++++++++++++++++++++-----------
 1 file changed, 416 insertions(+), 98 deletions(-)

diff --git a/cgi.d b/cgi.d
index a855e30..f7e3278 100644
--- a/cgi.d
+++ b/cgi.d
@@ -397,7 +397,7 @@ void cloexec(Socket s) {
 
 version(embedded_httpd_hybrid) {
 	version=embedded_httpd_threads;
-	version(cgi_no_fork) {} else
+	version(cgi_no_fork) {} else version(Posix)
 		version=cgi_use_fork;
 	version=cgi_use_fiber;
 }
@@ -4195,10 +4195,239 @@ class CgiFiber : Fiber {
 	}
 }
 
+version(cgi_use_fiber)
+version(Windows) {
+
+extern(Windows) private {
+
+	import core.sys.windows.mswsock;
+
+	alias GROUP=uint;
+	alias LPWSAPROTOCOL_INFOW = void*;
+	SOCKET WSASocketW(int af, int type, int protocol, LPWSAPROTOCOL_INFOW lpProtocolInfo, GROUP g, DWORD dwFlags);
+	int WSASend(SOCKET s, LPWSABUF lpBuffers, DWORD dwBufferCount, LPDWORD lpNumberOfBytesSent, DWORD dwFlags, LPWSAOVERLAPPED lpOverlapped, LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine);
+	int WSARecv(SOCKET s, LPWSABUF lpBuffers, DWORD dwBufferCount, LPDWORD lpNumberOfBytesRecvd, LPDWORD lpFlags, LPWSAOVERLAPPED lpOverlapped, LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine);
+
+	struct WSABUF {
+		ULONG len;
+		CHAR  *buf;
+	}
+	alias LPWSABUF = WSABUF*;
+
+	alias WSAOVERLAPPED = OVERLAPPED;
+	alias LPWSAOVERLAPPED = LPOVERLAPPED;
+	/+
+
+	alias LPFN_ACCEPTEX = 
+		BOOL
+		function(
+				SOCKET sListenSocket,
+				SOCKET sAcceptSocket,
+				//_Out_writes_bytes_(dwReceiveDataLength+dwLocalAddressLength+dwRemoteAddressLength) PVOID lpOutputBuffer,
+				void* lpOutputBuffer,
+				WORD dwReceiveDataLength,
+				WORD dwLocalAddressLength,
+				WORD dwRemoteAddressLength,
+				LPDWORD lpdwBytesReceived,
+				LPOVERLAPPED lpOverlapped
+			);
+
+	enum WSAID_ACCEPTEX = GUID([0xb5367df1,0xcbac,0x11cf,[0x95,0xca,0x00,0x80,0x5f,0x48,0xa1,0x92]]);
+	+/
+
+	enum WSAID_GETACCEPTEXSOCKADDRS = GUID(0xb5367df2,0xcbac,0x11cf,[0x95,0xca,0x00,0x80,0x5f,0x48,0xa1,0x92]);
+}
+
+private class PseudoblockingOverlappedSocket : Socket {
+	SOCKET handle;
+
+	CgiFiber fiber;
+
+	this(AddressFamily af, SocketType st) {
+		auto handle = WSASocketW(af, st, 0, null, 0, 1 /*WSA_FLAG_OVERLAPPED*/);
+		if(!handle)
+			throw new Exception("WSASocketW");
+		this.handle = handle;
+
+		iocp = CreateIoCompletionPort(cast(HANDLE) handle, iocp, cast(ULONG_PTR) cast(void*) this, 0);
+
+		if(iocp is null) {
+			writeln(GetLastError());
+			throw new Exception("CreateIoCompletionPort");
+		}
+
+		super(cast(socket_t) handle, af);
+	}
+	this() pure nothrow @trusted { assert(0); }
+
+	override void blocking(bool) {} // meaningless to us, just ignore it.
+
+	protected override Socket accepting() pure nothrow {
+		assert(0);
+	}
+
+	bool addressesParsed;
+	Address la;
+	Address ra;
+
+	private void populateAddresses() {
+		if(addressesParsed)
+			return;
+		addressesParsed = true;
+
+		int lalen, ralen;
+
+		sockaddr_in* la;
+		sockaddr_in* ra;
+
+		lpfnGetAcceptExSockaddrs(
+			scratchBuffer.ptr,
+			0, // same as in the AcceptEx call!
+			sockaddr_in.sizeof + 16,
+			sockaddr_in.sizeof + 16,
+			cast(sockaddr**) &la,
+			&lalen,
+			cast(sockaddr**) &ra,
+			&ralen
+		);
+
+		if(la)
+			this.la = new InternetAddress(*la);
+		if(ra)
+			this.ra = new InternetAddress(*ra);
+
+	}
+
+	override @property @trusted Address localAddress() {
+		populateAddresses();
+		return la;
+	}
+	override @property @trusted Address remoteAddress() {
+		populateAddresses();
+		return ra;
+	}
+
+	PseudoblockingOverlappedSocket accepted;
+
+	__gshared static LPFN_ACCEPTEX lpfnAcceptEx;
+	__gshared static typeof(&GetAcceptExSockaddrs) lpfnGetAcceptExSockaddrs;
+
+	override Socket accept() @trusted {
+		__gshared static LPFN_ACCEPTEX lpfnAcceptEx;
+
+		if(lpfnAcceptEx is null) {
+			DWORD dwBytes;
+			GUID GuidAcceptEx = WSAID_ACCEPTEX;
+
+			auto iResult = WSAIoctl(handle, 0xc8000006 /*SIO_GET_EXTENSION_FUNCTION_POINTER*/,
+					&GuidAcceptEx, GuidAcceptEx.sizeof,
+					&lpfnAcceptEx, lpfnAcceptEx.sizeof,
+					&dwBytes, null, null);
+
+			GuidAcceptEx = WSAID_GETACCEPTEXSOCKADDRS;
+			iResult = WSAIoctl(handle, 0xc8000006 /*SIO_GET_EXTENSION_FUNCTION_POINTER*/,
+					&GuidAcceptEx, GuidAcceptEx.sizeof,
+					&lpfnGetAcceptExSockaddrs, lpfnGetAcceptExSockaddrs.sizeof,
+					&dwBytes, null, null);
+
+		}
+
+		auto pfa = new PseudoblockingOverlappedSocket(AddressFamily.INET, SocketType.STREAM);
+		accepted = pfa;
+
+		SOCKET pendingForAccept = pfa.handle;
+		DWORD ignored;
+
+		auto ret = lpfnAcceptEx(handle,
+			pendingForAccept,
+			// buffer to receive up front
+			pfa.scratchBuffer.ptr,
+			0,
+			// size of local and remote addresses. normally + 16.
+			sockaddr_in.sizeof + 16,
+			sockaddr_in.sizeof + 16,
+			&ignored, // bytes would be given through the iocp instead but im not even requesting the thing
+			&overlapped
+		);
+
+		return pfa;
+	}
+
+	override void connect(Address to) { assert(0); }
+
+	DWORD lastAnswer;
+	ubyte[1024] scratchBuffer;
+	static assert(scratchBuffer.length > sockaddr_in.sizeof * 2 + 32);
+
+	WSABUF[1] buffer;
+	OVERLAPPED overlapped;
+	override ptrdiff_t send(const(void)[] buf, SocketFlags flags) @trusted {
+		overlapped = overlapped.init;
+		buffer[0].len = cast(DWORD) buf.length;
+		buffer[0].buf = cast(CHAR*) buf.ptr;
+		fiber.setPostYield( () {
+			if(!WSASend(handle, buffer.ptr, cast(DWORD) buffer.length, null, 0, &overlapped, null)) {
+				if(GetLastError() != 997) {
+					//throw new Exception("WSASend fail");
+				}
+			}
+		});
+
+		Fiber.yield();
+		return lastAnswer;
+	}
+	override ptrdiff_t receive(void[] buf, SocketFlags flags) @trusted {
+		overlapped = overlapped.init;
+		buffer[0].len = cast(DWORD) buf.length;
+		buffer[0].buf = cast(CHAR*) buf.ptr;
+
+		DWORD flags2 = 0;
+
+		fiber.setPostYield(() {
+			if(!WSARecv(handle, buffer.ptr, cast(DWORD) buffer.length, null, &flags2 /* flags */, &overlapped, null)) {
+				if(GetLastError() != 997) {
+					//writeln("WSARecv ", WSAGetLastError());
+					//throw new Exception("WSARecv fail");
+				}
+			}
+		});
+
+		Fiber.yield();
+		return lastAnswer;
+	}
+
+	// I might go back and implement these for udp things.
+	override ptrdiff_t receiveFrom(void[] buf, SocketFlags flags, ref Address from) @trusted {
+		assert(0);
+	}
+	override ptrdiff_t receiveFrom(void[] buf, SocketFlags flags) @trusted {
+		assert(0);
+	}
+	override ptrdiff_t sendTo(const(void)[] buf, SocketFlags flags, Address to) @trusted {
+		assert(0);
+	}
+	override ptrdiff_t sendTo(const(void)[] buf, SocketFlags flags) @trusted {
+		assert(0);
+	}
+
+	// lol overload sets
+	alias send = typeof(super).send;
+	alias receive = typeof(super).receive;
+	alias sendTo = typeof(super).sendTo;
+	alias receiveFrom = typeof(super).receiveFrom;
+
+}
+}
+
 void doThreadHttpConnection(CustomCgi, alias fun)(Socket connection) {
 	assert(connection !is null);
 	version(cgi_use_fiber) {
 		auto fiber = new CgiFiber(&doThreadHttpConnectionGuts!(CustomCgi, fun));
+
+		version(Windows) {
+			(cast(PseudoblockingOverlappedSocket) connection).fiber = fiber;
+		}
+
 		import core.memory;
 		GC.addRoot(cast(void*) fiber);
 		fiber.connection = connection;
@@ -4558,16 +4787,28 @@ import std.socket;
 
 version(cgi_use_fiber) {
 	import core.thread;
-	import core.sys.linux.epoll;
 
-	__gshared int epfd;
+	version(linux) {
+		import core.sys.linux.epoll;
+
+		int epfd = -1; // thread local because EPOLLEXCLUSIVE works much better this way... weirdly.
+	} else version(Windows) {
+		__gshared HANDLE iocp;
+	} else static assert(0, "The hybrid fiber server is not implemented on your OS.");
 }
 
 
-version(cgi_use_fiber)
-private enum WakeupEvent {
-	Read = EPOLLIN,
-	Write = EPOLLOUT
+version(cgi_use_fiber) {
+	version(linux)
+	private enum WakeupEvent {
+		Read = EPOLLIN,
+		Write = EPOLLOUT
+	}
+	else version(Windows)
+	private enum WakeupEvent {
+		Read, Write
+	}
+	else static assert(0);
 }
 
 version(cgi_use_fiber)
@@ -4576,35 +4817,45 @@ private void registerEventWakeup(bool* registered, Socket source, WakeupEvent e)
 	// static cast since I know what i have in here and don't want to pay for dynamic cast
 	auto f = cast(CgiFiber) cast(void*) Fiber.getThis();
 
-	f.setPostYield = () {
-		if(*registered) {
-			// rearm
-			epoll_event evt;
-			evt.events = e | EPOLLONESHOT;
-			evt.data.ptr = cast(void*) f;
-			if(epoll_ctl(epfd, EPOLL_CTL_MOD, source.handle, &evt) == -1)
-				throw new Exception("epoll_ctl");
-		} else {
-			// initial registration
-			*registered = true ;
-			int fd = source.handle;
-			epoll_event evt;
-			evt.events = e | EPOLLONESHOT;
-			evt.data.ptr = cast(void*) f;
-			if(epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &evt) == -1)
-				throw new Exception("epoll_ctl");
-		}
-	};
+	version(linux) {
+		f.setPostYield = () {
+			if(*registered) {
+				// rearm
+				epoll_event evt;
+				evt.events = e | EPOLLONESHOT;
+				evt.data.ptr = cast(void*) f;
+				if(epoll_ctl(epfd, EPOLL_CTL_MOD, source.handle, &evt) == -1)
+					throw new Exception("epoll_ctl");
+			} else {
+				// initial registration
+				*registered = true ;
+				int fd = source.handle;
+				epoll_event evt;
+				evt.events = e | EPOLLONESHOT;
+				evt.data.ptr = cast(void*) f;
+				if(epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &evt) == -1)
+					throw new Exception("epoll_ctl");
+			}
+		};
 
-	Fiber.yield();
+		Fiber.yield();
 
-	f.setPostYield(null);
+		f.setPostYield(null);
+	} else version(Windows) {
+		Fiber.yield();
+	}
+	else static assert(0);
 }
 
 version(cgi_use_fiber)
 void unregisterSource(Socket s) {
-	epoll_event evt;
-	epoll_ctl(epfd, EPOLL_CTL_DEL, s.handle(), &evt);
+	version(linux) {
+		epoll_event evt;
+		epoll_ctl(epfd, EPOLL_CTL_DEL, s.handle(), &evt);
+	} else version(Windows) {
+		// intentionally blank
+	}
+	else static assert(0);
 }
 
 // it is a class primarily for reference semantics
@@ -4871,20 +5122,10 @@ class ListeningConnectionManager {
 			}
 
 			version(cgi_use_fiber) {
-				import core.sys.linux.epoll;
-				epfd = epoll_create1(EPOLL_CLOEXEC);
-				if(epfd == -1)
-					throw new Exception("epoll_create1 " ~ to!string(errno));
-				scope(exit) {
-					import core.sys.posix.unistd;
-					close(epfd);
-				}
 
-				epoll_event ev;
-				ev.events = EPOLLIN | EPOLLEXCLUSIVE; // EPOLLEXCLUSIVE is only available on kernels since like 2017 but that's prolly good enough.
-				ev.data.fd = listener.handle;
-				if(epoll_ctl(epfd, EPOLL_CTL_ADD, listener.handle, &ev) == -1)
-					throw new Exception("epoll_ctl " ~ to!string(errno));
+				version(Windows) {
+					listener.accept();
+				}
 
 				WorkerThread[] threads = new WorkerThread[](totalCPUs * 1 + 1);
 				foreach(i, ref thread; threads) {
@@ -4919,58 +5160,66 @@ class ListeningConnectionManager {
 					thread = new ConnectionThread(this, handler, cast(int) i);
 					thread.start();
 				}
-			}
 
-			while(!loopBroken && running) {
-				Socket sn;
+				while(!loopBroken && !globalStopFlag) {
+					Socket sn;
 
-				bool crash_check() {
-					bool hasAnyRunning;
-					foreach(thread; threads) {
-						if(!thread.isRunning) {
-							thread.join();
-						} else hasAnyRunning = true;
+					bool crash_check() {
+						bool hasAnyRunning;
+						foreach(thread; threads) {
+							if(!thread.isRunning) {
+								thread.join();
+							} else hasAnyRunning = true;
+						}
+
+						return (!hasAnyRunning);
 					}
 
-					return (!hasAnyRunning);
+
+					void accept_new_connection() {
+						sn = acceptCancelable();
+						if(sn is null) return;
+						cloexec(sn);
+						if(tcp) {
+							// disable Nagle's algorithm to avoid a 40ms delay when we send/recv
+							// on the socket because we do some buffering internally. I think this helps,
+							// certainly does for small requests, and I think it does for larger ones too
+							sn.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, 1);
+
+							sn.setOption(SocketOptionLevel.SOCKET, SocketOption.RCVTIMEO, dur!"seconds"(10));
+						}
+					}
+
+					void existing_connection_new_data() {
+						// wait until a slot opens up
+						//int waited = 0;
+						while(queueLength >= queue.length) {
+							Thread.sleep(1.msecs);
+							//waited ++;
+						}
+						//if(waited) {import std.stdio; writeln(waited);}
+						synchronized(this) {
+							queue[nextIndexBack] = sn;
+							nextIndexBack++;
+							atomicOp!"+="(queueLength, 1);
+						}
+						semaphore.notify();
+					}
+
+
+					accept_new_connection();
+					if(sn !is null)
+						existing_connection_new_data();
+					else if(sn is null && globalStopFlag) {
+						foreach(thread; threads) {
+							semaphore.notify();
+						}
+						Thread.sleep(50.msecs);
+					}
+
+					if(crash_check())
+						break;
 				}
-
-
-				void accept_new_connection() {
-					sn = listener.accept();
-					cloexec(sn);
-					if(tcp) {
-						// disable Nagle's algorithm to avoid a 40ms delay when we send/recv
-						// on the socket because we do some buffering internally. I think this helps,
-						// certainly does for small requests, and I think it does for larger ones too
-						sn.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, 1);
-
-						sn.setOption(SocketOptionLevel.SOCKET, SocketOption.RCVTIMEO, dur!"seconds"(10));
-					}
-				}
-
-				void existing_connection_new_data() {
-					// wait until a slot opens up
-					//int waited = 0;
-					while(queueLength >= queue.length) {
-						Thread.sleep(1.msecs);
-						//waited ++;
-					}
-					//if(waited) {import std.stdio; writeln(waited);}
-					synchronized(this) {
-						queue[nextIndexBack] = sn;
-						nextIndexBack++;
-						atomicOp!"+="(queueLength, 1);
-					}
-					semaphore.notify();
-				}
-
-
-				accept_new_connection();
-				existing_connection_new_data();
-
-				if(crash_check())
-					break;
 			}
 
 			// FIXME: i typically stop this with ctrl+c which never
@@ -5045,7 +5294,14 @@ Socket startListening(string host, ushort port, ref bool tcp, ref void delegate(
 			throw new Exception("abstract unix sockets not supported on this system");
 		}
 	} else {
-		listener = new TcpSocket();
+		version(cgi_use_fiber) {
+			version(Windows)
+				listener = new PseudoblockingOverlappedSocket(AddressFamily.INET, SocketType.STREAM);
+			else
+				listener = new TcpSocket();
+		} else {
+			listener = new TcpSocket();
+		}
 		cloexec(listener);
 		listener.setOption(SocketOptionLevel.SOCKET, SocketOption.REUSEADDR, true);
 		listener.bind(host.length ? parseAddress(host, port) : new InternetAddress(port));
@@ -5213,8 +5469,69 @@ class WorkerThread : Thread {
 		super(&run);
 	}
 
+	version(Windows)
 	void run() {
-		while(lcm.running) {
+		auto timeout = INFINITE;
+		PseudoblockingOverlappedSocket key;
+		OVERLAPPED* overlapped;
+		DWORD bytes;
+		while(!globalStopFlag && GetQueuedCompletionStatus(iocp, &bytes, cast(PULONG_PTR) &key, &overlapped, timeout)) {
+			if(key is null)
+				continue;
+			key.lastAnswer = bytes;
+			if(key.fiber) {
+				key.fiber.proceed();
+			} else {
+				// we have a new connection, issue the first receive on it and issue the next accept
+
+				auto sn = key.accepted;
+
+				key.accept();
+
+				cloexec(sn);
+				if(lcm.tcp) {
+					// disable Nagle's algorithm to avoid a 40ms delay when we send/recv
+					// on the socket because we do some buffering internally. I think this helps,
+					// certainly does for small requests, and I think it does for larger ones too
+					sn.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, 1);
+
+					sn.setOption(SocketOptionLevel.SOCKET, SocketOption.RCVTIMEO, dur!"seconds"(10));
+				}
+
+				dg(sn);
+			}
+		}
+		//SleepEx(INFINITE, TRUE);
+	}
+
+	version(linux)
+	void run() {
+
+		import core.sys.linux.epoll;
+		epfd = epoll_create1(EPOLL_CLOEXEC);
+		if(epfd == -1)
+			throw new Exception("epoll_create1 " ~ to!string(errno));
+		scope(exit) {
+			import core.sys.posix.unistd;
+			close(epfd);
+		}
+
+		{
+			epoll_event ev;
+			ev.events = EPOLLIN;
+			ev.data.fd = cancelfd;
+			epoll_ctl(epfd, EPOLL_CTL_ADD, cancelfd, &ev);
+		}
+
+		epoll_event ev;
+		ev.events = EPOLLIN | EPOLLEXCLUSIVE; // EPOLLEXCLUSIVE is only available on kernels since like 2017 but that's prolly good enough.
+		ev.data.fd = lcm.listener.handle;
+		if(epoll_ctl(epfd, EPOLL_CTL_ADD, lcm.listener.handle, &ev) == -1)
+			throw new Exception("epoll_ctl " ~ to!string(errno));
+
+
+
+		while(!globalStopFlag) {
 			Socket sn;
 
 			epoll_event[64] events;
@@ -5228,18 +5545,19 @@ class WorkerThread : Thread {
 			foreach(idx; 0 .. nfds) {
 				auto flags = events[idx].events;
 
-				if(cast(size_t) events[idx].data.ptr == cast(size_t) lcm.listener.handle) {
+				if(cast(size_t) events[idx].data.ptr == cast(size_t) cancelfd) {
+					globalStopFlag = true;
+					//import std.stdio; writeln("exit heard");
+					break;
+				} else if(cast(size_t) events[idx].data.ptr == cast(size_t) lcm.listener.handle) {
+					//import std.stdio; writeln(myThreadNumber, " woken up ", flags);
 					// this try/catch is because it is set to non-blocking mode
 					// and Phobos' stupid api throws an exception instead of returning
 					// if it would block. Why would it block? because a forked process
 					// might have beat us to it, but the wakeup event thundered our herds.
-					version(cgi_use_fork) {
 						try
-						sn = lcm.listener.accept();
+						sn = lcm.listener.accept(); // don't need to do the acceptCancelable here since the epoll checks it better
 						catch(SocketAcceptException e) { continue; }
-					} else {
-						sn = lcm.listener.accept();
-					}
 
 					cloexec(sn);
 					if(lcm.tcp) {