Implement std.parallelism.fold with static if branches instead of with template specializations

This commit is contained in:
Ali Çehreli 2017-12-25 15:36:06 -08:00
parent 673c4e959d
commit 665e4adbce

View file

@ -2812,26 +2812,59 @@ public:
} }
} }
///
template fold(functions...)
{
/** Implements the homonym function (also known as `accumulate`, `compress`, /** Implements the homonym function (also known as `accumulate`, `compress`,
`inject`, or `foldl`) present in various programming languages of `inject`, or `foldl`) present in various programming languages of
functional flavor. functional flavor.
This is functionally equivalent to $(LREF reduce) except the range `fold` is functionally equivalent to $(LREF reduce) except the range
parameter comes first and there is no need to use $(REF_ALTTEXT parameter comes first and there is no need to use $(REF_ALTTEXT
`tuple`,tuple,std,typecons) for multiple seeds. `tuple`,tuple,std,typecons) for multiple seeds.
There may be one or more callable entities (`functions` argument) to
apply.
Params: Params:
functions = One or more functions
args = Just the range to _fold over; or the range and one seed
per function; or the range, one seed per function, and
the work unit size
Returns: Returns:
The accumulated result as a single value for single function and as The accumulated result as a single value for single function and
a tuple of values for multiple functions as a tuple of values for multiple functions
See_Also: See_Also:
Similar to $(REF _fold, std,algorithm,iteration), `fold` is a wrapper around $(LREF reduce). Similar to $(REF _fold, std,algorithm,iteration), `fold` is a wrapper around $(LREF reduce).
*/
template fold(functions...) Example:
---
static int adder(int a, int b)
{ {
return a + b;
}
static int multiplier(int a, int b)
{
return a * b;
}
// Just the range
auto x = taskPool.fold!adder([1, 2, 3, 4]);
assert(x == 10);
// The range and the seeds (0 and 1 below; also note multiple
// functions in this example)
auto y = taskPool.fold!(adder, multiplier)([1, 2, 3, 4], 0, 1);
assert(y[0] == 10);
assert(y[1] == 24);
// The range, the seed (0), and the work unit size (20)
auto z = taskPool.fold!adder([1, 2, 3, 4], 0, 20);
assert(z == 10);
---
*/
auto fold(Args...)(Args args) auto fold(Args...)(Args args)
{ {
import std.string : format; import std.string : format;
@ -2841,105 +2874,85 @@ public:
args.length == 1 + functions.length + 1, // range, seeds, and workUnitSize args.length == 1 + functions.length + 1, // range, seeds, and workUnitSize
format("Invalid number of arguments (%s): Should be an input range, %s optional seed(s)," ~ format("Invalid number of arguments (%s): Should be an input range, %s optional seed(s)," ~
" and an optional work unit size", Args.length, functions.length)); " and an optional work unit size", Args.length, functions.length));
return fold(args[0], args[1..$]);
auto range()
{
return args[0];
} }
/** This is the overload that uses implicit seeds. (See important notes static if (Args.length == 1)
on implicit seeds under $(LREF reduce).)
Params:
range = The range of values to _fold over
*/
auto fold(R)(R range)
{ {
// Just the range
return reduce!functions(range); return reduce!functions(range);
} }
else
///
version(StdUnittest)
unittest
{ {
static int adder(int a, int b) { return a + b; } static if (functions.length == 1)
auto result = taskPool.fold!adder([1, 2, 3, 4]); {
assert(result == 10); auto seeds()
{
return args[1];
}
}
else
{
auto seeds()
{
import std.typecons : tuple;
return tuple(args[1 .. functions.length+1]);
}
} }
/** This is the overload that uses explicit _seeds. (See important notes static if (Args.length == 1 + functions.length)
on explicit _seeds under $(LREF reduce).)
Params:
range = The range of values to _fold over
seeds = One seed per function
*/
auto fold(R, Seeds...)(R range, Seeds seeds)
if (Seeds.length == functions.length)
{
static if (Seeds.length == 1)
{ {
// The range and the seeds
return reduce!functions(seeds, range); return reduce!functions(seeds, range);
} }
else else static if (Args.length == 1 + functions.length + 1)
{
import std.typecons : tuple;
return reduce!functions(tuple(seeds), range);
}
}
///
version(StdUnittest)
unittest
{
static int adder (int a, int b) { return a + b; }
static int multiplier(int a, int b) { return a * b; }
auto result = taskPool.fold!(adder, multiplier)([1, 2, 3, 4], 0, 1);
assert(result[0] == 10);
assert(result[1] == 24);
}
/** This is the overload that uses explicit seeds and work unit size.
(See important notes on explicit seeds and work unit size under
$(LREF reduce).)
Params:
range = The range of values to _fold over
args = One seed per function (`args[0..$-1]`) and the work unit
size (`args[$-1]`)
*/
auto fold(R, Args...)(R range, Args args)
if (Args.length == functions.length + 1)
{ {
// The range, the seeds, and the work unit size
static assert(isIntegral!(Args[$-1]), "Work unit size must be an integral type"); static assert(isIntegral!(Args[$-1]), "Work unit size must be an integral type");
return reduce!functions(seeds, range, args[$-1]);
static if (Args.length == 2)
{
return reduce!functions(args[0], range, args[1]);
} }
else else
{ {
import std.typecons : tuple; static assert(0);
return reduce!functions(tuple(args[0..$-1]), range, args[$-1]); }
}
} }
} }
/// // This test is not included in the documentation because even though these
// examples are for the inner fold() template, with their current location,
// they would appear under the outer one. (We can't move this inside the
// outer fold() template because then dmd runs out of memory possibly due to
// recursive template instantiation, which is surprisingly not caught.)
version(StdUnittest) version(StdUnittest)
unittest @system unittest
{ {
static int adder (int a, int b) { return a + b; } static int adder(int a, int b)
static int multiplier(int a, int b) { return a * b; } {
return a + b;
// Single function produces single result
auto result = taskPool.fold!adder([1, 2, 3, 4], 0, 20);
assert(result == 10);
// Multiple functions produce a tuple of results
auto results = taskPool.fold!(adder, multiplier)([1, 2, 3, 4], 0, 1, 30);
assert(results[0] == 10);
assert(results[1] == 24);
} }
static int multiplier(int a, int b)
{
return a * b;
} }
// Just the range
auto x = taskPool.fold!adder([1, 2, 3, 4]);
assert(x == 10);
// The range and the seeds (0 and 1 below; also note multiple
// functions in this example)
auto y = taskPool.fold!(adder, multiplier)([1, 2, 3, 4], 0, 1);
assert(y[0] == 10);
assert(y[1] == 24);
// The range, the seed (0), and the work unit size (20)
auto z = taskPool.fold!adder([1, 2, 3, 4], 0, 20);
assert(z == 10);
}
/** /**
Gets the index of the current thread relative to this $(D TaskPool). Any Gets the index of the current thread relative to this $(D TaskPool). Any
thread not in this pool will receive an index of 0. The worker threads in thread not in this pool will receive an index of 0. The worker threads in
@ -3473,21 +3486,24 @@ public:
} }
version(StdUnittest) version(StdUnittest)
unittest @system unittest
{ {
import std.algorithm : sum; import std.algorithm.iteration : sum;
import std.range : iota; import std.range : iota;
import std.typecons : tuple;
static int adder(int a, int b)
{
return a + b;
}
enum N = 100; enum N = 100;
auto r = iota(1, N + 1); auto r = iota(1, N + 1);
const expected = r.sum(); const expected = r.sum();
static auto adder(int a, int b) {
return a + b;
}
assert(taskPool.fold!adder(r) == expected); assert(taskPool.fold!adder(r) == expected);
assert(taskPool.fold!adder(r, 0) == expected); assert(taskPool.fold!adder(r, 0) == expected);
assert(taskPool.fold!adder(r, 0, 42) == expected); assert(taskPool.fold!(adder, adder)(r, 0, 0, 42) == tuple(expected, expected));
} }
/** /**