arsd/rpc.d

557 lines
14 KiB
D

/// I never finished this. The idea is to use CT reflection to make calling another process feel as simple as calling in-process objects. Will come eventually but no promises.
module arsd.rpc;
/*
FIXME:
1) integrate with arsd.eventloop
2) make it easy to use with other processes; pipe to a process and talk to it that way. perhaps with shared memory too?
3) extend the serialization capabilities
@Throws!(List, Of, Exceptions)
classes are also RPC proxied
stdin/out/err also redirected
*/
///+ //example usage
interface ExampleNetworkFunctions {
string sayHello(string name);
int add(int a, int b);
S2 structTest(S1);
void die();
}
// the server must implement the interface
class ExampleServer : ExampleNetworkFunctions {
override string sayHello(string name) {
return "Hello, " ~ name;
}
override int add(int a, int b) {
return a+b;
}
override S2 structTest(S1 a) {
return S2(a.name, a.number);
}
override void die() {
throw new Exception("death requested");
}
mixin NetworkServer!ExampleNetworkFunctions;
}
struct S1 {
int number;
string name;
}
struct S2 {
string name;
int number;
}
import std.stdio;
void main(string[] args) {
if(args.length > 1) {
auto client = makeNetworkClient!ExampleNetworkFunctions("localhost", 5005);
// these work like the interface above, but instead of returning the value,
// they take callbacks for success (where the arg is the retval)
// and failure (the arg is the exception)
client.sayHello("whoa", (a) { writeln(a); }, null);
client.add(1,2, (a) { writeln(a); }, null);
client.add(10,20, (a) { writeln(a); }, null);
client.structTest(S1(20, "cool!"), (a) { writeln(a.name, " -- ", a.number); }, null);
client.die(delegate () { writeln("shouldn't happen"); }, delegate(a) { writeln(a); });
client.eventLoop();
/*
auto client = makeNetworkClient!(ExampleNetworkFunctions, false)("localhost", 5005);
writeln(client.sayHello("whoa"));
writeln(client.add(1, 2));
client.die();
writeln(client.add(1, 2));
*/
} else {
auto server = new ExampleServer(5005);
server.eventLoop();
}
}
//+/
mixin template NetworkServer(Interface) {
import std.socket;
private Socket socket;
public this(ushort port) {
socket = new TcpSocket();
socket.setOption(SocketOptionLevel.SOCKET, SocketOption.REUSEADDR, true);
socket.bind(new InternetAddress(port));
socket.listen(16);
}
final public void eventLoop() {
auto check = new SocketSet();
Socket[] connections;
connections.reserve(16);
ubyte[4096] buffer;
while(true) {
check.reset();
check.add(socket);
foreach(connection; connections) {
check.add(connection);
}
if(Socket.select(check, null, null)) {
if(check.isSet(socket)) {
connections ~= socket.accept();
}
foreach(connection; connections) {
if(check.isSet(connection)) {
auto gotNum = connection.receive(buffer);
if(gotNum == 0) {
// connection is closed, we could remove it from the list
} else {
auto got = buffer[0 .. gotNum];
another:
int length, functionNumber, sequenceNumber;
got = deserializeInto(got, length);
got = deserializeInto(got, functionNumber);
got = deserializeInto(got, sequenceNumber);
//writeln("got ", sequenceNumber, " calling ", functionNumber);
auto remaining = got[length .. $];
got = got[0 .. length];
import std.conv;
assert(length == got.length, to!string(length) ~ " != " ~ to!string(got.length)); // FIXME: what if it doesn't all come at once?
callByNumber(functionNumber, sequenceNumber, got, connection);
if(remaining.length) {
got = remaining;
goto another;
}
}
}
}
}
}
}
final private void callByNumber(int functionNumber, int sequenceNumber, const(ubyte)[] buffer, Socket connection) {
ubyte[4096] sendBuffer;
int length = 12;
// length, sequence, success
serialize(sendBuffer[4 .. 8], sequenceNumber);
string callCode() {
import std.conv;
import std.traits;
string code;
foreach(memIdx, member; __traits(allMembers, Interface)) {
code ~= "\t\tcase " ~ to!string(memIdx + 1) ~ ":\n";
alias mem = PassThrough!(__traits(getMember, Interface, member));
// we need to deserialize the arguments, call the function, and send back the response (if there is one)
string argsString;
foreach(i, arg; ParameterTypeTuple!mem) {
if(i)
argsString ~= ", ";
auto istr = to!string(i);
code ~= "\t\t\t" ~ arg.stringof ~ " arg" ~ istr ~ ";\n";
code ~= "\t\t\tbuffer = deserializeInto(buffer, arg" ~ istr ~ ");\n";
argsString ~= "arg" ~ istr;
}
// the call
static if(is(ReturnType!mem == void)) {
code ~= "\n\t\t\t" ~ member ~ "(" ~ argsString ~ ");\n";
} else {
// call and return answer
code ~= "\n\t\t\tauto ret = " ~ member ~ "(" ~ argsString ~ ");\n";
code ~= "\t\t\tserialize(sendBuffer[8 .. 12], cast(int) 1);\n"; // yes success
code ~= "\t\t\tauto serialized = serialize(sendBuffer[12 .. $], ret);\n";
code ~= "\t\t\tserialize(sendBuffer[0 .. 4], cast(int) serialized.length);\n";
code ~= "\t\t\tlength += serialized.length;\n";
}
code ~= "\t\tbreak;\n";
}
return code;
}
try {
switch(functionNumber) {
default: assert(0, "unknown function");
//pragma(msg, callCode());
mixin(callCode());
}
} catch(Throwable t) {
//writeln("thrown: ", t);
serialize(sendBuffer[8 .. 12], cast(int) 0); // no success
auto place = sendBuffer[12 .. $];
int l;
auto s = serialize(place, t.msg);
place = place[s.length .. $];
l += s.length;
s = serialize(place, t.file);
place = place[s.length .. $];
l += s.length;
s = serialize(place, t.line);
place = place[s.length .. $];
l += s.length;
serialize(sendBuffer[0 .. 4], l);
length += l;
}
if(length != 12) // if there is a response...
connection.send(sendBuffer[0 .. length]);
}
}
template PassThrough(alias a) {
alias PassThrough = a;
}
// general FIXME: what if we run out of buffer space?
// returns the part of the buffer that was actually used
final public ubyte[] serialize(T)(ubyte[] buffer, in T s) {
auto original = buffer;
size_t totalLength = 0;
import std.traits;
static if(isArray!T) {
/* length */ {
auto used = serialize(buffer, cast(int) s.length);
totalLength += used.length;
buffer = buffer[used.length .. $];
}
foreach(i; s) {
auto used = serialize(buffer, i);
totalLength += used.length;
buffer = buffer[used.length .. $];
}
} else static if(isPointer!T) {
static assert(0, "no pointers allowed");
} else static if(!hasIndirections!T) {
// covers int, float, char, etc. most the builtins
import std.string;
assert(buffer.length >= T.sizeof, format("%s won't fit in %s buffer", T.stringof, buffer.length));
buffer[0 .. T.sizeof] = (cast(ubyte*)&s)[0 .. T.sizeof];
totalLength += T.sizeof;
buffer = buffer[T.sizeof .. $];
} else {
// structs, classes, etc.
foreach(i, t; s.tupleof) {
auto used = serialize(buffer, t);
totalLength += used.length;
buffer = buffer[used.length .. $];
}
}
return original[0 .. totalLength];
}
// returns the remaining part of the buffer
final public inout(ubyte)[] deserializeInto(T)(inout(ubyte)[] buffer, ref T s) {
import std.traits;
static if(isArray!T) {
size_t length;
buffer = deserializeInto(buffer, length);
s.length = length;
foreach(i; 0 .. length)
buffer = deserializeInto(buffer, s[i]);
} else static if(isPointer!T) {
static assert(0, "no pointers allowed");
} else static if(!hasIndirections!T) {
// covers int, float, char, etc. most the builtins
(cast(ubyte*)(&s))[0 .. T.sizeof] = buffer[0 .. T.sizeof];
buffer = buffer[T.sizeof .. $];
} else {
// structs, classes, etc.
foreach(i, t; s.tupleof) {
buffer = deserializeInto(buffer, s.tupleof[i]);
}
}
return buffer;
}
mixin template NetworkClient(Interface, bool useAsync = true) {
private static string createClass() {
// this doesn't actually inherit from the interface because
// the return value needs to be handled async
string code;// = `final class Class /*: ` ~ Interface.stringof ~ `*/ {`;
code ~= "\n\timport std.socket;";
code ~= "\n\tprivate Socket socket;";
if(useAsync) {
code ~= "\n\tprivate void delegate(const(ubyte)[] buffer)[uint] onSuccesses;";
code ~= "\n\tprivate void delegate(const(ubyte)[] buffer)[uint] onErrors;";
}
code ~= "\n\tprivate uint lastSequenceNumber;";
code ~= q{
private this(string host, ushort port) {
this.socket = new TcpSocket();
this.socket.connect(new InternetAddress(host, port));
}
};
if(useAsync)
code ~= q{
final public void eventLoop() {
ubyte[4096] buffer;
bool open = true;
do {
auto gotNum = socket.receive(buffer);
if(gotNum == 0) {
open = false;
break;
}
while(gotNum < 9) {
auto g2 = socket.receive(buffer[gotNum .. $]);
if(g2 == 0) {
open = false;
break;
}
gotNum += g2;
}
auto got = buffer[0 .. gotNum];
another:
uint length, seq;
uint success;
got = deserializeInto(got, length);
got = deserializeInto(got, seq);
got = deserializeInto(got, success);
auto more = got[length .. $];
if(got.length >= length) {
if(success) {
auto s = (seq in onSuccesses);
if(s !is null && *s !is null)
(*s)(got);
} else {
auto s = (seq in onErrors);
if(s !is null && *s !is null)
(*s)(got);
}
}
if(more.length) {
got = more;
goto another;
}
} while(open);
}
};
code ~= "\n\tpublic:\n";
foreach(memIdx, member; __traits(allMembers, Interface)) {
import std.traits;
alias mem = PassThrough!(__traits(getMember, Interface, member));
string type;
if(useAsync)
type = "void";
else {
static if(is(ReturnType!mem == void))
type = "void";
else
type = (ReturnType!mem).stringof;
}
code ~= "\t\tfinal "~type~" " ~ member ~ "(";
bool hadArgument = false;
import std.conv;
// arguments
foreach(i, arg; ParameterTypeTuple!mem) {
if(hadArgument)
code ~= ", ";
// FIXME: this is one place the arg can get unknown if we don't have all the imports
code ~= arg.stringof ~ " arg" ~ to!string(i);
hadArgument = true;
}
if(useAsync) {
if(hadArgument)
code ~= ", ";
static if(is(ReturnType!mem == void))
code ~= "void delegate() onSuccess";
else
code ~= "void delegate("~(ReturnType!mem).stringof~") onSuccess";
code ~= ", ";
code ~= "void delegate(Throwable) onError";
}
code ~= ") {\n";
code ~= "auto seq = ++lastSequenceNumber;";
if(useAsync)
code ~= q{
#line 252
onSuccesses[seq] = (const(ubyte)[] buffer) {
onSuccesses.remove(seq);
onErrors.remove(seq);
import std.traits;
static if(is(ParameterTypeTuple!(typeof(onSuccess)) == void)) {
if(onSuccess !is null)
onSuccess();
} else {
ParameterTypeTuple!(typeof(onSuccess)) args;
foreach(i, arg; args)
buffer = deserializeInto(buffer, args[i]);
if(onSuccess !is null)
onSuccess(args);
}
};
onErrors[seq] = (const(ubyte)[] buffer) {
onSuccesses.remove(seq);
onErrors.remove(seq);
auto t = new Throwable("");
buffer = deserializeInto(buffer, t.msg);
buffer = deserializeInto(buffer, t.file);
buffer = deserializeInto(buffer, t.line);
if(onError !is null)
onError(t);
};
};
code ~= q{
#line 283
ubyte[4096] bufferBase;
auto buffer = bufferBase[12 .. $]; // leaving room for size, func number, and seq number
ubyte[] serialized;
int used;
};
// preparing the request
foreach(i, arg; ParameterTypeTuple!mem) {
code ~= "\t\t\tserialized = serialize(buffer, arg" ~ to!string(i) ~ ");\n";
code ~= "\t\t\tused += serialized.length;\n";
code ~= "\t\t\tbuffer = buffer[serialized.length .. $];\n";
}
code ~= "\t\t\tserialize(bufferBase[0 .. 4], used);\n";
code ~= "\t\t\tserialize(bufferBase[4 .. 8], " ~ to!string(memIdx + 1) ~ ");\n";
code ~= "\t\t\tserialize(bufferBase[8 .. 12], seq);\n";
// FIXME: what if it doesn't all send at once?
code ~= "\t\t\tsocket.send(bufferBase[0 .. 12 + used]);\n";
//code ~= `writeln("sending ", bufferBase[0 .. 12 + used]);`;
if(!useAsync)
code ~= q{
ubyte[4096] dbuffer;
bool open = true;
static if(is(typeof(return) == void)) {
} else
typeof(return) returned;
auto gotNum = socket.receive(dbuffer);
if(gotNum == 0) {
open = false;
throw new Exception("connection closed");
}
while(gotNum < 9) {
auto g2 = socket.receive(dbuffer[gotNum .. $]);
if(g2 == 0) {
open = false;
break;
}
gotNum += g2;
}
auto got = dbuffer[0 .. gotNum];
another:
uint length;
uint success;
got = deserializeInto(got, length);
got = deserializeInto(got, seq);
got = deserializeInto(got, success);
auto more = got[length .. $];
if(got.length >= length) {
if(success) {
/*
auto s = (seq in onSuccesses);
if(s !is null && *s !is null)
(*s)(got);
*/
static if(is(typeof(return) == void)) {
} else {
got = deserializeInto(got, returned);
}
} else {
/*
auto s = (seq in onErrors);
if(s !is null && *s !is null)
(*s)(got);
*/
auto t = new Throwable("");
got = deserializeInto(got, t.msg);
got = deserializeInto(got, t.file);
got = deserializeInto(got, t.line);
throw t;
}
}
if(more.length) {
got = more;
goto another;
}
static if(is(typeof(return) == void)) {
} else
return returned;
};
code ~= "}\n";
code ~= "\n";
}
//code ~= `}`;
return code;
}
//pragma(msg, createClass()); // for debugging help
mixin(createClass());
}
auto makeNetworkClient(Interface, bool useAsync = true)(string host, ushort port) {
class Thing {
mixin NetworkClient!(Interface, useAsync);
}
return new Thing(host, port);
}
// the protocol is:
/*
client connects
ulong interface hash
handshake complete
messages:
uint messageLength
uint sequence number
ushort function number, 0 is reserved for interface check
serialized arguments....
server responds with answers:
uint messageLength
uint re: sequence number
ubyte, 1 == success, 0 == error
serialized return value
*/