/++ mvd stands for Multiple Virtual Dispatch. It lets you write functions that take any number of arguments of objects and match based on the dynamic type of each of them. --- void foo(Object a, Object b) {} // 1 void foo(MyClass b, Object b) {} // 2 void foo(DerivedClass a, MyClass b) {} // 3 Object a = new MyClass(); Object b = new Object(); mvd!foo(a, b); // will call overload #2 --- The return values must be compatible; [mvd] will return the least specialized static type of the return values (most likely the shared base class type of all return types, or `void` if there isn't one). All non-class/interface types should be compatible among overloads. Otherwise you are liable to get compile errors. (Or it might work, that's up to the compiler's discretion.) +/ module arsd.mvd; import std.traits; /// This exists just to make the documentation of [mvd] nicer looking. template CommonReturnOfOverloads(alias fn) { alias overloads = __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn)); static if (overloads.length == 1) { alias CommonReturnOfOverloads = ReturnType!(overloads[0]); } else { alias CommonReturnOfOverloads = CommonType!(staticMap!(ReturnType, overloads)); } } /// See details on the [arsd.mvd] page. CommonReturnOfOverloads!fn mvd(alias fn, T...)(T args) { return mvdObj!fn(null, args); } CommonReturnOfOverloads!fn mvdObj(alias fn, This, T...)(This this_, T args) { typeof(return) delegate() bestMatch; int bestScore; string argsStr() { string s; foreach(arg; args) { if(s.length) s ~= ", "; static if (is(typeof(arg) == class)) { if (arg is null) { s ~= "null " ~ typeof(arg).stringof; } else { s ~= typeid(arg).name; } } else { s ~= typeof(arg).stringof; } } return s; } ov: foreach(overload; __traits(getOverloads, __traits(parent, fn), __traits(identifier, fn))) { Parameters!overload pargs; int score = 0; foreach(idx, parg; pargs) { alias t = typeof(parg); static if(is(t == interface) || is(t == class)) { t value = cast(t) args[idx]; // HACK: cast to Object* so we can set the value even if it's an immutable class *cast(Object*) &pargs[idx] = cast(Object) value; if(args[idx] !is null && pargs[idx] is null) continue ov; // failed cast, forget it else score += BaseClassesTuple!t.length + 1; } else pargs[idx] = args[idx]; } if(score == bestScore) throw new Exception("ambiguous overload selection with args (" ~ argsStr ~ ")"); if(score > bestScore) { bestMatch = () { static if(is(typeof(return) == void)) __traits(child, this_, overload)(pargs); else return __traits(child, this_, overload)(pargs); }; bestScore = score; } } if(bestMatch is null) throw new Exception("no match existed with args (" ~ argsStr ~ ")"); return bestMatch(); } /// unittest { class MyClass {} class DerivedClass : MyClass {} class OtherClass {} static struct Wrapper { static: // this is just a namespace cuz D doesn't allow overloading inside unittest int foo(Object a, Object b) { return 1; } int foo(MyClass a, Object b) { return 2; } int foo(DerivedClass a, MyClass b) { return 3; } int bar(MyClass a) { return 4; } } with(Wrapper) { assert(mvd!foo(new Object, new Object) == 1); assert(mvd!foo(new MyClass, new DerivedClass) == 2); assert(mvd!foo(new DerivedClass, new DerivedClass) == 3); assert(mvd!foo(new OtherClass, new OtherClass) == 1); assert(mvd!foo(new OtherClass, new MyClass) == 1); assert(mvd!foo(new DerivedClass, new DerivedClass) == 3); assert(mvd!foo(new OtherClass, new MyClass) == 1); //mvd!bar(new OtherClass); } } /// unittest { class MyClass {} class DerivedClass : MyClass {} class OtherClass {} class Wrapper { int x; int foo(Object a, Object b) { return x + 1; } int foo(MyClass a, Object b) { return x + 2; } int foo(DerivedClass a, MyClass b) { return x + 3; } int bar(MyClass a) { return x + 4; } } Wrapper wrapper = new Wrapper; wrapper.x = 20; assert(wrapper.mvdObj!(wrapper.foo)(new Object, new Object) == 21); assert(wrapper.mvdObj!(wrapper.foo)(new MyClass, new DerivedClass) == 22); assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23); assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new OtherClass) == 21); assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21); assert(wrapper.mvdObj!(wrapper.foo)(new DerivedClass, new DerivedClass) == 23); assert(wrapper.mvdObj!(wrapper.foo)(new OtherClass, new MyClass) == 21); //mvd!bar(new OtherClass); } /// unittest { class MyClass {} static bool success = false; static struct Wrapper { static: void foo(MyClass a) { success = true; } } with(Wrapper) { mvd!foo(new MyClass); assert(success); } } /// unittest { immutable class Foo {} immutable class Bar : Foo { int x; this(int x) { this.x = x; } } immutable class Baz : Foo { int x, y; this(int x, int y) { this.x = x; this.y = y; } } static struct Wrapper { static: int foo(Bar b) { return b.x; } int foo(Baz b) { return b.x + b.y; } } with(Wrapper) { Foo x = new Bar(3); Foo y = new Baz(5, 7); assert(mvd!foo(x) == 3); assert(mvd!foo(y) == 12); } }