diff --git a/rpc.d b/rpc.d new file mode 100644 index 0000000..83a9f3d --- /dev/null +++ b/rpc.d @@ -0,0 +1,445 @@ +module arsd.rpc; + +/+ //example usage +interface ExampleNetworkFunctions { + string sayHello(string name); + int add(int a, int b); + S2 structTest(S1); + void die(); +} + +// the server must implement the interface +class ExampleServer : ExampleNetworkFunctions { + override string sayHello(string name) { + return "Hello, " ~ name; + } + + override int add(int a, int b) { + return a+b; + } + + override S2 structTest(S1 a) { + return S2(a.name, a.number); + } + + override void die() { + throw new Exception("death requested"); + } + + mixin NetworkServer!ExampleNetworkFunctions; +} + +struct S1 { + int number; + string name; +} + +struct S2 { + string name; + int number; +} + +import std.stdio; +void main(string[] args) { + if(args.length > 1) { + auto client = makeNetworkClient!ExampleNetworkFunctions("localhost", 5005); + // these work like the interface above, but instead of returning the value, + // they take callbacks for success (where the arg is the retval) + // and failure (the arg is the exception) + client.sayHello("whoa", (a) { writeln(a); }, null); + client.add(1,2, (a) { writeln(a); }, null); + client.add(10,20, (a) { writeln(a); }, null); + client.structTest(S1(20, "cool!"), (a) { writeln(a.name, " -- ", a.number); }, null); + client.die(delegate () { writeln("shouldn't happen"); }, delegate(a) { writeln(a); }); + client.eventLoop(); + } else { + auto server = new ExampleServer(5005); + server.eventLoop(); + } +} ++/ + +mixin template NetworkServer(Interface) { + import std.socket; + private Socket socket; + public this(ushort port) { + socket = new TcpSocket(); + socket.setOption(SocketOptionLevel.SOCKET, SocketOption.REUSEADDR, true); + socket.bind(new InternetAddress(port)); + socket.listen(16); + } + + final public void eventLoop() { + auto check = new SocketSet(); + Socket[] connections; + connections.reserve(16); + ubyte[4096] buffer; + + while(true) { + check.reset(); + check.add(socket); + foreach(connection; connections) { + check.add(connection); + } + + if(Socket.select(check, null, null)) { + if(check.isSet(socket)) { + connections ~= socket.accept(); + } + + foreach(connection; connections) { + if(check.isSet(connection)) { + auto gotNum = connection.receive(buffer); + if(gotNum == 0) { + // connection is closed, we could remove it from the list + } else { + auto got = buffer[0 .. gotNum]; + another: + int length, functionNumber, sequenceNumber; + got = deserializeInto(got, length); + got = deserializeInto(got, functionNumber); + got = deserializeInto(got, sequenceNumber); + + //writeln("got ", sequenceNumber, " calling ", functionNumber); + + auto remaining = got[length .. $]; + got = got[0 .. length]; + import std.conv; + assert(length == got.length, to!string(length) ~ " != " ~ to!string(got.length)); // FIXME: what if it doesn't all come at once? + callByNumber(functionNumber, sequenceNumber, got, connection); + + if(remaining.length) { + got = remaining; + goto another; + } + } + } + } + } + } + } + + final private void callByNumber(int functionNumber, int sequenceNumber, const(ubyte)[] buffer, Socket connection) { + ubyte[4096] sendBuffer; + int length = 12; + // length, sequence, success + serialize(sendBuffer[4 .. 8], sequenceNumber); + string callCode() { + import std.conv; + import std.traits; + string code; + foreach(memIdx, member; __traits(allMembers, Interface)) { + code ~= "\t\tcase " ~ to!string(memIdx + 1) ~ ":\n"; + alias mem = PassThrough!(__traits(getMember, Interface, member)); + // we need to deserialize the arguments, call the function, and send back the response (if there is one) + string argsString; + foreach(i, arg; ParameterTypeTuple!mem) { + if(i) + argsString ~= ", "; + auto istr = to!string(i); + code ~= "\t\t\t" ~ arg.stringof ~ " arg" ~ istr ~ ";\n"; + code ~= "\t\t\tbuffer = deserializeInto(buffer, arg" ~ istr ~ ");\n"; + + argsString ~= "arg" ~ istr; + } + + // the call + static if(is(ReturnType!mem == void)) { + code ~= "\n\t\t\t" ~ member ~ "(" ~ argsString ~ ");\n"; + } else { + // call and return answer + code ~= "\n\t\t\tauto ret = " ~ member ~ "(" ~ argsString ~ ");\n"; + + code ~= "\t\t\tserialize(sendBuffer[8 .. 12], cast(int) 1);\n"; // yes success + code ~= "\t\t\tauto serialized = serialize(sendBuffer[12 .. $], ret);\n"; + code ~= "\t\t\tserialize(sendBuffer[0 .. 4], cast(int) serialized.length);\n"; + code ~= "\t\t\tlength += serialized.length;\n"; + } + code ~= "\t\tbreak;\n"; + } + return code; + } + + try { + switch(functionNumber) { + default: assert(0, "unknown function"); + //pragma(msg, callCode()); + mixin(callCode()); + } + } catch(Throwable t) { + //writeln("thrown: ", t); + serialize(sendBuffer[8 .. 12], cast(int) 0); // no success + + auto place = sendBuffer[12 .. $]; + int l; + auto s = serialize(place, t.msg); + place = place[s.length .. $]; + l += s.length; + s = serialize(place, t.file); + place = place[s.length .. $]; + l += s.length; + s = serialize(place, t.line); + place = place[s.length .. $]; + l += s.length; + + serialize(sendBuffer[0 .. 4], l); + length += l; + } + + if(length != 12) // if there is a response... + connection.send(sendBuffer[0 .. length]); + } +} + +template PassThrough(alias a) { + alias PassThrough = a; +} + +// general FIXME: what if we run out of buffer space? + +// returns the part of the buffer that was actually used +final public ubyte[] serialize(T)(ubyte[] buffer, in T s) { + auto original = buffer; + size_t totalLength = 0; + import std.traits; + static if(isArray!T) { + /* length */ { + auto used = serialize(buffer, cast(int) s.length); + totalLength += used.length; + buffer = buffer[used.length .. $]; + } + foreach(i; s) { + auto used = serialize(buffer, i); + totalLength += used.length; + buffer = buffer[used.length .. $]; + } + } else static if(isPointer!T) { + static assert(0, "no pointers allowed"); + } else static if(!hasIndirections!T) { + // covers int, float, char, etc. most the builtins + import std.string; + assert(buffer.length >= T.sizeof, format("%s won't fit in %s buffer", T.stringof, buffer.length)); + buffer[0 .. T.sizeof] = (cast(ubyte*)&s)[0 .. T.sizeof]; + totalLength += T.sizeof; + buffer = buffer[T.sizeof .. $]; + } else { + // structs, classes, etc. + foreach(i, t; s.tupleof) { + auto used = serialize(buffer, t); + totalLength += used.length; + buffer = buffer[used.length .. $]; + } + } + + return original[0 .. totalLength]; +} + +// returns the remaining part of the buffer +final public inout(ubyte)[] deserializeInto(T)(inout(ubyte)[] buffer, ref T s) { + import std.traits; + + static if(isArray!T) { + size_t length; + buffer = deserializeInto(buffer, length); + s.length = length; + foreach(i; 0 .. length) + buffer = deserializeInto(buffer, s[i]); + } else static if(isPointer!T) { + static assert(0, "no pointers allowed"); + } else static if(!hasIndirections!T) { + // covers int, float, char, etc. most the builtins + (cast(ubyte*)(&s))[0 .. T.sizeof] = buffer[0 .. T.sizeof]; + buffer = buffer[T.sizeof .. $]; + } else { + // structs, classes, etc. + foreach(i, t; s.tupleof) { + buffer = deserializeInto(buffer, s.tupleof[i]); + } + } + + return buffer; +} + +auto makeNetworkClient(Interface)(string serverHost, ushort serverPort) { + static string createClass() { + // this doesn't actually inherit from the interface because + // the return value needs to be handled async + string code = `final class Class /*: ` ~ Interface.stringof ~ `*/ {`; + code ~= "\n\timport std.socket;"; + code ~= "\n\tprivate Socket socket;"; + code ~= "\n\tprivate void delegate(const(ubyte)[] buffer)[uint] onSuccesses;"; + code ~= "\n\tprivate void delegate(const(ubyte)[] buffer)[uint] onErrors;"; + code ~= "\n\tprivate uint lastSequenceNumber;"; + code ~= q{ + private this(string host, ushort port) { + this.socket = new TcpSocket(); + this.socket.connect(new InternetAddress(host, port)); + } + + final public void eventLoop() { + ubyte[4096] buffer; + bool open = true; + + do { + auto gotNum = socket.receive(buffer); + if(gotNum == 0) { + open = false; + break; + } + while(gotNum < 9) { + auto g2 = socket.receive(buffer[gotNum .. $]); + if(g2 == 0) { + open = false; + break; + } + gotNum += g2; + } + + auto got = buffer[0 .. gotNum]; + another: + uint length, seq; + uint success; + got = deserializeInto(got, length); + got = deserializeInto(got, seq); + got = deserializeInto(got, success); + auto more = got[length .. $]; + + if(got.length >= length) { + if(success) { + auto s = (seq in onSuccesses); + if(s !is null && *s !is null) + (*s)(got); + } else { + auto s = (seq in onErrors); + if(s !is null && *s !is null) + (*s)(got); + } + } + + if(more.length) { + got = more; + goto another; + } + } while(open); + } + }; + code ~= "\n\tpublic:\n"; + + foreach(memIdx, member; __traits(allMembers, Interface)) { + alias mem = PassThrough!(__traits(getMember, Interface, member)); + code ~= "\t\tfinal void " ~ member ~ "("; + bool hadArgument = false; + import std.traits; + import std.conv; + // arguments + foreach(i, arg; ParameterTypeTuple!mem) { + if(hadArgument) + code ~= ", "; + // FIXME: this is one place the arg can get unknown if we don't have all the imports + code ~= arg.stringof ~ " arg" ~ to!string(i); + hadArgument = true; + } + + if(hadArgument) + code ~= ", "; + + static if(is(ReturnType!mem == void)) + code ~= "void delegate() onSuccess"; + else + code ~= "void delegate("~(ReturnType!mem).stringof~") onSuccess"; + code ~= ", "; + code ~= "void delegate(Throwable) onError"; + code ~= ") {\n"; + code ~= q{ + #line 252 + auto seq = ++lastSequenceNumber; + onSuccesses[seq] = (const(ubyte)[] buffer) { + onSuccesses.remove(seq); + onErrors.remove(seq); + + import std.traits; + + static if(is(ParameterTypeTuple!(typeof(onSuccess)) == void)) { + if(onSuccess !is null) + onSuccess(); + } else { + ParameterTypeTuple!(typeof(onSuccess)) args; + foreach(i, arg; args) + buffer = deserializeInto(buffer, args[i]); + if(onSuccess !is null) + onSuccess(args); + } + }; + onErrors[seq] = (const(ubyte)[] buffer) { + onSuccesses.remove(seq); + onErrors.remove(seq); + auto t = new Throwable(""); + buffer = deserializeInto(buffer, t.msg); + buffer = deserializeInto(buffer, t.file); + buffer = deserializeInto(buffer, t.line); + + if(onError !is null) + onError(t); + }; + + #line 283 + ubyte[4096] bufferBase; + auto buffer = bufferBase[12 .. $]; // leaving room for size, func number, and seq number + ubyte[] serialized; + int used; + }; + // preparing the request + foreach(i, arg; ParameterTypeTuple!mem) { + code ~= "\t\t\tserialized = serialize(buffer, arg" ~ to!string(i) ~ ");\n"; + code ~= "\t\t\tused += serialized.length;\n"; + code ~= "\t\t\tbuffer = buffer[serialized.length .. $];\n"; + } + + code ~= "\t\t\tserialize(bufferBase[0 .. 4], used);\n"; + code ~= "\t\t\tserialize(bufferBase[4 .. 8], " ~ to!string(memIdx + 1) ~ ");\n"; + code ~= "\t\t\tserialize(bufferBase[8 .. 12], seq);\n"; + + // FIXME: what if it doesn't all send at once? + code ~= "\t\t\tsocket.send(bufferBase[0 .. 12 + used]);\n"; + //code ~= `writeln("sending ", bufferBase[0 .. 12 + used]);`; + code ~= "}\n"; + code ~= "\n"; + } + + code ~= `}`; + return code; + } + + //pragma(msg, createClass()); // for debugging help + #line 363 + mixin(createClass()); + #line 365 + + return new Class(serverHost, serverPort); +} + +// the protocol is: +/* + +client connects + ulong interface hash + +handshake complete + +messages: + + uint messageLength + uint sequence number + ushort function number, 0 is reserved for interface check + serialized arguments.... + + + +server responds with answers: + + uint messageLength + uint re: sequence number + ubyte, 1 == success, 0 == error + serialized return value + +*/