This commit is contained in:
Adam D. Ruppe 2015-11-22 14:19:19 -05:00
parent 09fbeea89a
commit 048a3b39d7
4 changed files with 138 additions and 126 deletions

View File

@ -3,7 +3,23 @@ module arsd.database;
public import std.variant;
import std.string;
import core.vararg;
/*
Database 2.0 plan, WIP:
// Do I want to do some kind of RAII?
auto database = Database(new MySql("connection info"));
* Prepared statement support
* Queries with separate args whenever we can with consistent interface
* Query returns some typed info when we can.
* ....?
PreparedStatement prepareStatement(string sql);
Might be worth looking at doing the preparations in static ctors
so they are always done once per program...
*/
interface Database {
/// Actually implements the query for the database. The query() method
@ -16,11 +32,7 @@ interface Database {
/// query to start a transaction, only here because sqlite is apparently different in syntax...
void startTransaction();
// FIXME: this would be better as a template, but can't because it is an interface
/// Just executes a query. It supports placeholders for parameters
/// by using ? in the sql string. NOTE: it only accepts string, int, long, and null types.
/// Others will fail runtime asserts.
final ResultSet query(T...)(string sql, T t) {
Variant[] args;
foreach(arg; t) {
@ -33,37 +45,12 @@ interface Database {
}
return queryImpl(sql, args);
}
version(none)
final ResultSet query(string sql, ...) {
Variant[] args;
foreach(arg; _arguments) {
string a;
if(arg == typeid(string) || arg == typeid(immutable(string)) || arg == typeid(const(string)))
a = va_arg!string(_argptr);
else if (arg == typeid(int) || arg == typeid(immutable(int)) || arg == typeid(const(int))) {
auto e = va_arg!int(_argptr);
a = to!string(e);
} else if (arg == typeid(uint) || arg == typeid(immutable(uint)) || arg == typeid(const(uint))) {
auto e = va_arg!uint(_argptr);
a = to!string(e);
} else if (arg == typeid(immutable(char))) {
auto e = va_arg!char(_argptr);
a = to!string(e);
} else if (arg == typeid(long) || arg == typeid(const(long)) || arg == typeid(immutable(long))) {
auto e = va_arg!long(_argptr);
a = to!string(e);
} else if (arg == typeid(ulong) || arg == typeid(const(ulong)) || arg == typeid(immutable(ulong))) {
auto e = va_arg!ulong(_argptr);
a = to!string(e);
} else if (arg == typeid(null)) {
a = null;
} else assert(0, "invalid type " ~ arg.toString() );
args ~= Variant(a);
}
/// Prepared statement api
/*
PreparedStatement prepareStatement(string sql, int numberOfArguments);
return queryImpl(sql, args);
}
*/
}
import std.stdio;
@ -710,49 +697,6 @@ class DataObject {
}
// vararg hack so property assignment works right, even with null
version(none)
string opDispatch(string field, string file = __FILE__, size_t line = __LINE__)(...)
if((field.length < 8 || field[0..8] != "id_from_") && field != "popFront")
{
if(_arguments.length == 0) {
if(field !in fields)
throw new Exception("no such field " ~ field, file, line);
return fields[field];
} else if(_arguments.length == 1) {
auto arg = _arguments[0];
string a;
if(arg == typeid(string) || arg == typeid(immutable(string)) || arg == typeid(const(immutable(char)[]))) {
a = va_arg!(string)(_argptr);
} else if (arg == typeid(int) || arg == typeid(immutable(int)) || arg == typeid(const(int))) {
auto e = va_arg!(int)(_argptr);
a = to!string(e);
} else if (arg == typeid(char) || arg == typeid(immutable(char))) {
auto e = va_arg!(char)(_argptr);
a = to!string(e);
} else if (arg == typeid(uint) || arg == typeid(immutable(uint)) || arg == typeid(const(uint))) {
auto e = va_arg!uint(_argptr);
a = to!string(e);
} else if (arg == typeid(long) || arg == typeid(const(long)) || arg == typeid(immutable(long))) {
auto e = va_arg!(long)(_argptr);
a = to!string(e);
} else if (arg == typeid(null)) {
a = null;
} else assert(0, "invalid type " ~ arg.toString );
auto setTo = a;
setImpl(field, setTo);
return setTo;
} else assert(0, "too many arguments");
assert(0); // should never be reached
}
private void setImpl(string field, string value) {
if(field in fields) {
if(fields[field] != value)

View File

@ -665,7 +665,7 @@ class HttpRequest {
case 2: // reading data
auto can = a + bodyReadingState.contentLengthRemaining;
if(can > data.length)
can = data.length;
can = cast(int) data.length;
//if(bodyReadingState.isGzipped || bodyReadingState.isDeflated)
// responseData.content ~= cast(ubyte[]) uncompress.uncompress(data[a .. can]);

View File

@ -7,6 +7,11 @@ import std.string;
import std.exception;
// remember to CREATE DATABASE name WITH ENCODING 'utf8'
//
// http://www.postgresql.org/docs/8.0/static/libpq-exec.html
// ExecParams, PQPrepare, PQExecPrepared
//
// SQL: `DEALLOCATE name` is how to dealloc a prepared statement.
class PostgreSql : Database {
// dbname = name is probably the most common connection string
@ -23,6 +28,36 @@ class PostgreSql : Database {
PQfinish(conn);
}
/*
Prepared statement support
This will be added to the Database interface eventually in some form,
but first I need to implement it for all my providers.
The common function of those 4 will be what I put in the interface.
*/
ResultSet executePreparedStatement(T...)(string name, T args) {
char*[args.length] argsStrings;
foreach(idx, arg; args) {
// FIXME: optimize to remove allocations here
static if(!is(typeof(arg) == typeof(null)))
argsStrings[idx] = toStringz(to!string(arg));
// else make it null
}
auto res = PQexecPrepared(conn, toStringz(name), argsStrings.length, argStrings.ptr, 0, null, 0);
int ress = PQresultStatus(res);
if(ress != PGRES_TUPLES_OK
&& ress != PGRES_COMMAND_OK)
throw new DatabaseException(error());
return new PostgresResult(res);
}
override void startTransaction() {
query("START TRANSACTION");
}
@ -183,6 +218,10 @@ extern(C) {
PGresult* PQexec(PGconn*, const char*);
void PQclear(PGresult*);
PGresult* PQprepare(PGconn*, const char* stmtName, const char* query, int nParams, const void* paramTypes);
PGresult* PQexecPrepared(PGconn*, const char* stmtName, int nParams, const char** paramValues, const int* paramLengths, const int* paramFormats, int resultFormat);
int PQresultStatus(PGresult*); // FIXME check return value
int PQnfields(PGresult*); // number of fields in a result

View File

@ -21,69 +21,98 @@ public import std.socket;
// see also:
// http://msdn.microsoft.com/en-us/library/aa380536%28v=vs.85%29.aspx
import deimos.openssl.ssl;
// import deimos.openssl.ssl;
static this() {
SSL_library_init();
OpenSSL_add_all_algorithms();
SSL_load_error_strings();
}
version=use_openssl;
pragma(lib, "crypto");
pragma(lib, "ssl");
version(use_openssl) {
alias SslClientSocket = OpenSslSocket;
class OpenSslSocket : Socket {
private SSL* ssl;
private SSL_CTX* ctx;
private void initSsl() {
ctx = SSL_CTX_new(SSLv3_client_method());
assert(ctx !is null);
extern(C) {
int SSL_library_init();
void OpenSSL_add_all_ciphers();
void OpenSSL_add_all_digests();
void SSL_load_error_strings();
ssl = SSL_new(ctx);
SSL_set_fd(ssl, this.handle);
struct SSL {}
struct SSL_CTX {}
struct SSL_METHOD {}
SSL_CTX* SSL_CTX_new(const SSL_METHOD* method);
SSL* SSL_new(SSL_CTX*);
int SSL_set_fd(SSL*, int);
int SSL_connect(SSL*);
int SSL_write(SSL*, const void*, int);
int SSL_read(SSL*, void*, int);
void SSL_free(SSL*);
void SSL_CTX_free(SSL_CTX*);
SSL_METHOD* SSLv3_client_method();
}
@trusted
override void connect(Address to) {
super.connect(to);
if(SSL_connect(ssl) == -1)
throw new Exception("ssl connect");
}
@trusted
override ptrdiff_t send(const(void)[] buf, SocketFlags flags) {
return SSL_write(ssl, buf.ptr, cast(uint) buf.length);
}
override ptrdiff_t send(const(void)[] buf) {
return send(buf, SocketFlags.NONE);
}
@trusted
override ptrdiff_t receive(void[] buf, SocketFlags flags) {
return SSL_read(ssl, buf.ptr, cast(int)buf.length);
}
override ptrdiff_t receive(void[] buf) {
return receive(buf, SocketFlags.NONE);
shared static this() {
SSL_library_init();
OpenSSL_add_all_ciphers();
OpenSSL_add_all_digests();
SSL_load_error_strings();
}
this(AddressFamily af) {
super(af, SocketType.STREAM);
initSsl();
}
pragma(lib, "crypto");
pragma(lib, "ssl");
this(socket_t sock, AddressFamily af) {
super(sock, af);
initSsl();
}
class OpenSslSocket : Socket {
private SSL* ssl;
private SSL_CTX* ctx;
private void initSsl() {
ctx = SSL_CTX_new(SSLv3_client_method());
assert(ctx !is null);
~this() {
SSL_free(ssl);
SSL_CTX_free(ctx);
ssl = SSL_new(ctx);
SSL_set_fd(ssl, this.handle);
}
@trusted
override void connect(Address to) {
super.connect(to);
if(SSL_connect(ssl) == -1)
throw new Exception("ssl connect");
}
@trusted
override ptrdiff_t send(const(void)[] buf, SocketFlags flags) {
return SSL_write(ssl, buf.ptr, cast(uint) buf.length);
}
override ptrdiff_t send(const(void)[] buf) {
return send(buf, SocketFlags.NONE);
}
@trusted
override ptrdiff_t receive(void[] buf, SocketFlags flags) {
return SSL_read(ssl, buf.ptr, cast(int)buf.length);
}
override ptrdiff_t receive(void[] buf) {
return receive(buf, SocketFlags.NONE);
}
this(AddressFamily af, SocketType type = SocketType.STREAM) {
super(af, type);
initSsl();
}
this(socket_t sock, AddressFamily af) {
super(sock, af);
initSsl();
}
~this() {
SSL_free(ssl);
SSL_CTX_free(ctx);
}
}
}
version(ssl_test)
void main() {
auto sock = new OpenSslSocket(AddressFamily.INET);
auto sock = new SslClientSocket(AddressFamily.INET);
sock.connect(new InternetAddress("localhost", 443));
sock.send("GET / HTTP/1.0\r\n\r\n");
import std.stdio;