From 2e37fa70c3f565c6bbe04cf5735b2a8a565fd787 Mon Sep 17 00:00:00 2001
From: "Adam D. Ruppe" <destructionator@gmail.com>
Date: Thu, 14 Apr 2022 08:41:24 -0400
Subject: [PATCH] simplify adding dynamic load functions

---
 http2.d | 294 +++++++++++++++++---------------------------------------
 1 file changed, 87 insertions(+), 207 deletions(-)

diff --git a/http2.d b/http2.d
index 10982dc..45fdf1d 100644
--- a/http2.d
+++ b/http2.d
@@ -3031,14 +3031,7 @@ void main() {
 version(use_openssl) {
 	alias SslClientSocket = OpenSslSocket;
 
-	// macros in the original C
-	SSL_METHOD* SSLv23_client_method() {
-		if(ossllib.SSLv23_client_method)
-			return ossllib.SSLv23_client_method();
-		else
-			return ossllib.TLS_client_method();
-	}
-
+	// CRL = Certificate Revocation List
 	static immutable string[] sslErrorCodes = [
 		"OK (code 0)",
 		"Unspecified SSL/TLS error (code 1)",
@@ -3086,6 +3079,20 @@ version(use_openssl) {
 	enum SSL_VERIFY_NONE = 0;
 	enum SSL_VERIFY_PEER = 1;
 
+	// copy it into the buf[0 .. size] and return actual length you read.
+	// rwflag == 0 when reading, 1 when writing.
+	extern(C) alias pem_password_cb = int function(char* buffer, int bufferSize, int rwflag, void* userPointer);
+	extern(C) alias print_errors_cb = int function(const char*, size_t, void*);
+	extern(C) alias client_cert_cb = int function(SSL *ssl, X509 **x509, EVP_PKEY **pkey);
+	extern(C) alias keylog_cb = void function(SSL*, char*);
+
+	struct X509;
+	struct X509_STORE;
+	struct EVP_PKEY;
+	struct X509_VERIFY_PARAM;
+
+	import core.stdc.config;
+
 	struct ossllib {
 		__gshared static extern(C) {
 			/* these are only on older openssl versions { */
@@ -3129,22 +3136,10 @@ version(use_openssl) {
 			SSL_CTX_set_client_CA_list
 			+/
 
-
 			// client cert things
 			void function (SSL_CTX *ctx, int function(SSL *ssl, X509 **x509, EVP_PKEY **pkey)) SSL_CTX_set_client_cert_cb;
 		}
 	}
-	// copy it into the buf[0 .. size] and return actual length you read.
-	// rwflag == 0 when reading, 1 when writing.
-	extern(C)
-	alias pem_password_cb = int function(char* buffer, int bufferSize, int rwflag, void* userPointer);
-
-	struct X509;
-	struct X509_STORE;
-	struct EVP_PKEY;
-	struct X509_VERIFY_PARAM;
-
-	import core.stdc.config;
 
 	struct eallib {
 		__gshared static extern(C) {
@@ -3173,161 +3168,51 @@ version(use_openssl) {
 		}
 	}
 
-	extern(C)
-	alias print_errors_cb = int function(const char*, size_t, void*);
+	struct OpenSSL {
+		static:
 
-	int SSL_CTX_set_default_verify_paths(SSL_CTX* a) {
-		if(ossllib.SSL_CTX_set_default_verify_paths)
-			return ossllib.SSL_CTX_set_default_verify_paths(a);
-		else throw new Exception("SSL_CTX_set_default_verify_paths not loaded");
-	}
+		template opDispatch(string name) {
+			auto opDispatch(T...)(T t) {
+				static if(__traits(hasMember, ossllib, name)) {
+					auto ptr = __traits(getMember, ossllib, name);
+				} else static if(__traits(hasMember, eallib, name)) {
+					auto ptr = __traits(getMember, eallib, name);
+				} else static assert(0);
 
-	c_long SSL_get_verify_result(const SSL* ssl) {
-		if(ossllib.SSL_get_verify_result)
-			return ossllib.SSL_get_verify_result(ssl);
-		else throw new Exception("SSL_get_verify_result not loaded");
-	}
+				if(ptr is null)
+					throw new Exception(name ~ " not loaded");
+				return ptr(t);
+			}
+		}
 
-	X509_VERIFY_PARAM* SSL_get0_param(const SSL* ssl) {
-		if(ossllib.SSL_get0_param)
-			return ossllib.SSL_get0_param(ssl);
-		else throw new Exception("SSL_get0_param not loaded");
-	}
+		// macros in the original C
+		SSL_METHOD* SSLv23_client_method() {
+			if(ossllib.SSLv23_client_method)
+				return ossllib.SSLv23_client_method();
+			else
+				return ossllib.TLS_client_method();
+		}
 
-	X509_STORE* SSL_CTX_get_cert_store(SSL_CTX* a) {
-		if(ossllib.SSL_CTX_get_cert_store)
-			return ossllib.SSL_CTX_get_cert_store(a);
-		else throw new Exception("SSL_CTX_get_cert_store not loaded");
-	}
+		void SSL_set_tlsext_host_name(SSL* a, const char* b) {
+			if(ossllib.SSL_ctrl)
+				return ossllib.SSL_ctrl(a, 55 /*SSL_CTRL_SET_TLSEXT_HOSTNAME*/, 0 /*TLSEXT_NAMETYPE_host_name*/, cast(void*) b);
+			else throw new Exception("SSL_set_tlsext_host_name not loaded");
+		}
 
-	SSL_CTX* SSL_CTX_new(const SSL_METHOD* a) {
-		if(ossllib.SSL_CTX_new)
-			return ossllib.SSL_CTX_new(a);
-		else throw new Exception("SSL_CTX_new not loaded");
-	}
-	SSL* SSL_new(SSL_CTX* a) {
-		if(ossllib.SSL_new)
-			return ossllib.SSL_new(a);
-		else throw new Exception("SSL_new not loaded");
-	}
-	int SSL_set_fd(SSL* a, int b) {
-		if(ossllib.SSL_set_fd)
-			return ossllib.SSL_set_fd(a, b);
-		else throw new Exception("SSL_set_fd not loaded");
-	}
+		// special case
+		@trusted nothrow @nogc int SSL_shutdown(SSL* a) {
+			if(ossllib.SSL_shutdown)
+				return ossllib.SSL_shutdown(a);
+			assert(0);
+		}
 
-	extern(C)
-	alias client_cert_cb = int function(SSL *ssl, X509 **x509, EVP_PKEY **pkey);
+		void SSL_CTX_keylog_cb_func(SSL_CTX* ctx, keylog_cb func) {
+			// this isn't in openssl 1.0 and is non-essential, so it is allowed to fail.
+			if(ossllib.SSL_CTX_set_keylog_callback)
+				ossllib.SSL_CTX_set_keylog_callback(ctx, func);
+			//else throw new Exception("SSL_CTX_keylog_cb_func not loaded");
+		}
 
-	void SSL_CTX_set_client_cert_cb(SSL_CTX *ctx, client_cert_cb cb) {
-		if(ossllib.SSL_CTX_set_client_cert_cb)
-			return ossllib.SSL_CTX_set_client_cert_cb(ctx, cb);
-		else throw new Exception("SSL_CTX_set_client_cert_cb not loaded");
-	}
-
-	X509* d2i_X509(X509** a, const(ubyte*)* pp, c_long length) {
-		if(eallib.d2i_X509)
-			return eallib.d2i_X509(a, pp, length);
-		else throw new Exception("d2i_X509 not loaded");
-	}
-
-	X509* PEM_read_X509(FILE *fp, X509 **x, pem_password_cb *cb, void *u) {
-		if(eallib.PEM_read_X509)
-			return eallib.PEM_read_X509(fp, x, cb, u);
-		else throw new Exception("PEM_read_X509 not loaded");
-	}
-	EVP_PKEY* PEM_read_PrivateKey(FILE *fp, EVP_PKEY **x, pem_password_cb *cb, void *u) {
-		if(eallib.PEM_read_PrivateKey)
-			return eallib.PEM_read_PrivateKey(fp, x, cb, u);
-		else throw new Exception("PEM_read_PrivateKey not loaded");
-	}
-
-	EVP_PKEY* d2i_PrivateKey_fp(FILE *fp, EVP_PKEY **a) {
-		if(eallib.d2i_PrivateKey_fp)
-			return eallib.d2i_PrivateKey_fp(fp, a);
-		else throw new Exception("d2i_PrivateKey_fp not loaded");
-	}
-	X509* d2i_X509_fp(FILE *fp, X509 **x) {
-		if(eallib.d2i_X509_fp)
-			return eallib.d2i_X509_fp(fp, x);
-		else throw new Exception("d2i_X509_fp not loaded");
-	}
-
-	int SSL_connect(SSL* a) {
-		if(ossllib.SSL_connect)
-			return ossllib.SSL_connect(a);
-		else throw new Exception("SSL_connect not loaded");
-	}
-	int SSL_write(SSL* a, const void* b, int c) {
-		if(ossllib.SSL_write)
-			return ossllib.SSL_write(a, b, c);
-		else throw new Exception("SSL_write not loaded");
-	}
-	int SSL_read(SSL* a, void* b, int c) {
-		if(ossllib.SSL_read)
-			return ossllib.SSL_read(a, b, c);
-		else throw new Exception("SSL_read not loaded");
-	}
-	@trusted nothrow @nogc int SSL_shutdown(SSL* a) {
-		if(ossllib.SSL_shutdown)
-			return ossllib.SSL_shutdown(a);
-		assert(0);
-	}
-	void SSL_free(SSL* a) {
-		if(ossllib.SSL_free)
-			return ossllib.SSL_free(a);
-		else throw new Exception("SSL_free not loaded");
-	}
-	void SSL_CTX_free(SSL_CTX* a) {
-		if(ossllib.SSL_CTX_free)
-			return ossllib.SSL_CTX_free(a);
-		else throw new Exception("SSL_CTX_free not loaded");
-	}
-
-	int SSL_pending(const SSL* a) {
-		if(ossllib.SSL_pending)
-			return ossllib.SSL_pending(a);
-		else throw new Exception("SSL_pending not loaded");
-	}
-	void SSL_set_verify(SSL* a, int b, void* c) {
-		if(ossllib.SSL_set_verify)
-			return ossllib.SSL_set_verify(a, b, c);
-		else throw new Exception("SSL_set_verify not loaded");
-	}
-	void SSL_set_tlsext_host_name(SSL* a, const char* b) {
-		if(ossllib.SSL_ctrl)
-			return ossllib.SSL_ctrl(a, 55 /*SSL_CTRL_SET_TLSEXT_HOSTNAME*/, 0 /*TLSEXT_NAMETYPE_host_name*/, cast(void*) b);
-		else throw new Exception("SSL_set_tlsext_host_name not loaded");
-	}
-	int X509_VERIFY_PARAM_set1_host(X509_VERIFY_PARAM* a, const char* b, size_t l) {
-		if(eallib.X509_VERIFY_PARAM_set1_host)
-			return eallib.X509_VERIFY_PARAM_set1_host(a, b, l);
-		else throw new Exception("X509_VERIFY_PARAM_set1_host not loaded");
-	}
-	SSL_METHOD* SSLv3_client_method() {
-		if(ossllib.SSLv3_client_method)
-			return ossllib.SSLv3_client_method();
-		else throw new Exception("SSLv3_client_method not loaded");
-	}
-	SSL_METHOD* TLS_client_method() {
-		if(ossllib.TLS_client_method)
-			return ossllib.TLS_client_method();
-		else throw new Exception("TLS_client_method not loaded");
-	}
-	void ERR_print_errors_cb(print_errors_cb cb, void* u) {
-		if(eallib.ERR_print_errors_cb)
-			return eallib.ERR_print_errors_cb(cb, u);
-		else throw new Exception("ERR_print_errors_cb not loaded");
-	}
-	void X509_free(X509* x) {
-		if(eallib.X509_free)
-			return eallib.X509_free(x);
-		else throw new Exception("X509_free not loaded");
-	}
-	int X509_STORE_add_cert(X509_STORE* s, X509* x) {
-		if(eallib.X509_STORE_add_cert)
-			return eallib.X509_STORE_add_cert(s, x);
-		else throw new Exception("X509_STORE_add_cert not loaded");
 	}
 
 	extern(C)
@@ -3339,15 +3224,6 @@ version(use_openssl) {
 		return 0;
 	}
 
-	extern(C)
-	void SSL_CTX_keylog_cb_func(SSL_CTX* ctx, void function(SSL*, char*) func)
-	{
-		// this isn't in openssl 1.0 and is non-essential, so it is allowed to fail.
-		if(ossllib.SSL_CTX_set_keylog_callback)
-			ossllib.SSL_CTX_set_keylog_callback(ctx, func);
-		//else throw new Exception("SSL_CTX_keylog_cb_func not loaded");
-	}
-
 
 	private __gshared void* ossllib_handle;
 	version(Windows)
@@ -3475,7 +3351,7 @@ version(use_openssl) {
 		string logfile = environment.get("SSLKEYLOGFILE");
 		if (logfile !is null)
 		{
-			auto f = std.stdio.File("/tmp/keyfile", "a+");
+			auto f = std.stdio.File(logfile, "a+");
 			f.writeln(fromStringz(line));
 			f.close();
 		}
@@ -3485,31 +3361,31 @@ version(use_openssl) {
 		private SSL* ssl;
 		private SSL_CTX* ctx;
 		private void initSsl(bool verifyPeer, string hostname) {
-			ctx = SSL_CTX_new(SSLv23_client_method());
+			ctx = OpenSSL.SSL_CTX_new(OpenSSL.SSLv23_client_method());
 			assert(ctx !is null);
 
-			SSL_CTX_set_default_verify_paths(ctx);
+			OpenSSL.SSL_CTX_set_default_verify_paths(ctx);
 			version(Windows)
 				loadCertificatesFromRegistry(ctx);
 
-			debug SSL_CTX_keylog_cb_func(ctx, &write_to_file);
-			ssl = SSL_new(ctx);
+			debug OpenSSL.SSL_CTX_keylog_cb_func(ctx, &write_to_file);
+			ssl = OpenSSL.SSL_new(ctx);
 
 			if(hostname.length) {
-				SSL_set_tlsext_host_name(ssl, toStringz(hostname));
+				OpenSSL.SSL_set_tlsext_host_name(ssl, toStringz(hostname));
 				if(verifyPeer)
-					X509_VERIFY_PARAM_set1_host(SSL_get0_param(ssl), hostname.ptr, hostname.length);
+					OpenSSL.X509_VERIFY_PARAM_set1_host(OpenSSL.SSL_get0_param(ssl), hostname.ptr, hostname.length);
 			}
 
 			if(verifyPeer)
-				SSL_set_verify(ssl, SSL_VERIFY_PEER, null);
+				OpenSSL.SSL_set_verify(ssl, SSL_VERIFY_PEER, null);
 			else
-				SSL_set_verify(ssl, SSL_VERIFY_NONE, null);
+				OpenSSL.SSL_set_verify(ssl, SSL_VERIFY_NONE, null);
 
-			SSL_set_fd(ssl, cast(int) this.handle); // on win64 it is necessary to truncate, but the value is never large anyway see http://openssl.6102.n7.nabble.com/Sockets-windows-64-bit-td36169.html
+			OpenSSL.SSL_set_fd(ssl, cast(int) this.handle); // on win64 it is necessary to truncate, but the value is never large anyway see http://openssl.6102.n7.nabble.com/Sockets-windows-64-bit-td36169.html
 
 
-			SSL_CTX_set_client_cert_cb(ctx, &cb);
+			OpenSSL.SSL_CTX_set_client_cert_cb(ctx, &cb);
 		}
 
 		extern(C)
@@ -3534,12 +3410,12 @@ version(use_openssl) {
 						else
 							goto case der;
 					case pem:
-						*x509 = PEM_read_X509(fpCert, null, null, null);
-						*pkey = PEM_read_PrivateKey(fpKey, null, null, null);
+						*x509 = OpenSSL.PEM_read_X509(fpCert, null, null, null);
+						*pkey = OpenSSL.PEM_read_PrivateKey(fpKey, null, null, null);
 					break;
 					case der:
-						*x509 = d2i_X509_fp(fpCert, null);
-						*pkey = d2i_PrivateKey_fp(fpKey, null);
+						*x509 = OpenSSL.d2i_X509_fp(fpCert, null);
+						*pkey = OpenSSL.d2i_PrivateKey_fp(fpKey, null);
 					break;
 				}
 
@@ -3550,7 +3426,7 @@ version(use_openssl) {
 		}
 
 		bool dataPending() {
-			return SSL_pending(ssl) > 0;
+			return OpenSSL.SSL_pending(ssl) > 0;
 		}
 
 		@trusted
@@ -3561,11 +3437,11 @@ version(use_openssl) {
 
 		@trusted
 		void do_ssl_connect() {
-			if(SSL_connect(ssl) == -1) {
+			if(OpenSSL.SSL_connect(ssl) == -1) {
 				string str;
-				ERR_print_errors_cb(&collectSslErrors, &str);
+				OpenSSL.ERR_print_errors_cb(&collectSslErrors, &str);
 				int i;
-				auto err = SSL_get_verify_result(ssl);
+				auto err = OpenSSL.SSL_get_verify_result(ssl);
 				//printf("wtf\n");
 				//scanf("%d\n", i);
 				throw new Exception("Secure connect failed: " ~ getOpenSslErrorCode(err));
@@ -3576,10 +3452,10 @@ version(use_openssl) {
 		override ptrdiff_t send(scope const(void)[] buf, SocketFlags flags) {
 		//import std.stdio;writeln(cast(string) buf);
 			debug(arsd_http2_verbose) writeln("ssl writing ", buf.length);
-			auto retval = SSL_write(ssl, buf.ptr, cast(uint) buf.length);
+			auto retval = OpenSSL.SSL_write(ssl, buf.ptr, cast(uint) buf.length);
 			if(retval == -1) {
 				string str;
-				ERR_print_errors_cb(&collectSslErrors, &str);
+				OpenSSL.ERR_print_errors_cb(&collectSslErrors, &str);
 				int i;
 				//printf("wtf\n");
 				//scanf("%d\n", i);
@@ -3595,11 +3471,11 @@ version(use_openssl) {
 		override ptrdiff_t receive(scope void[] buf, SocketFlags flags) {
 
 			debug(arsd_http2_verbose) writeln("ssl_read before");
-			auto retval = SSL_read(ssl, buf.ptr, cast(int)buf.length);
+			auto retval = OpenSSL.SSL_read(ssl, buf.ptr, cast(int)buf.length);
 			debug(arsd_http2_verbose) writeln("ssl_read after");
 			if(retval == -1) {
 				string str;
-				ERR_print_errors_cb(&collectSslErrors, &str);
+				OpenSSL.ERR_print_errors_cb(&collectSslErrors, &str);
 				int i;
 				//printf("wtf\n");
 				//scanf("%d\n", i);
@@ -3617,7 +3493,7 @@ version(use_openssl) {
 		}
 
 		override void close() {
-			if(ssl) SSL_shutdown(ssl);
+			if(ssl) OpenSSL.SSL_shutdown(ssl);
 			super.close();
 		}
 
@@ -3629,8 +3505,8 @@ version(use_openssl) {
 		void freeSsl() {
 			if(ssl is null)
 				return;
-			SSL_free(ssl);
-			SSL_CTX_free(ctx);
+			OpenSSL.SSL_free(ssl);
+			OpenSSL.SSL_CTX_free(ctx);
 			ssl = null;
 		}
 
@@ -5136,7 +5012,7 @@ version(Windows) {
 		scope(exit)
 			CertCloseStore(store, 0);
 
-		X509_STORE* ssl_store = SSL_CTX_get_cert_store(ctx);
+		X509_STORE* ssl_store = OpenSSL.SSL_CTX_get_cert_store(ctx);
 		PCCERT_CONTEXT c;
 		while((c = CertEnumCertificatesInStore(store, c)) !is null) {
 			FILETIME na = c.pCertInfo.NotAfter;
@@ -5157,10 +5033,14 @@ version(Windows) {
 			}
 
 			const(ubyte)* thing = c.pbCertEncoded;
-			auto x509 = d2i_X509(null, &thing, c.cbCertEncoded);
+			auto x509 = OpenSSL.d2i_X509(null, &thing, c.cbCertEncoded);
 			if (x509) {
-				auto success = X509_STORE_add_cert(ssl_store, x509);
-				X509_free(x509);
+				auto success = OpenSSL.X509_STORE_add_cert(ssl_store, x509);
+				//if(!success)
+					//writeln("FAILED HERE");
+				OpenSSL.X509_free(x509);
+			} else {
+				//writeln("FAILED");
 			}
 		}