mirror of https://github.com/adamdruppe/arsd.git
215 lines
5.2 KiB
D
215 lines
5.2 KiB
D
/++
|
|
mvd stands for Multiple Virtual Dispatch. It lets you
|
|
write functions that take any number of arguments of
|
|
objects and match based on the dynamic type of each
|
|
of them.
|
|
|
|
---
|
|
void foo(Object a, Object b) {} // 1
|
|
void foo(MyClass b, Object b) {} // 2
|
|
void foo(DerivedClass a, MyClass b) {} // 3
|
|
|
|
Object a = new MyClass();
|
|
Object b = new Object();
|
|
|
|
mvd!foo(a, b); // will call overload #2
|
|
---
|
|
|
|
The return values must be compatible; [mvd] will return
|
|
the least specialized static type of the return values
|
|
(most likely the shared base class type of all return types,
|
|
or `void` if there isn't one).
|
|
|
|
All non-class/interface types should be compatible among overloads.
|
|
Otherwise you are liable to get compile errors. (Or it might work,
|
|
that's up to the compiler's discretion.)
|
|
+/
|
|
module arsd.mvd;
|
|
|
|
import std.traits;
|
|
|
|
/// This exists just to make the documentation of [mvd] nicer looking.
|
|
template CommonReturnOfOverloads(alias fn) {
|
|
alias overloads = __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn));
|
|
static if (overloads.length == 1) {
|
|
alias CommonReturnOfOverloads = ReturnType!(overloads[0]);
|
|
}
|
|
else {
|
|
alias CommonReturnOfOverloads = CommonType!(staticMap!(ReturnType, overloads));
|
|
}
|
|
}
|
|
|
|
/// See details on the [arsd.mvd] page.
|
|
CommonReturnOfOverloads!fn mvd(alias fn, T...)(T args) {
|
|
return mvdObj!fn(null, args);
|
|
}
|
|
|
|
CommonReturnOfOverloads!fn mvdObj(alias fn, This, T...)(This this_, T args) {
|
|
typeof(return) delegate() bestMatch;
|
|
int bestScore;
|
|
|
|
string argsStr() {
|
|
string s;
|
|
foreach(arg; args) {
|
|
if(s.length)
|
|
s ~= ", ";
|
|
static if (is(typeof(arg) == class)) {
|
|
if (arg is null) {
|
|
s ~= "null " ~ typeof(arg).stringof;
|
|
} else {
|
|
s ~= typeid(arg).name;
|
|
}
|
|
} else {
|
|
s ~= typeof(arg).stringof;
|
|
}
|
|
}
|
|
return s;
|
|
}
|
|
|
|
ov: foreach(overload; __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))) {
|
|
Parameters!overload pargs;
|
|
int score = 0;
|
|
foreach(idx, parg; pargs) {
|
|
alias t = typeof(parg);
|
|
static if(is(t == interface) || is(t == class)) {
|
|
t value = cast(t) args[idx];
|
|
// HACK: cast to Object* so we can set the value even if it's an immutable class
|
|
*cast(Object*) &pargs[idx] = cast(Object) value;
|
|
if(args[idx] !is null && pargs[idx] is null)
|
|
continue ov; // failed cast, forget it
|
|
else
|
|
score += BaseClassesTuple!t.length + 1;
|
|
} else
|
|
pargs[idx] = args[idx];
|
|
}
|
|
if(score == bestScore)
|
|
throw new Exception("ambiguous overload selection with args (" ~ argsStr ~ ")");
|
|
if(score > bestScore) {
|
|
bestMatch = () {
|
|
static if(is(typeof(return) == void))
|
|
__traits(child, this_, overload)(pargs);
|
|
else
|
|
return __traits(child, this_, overload)(pargs);
|
|
};
|
|
bestScore = score;
|
|
}
|
|
}
|
|
|
|
if(bestMatch is null)
|
|
throw new Exception("no match existed with args (" ~ argsStr ~ ")");
|
|
|
|
return bestMatch();
|
|
}
|
|
|
|
///
|
|
unittest {
|
|
|
|
class MyClass {}
|
|
class DerivedClass : MyClass {}
|
|
class OtherClass {}
|
|
|
|
static struct Wrapper {
|
|
static: // this is just a namespace cuz D doesn't allow overloading inside unittest
|
|
int foo(Object a, Object b) { return 1; }
|
|
int foo(MyClass a, Object b) { return 2; }
|
|
int foo(DerivedClass a, MyClass b) { return 3; }
|
|
|
|
int bar(MyClass a) { return 4; }
|
|
}
|
|
|
|
with(Wrapper) {
|
|
assert(mvd!foo(new Object, new Object) == 1);
|
|
assert(mvd!foo(new MyClass, new DerivedClass) == 2);
|
|
assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
|
|
assert(mvd!foo(new OtherClass, new OtherClass) == 1);
|
|
assert(mvd!foo(new OtherClass, new MyClass) == 1);
|
|
assert(mvd!foo(new DerivedClass, new DerivedClass) == 3);
|
|
assert(mvd!foo(new OtherClass, new MyClass) == 1);
|
|
|
|
//mvd!bar(new OtherClass);
|
|
}
|
|
}
|
|
|
|
///
|
|
unittest {
|
|
|
|
class MyClass {}
|
|
class DerivedClass : MyClass {}
|
|
class OtherClass {}
|
|
|
|
class Wrapper {
|
|
int x;
|
|
|
|
int foo(Object a, Object b) { return x + 1; }
|
|
int foo(MyClass a, Object b) { return x + 2; }
|
|
int foo(DerivedClass a, MyClass b) { return x + 3; }
|
|
|
|
int bar(MyClass a) { return x + 4; }
|
|
}
|
|
|
|
Wrapper wrapper = new Wrapper;
|
|
wrapper.x = 20;
|
|
assert(wrapper.mvdObj!(wrapper.foo)(new Object, new Object) == 21);
|
|
assert(wrapper.mvdObj!(wrapper.foo)(new MyClass, new DerivedClass) == 22);
|
|
assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23);
|
|
assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new OtherClass) == 21);
|
|
assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21);
|
|
assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23);
|
|
assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21);
|
|
|
|
//mvd!bar(new OtherClass);
|
|
}
|
|
|
|
///
|
|
unittest {
|
|
class MyClass {}
|
|
|
|
static bool success = false;
|
|
|
|
static struct Wrapper {
|
|
static:
|
|
void foo(MyClass a) { success = true; }
|
|
}
|
|
|
|
with(Wrapper) {
|
|
mvd!foo(new MyClass);
|
|
assert(success);
|
|
}
|
|
}
|
|
|
|
///
|
|
unittest {
|
|
immutable class Foo {}
|
|
|
|
immutable class Bar : Foo {
|
|
int x;
|
|
|
|
this(int x) {
|
|
this.x = x;
|
|
}
|
|
}
|
|
|
|
immutable class Baz : Foo {
|
|
int x, y;
|
|
|
|
this(int x, int y) {
|
|
this.x = x;
|
|
this.y = y;
|
|
}
|
|
}
|
|
|
|
static struct Wrapper {
|
|
static:
|
|
|
|
int foo(Bar b) { return b.x; }
|
|
int foo(Baz b) { return b.x + b.y; }
|
|
}
|
|
|
|
with(Wrapper) {
|
|
Foo x = new Bar(3);
|
|
Foo y = new Baz(5, 7);
|
|
assert(mvd!foo(x) == 3);
|
|
assert(mvd!foo(y) == 12);
|
|
}
|
|
}
|