Implement std.parallelism.fold with static if branches instead of with template specializations

This commit is contained in:
Ali Çehreli 2017-12-25 15:36:06 -08:00
parent 673c4e959d
commit 665e4adbce

View file

@ -2812,26 +2812,59 @@ public:
} }
} }
/** Implements the homonym function (also known as `accumulate`, `compress`, ///
`inject`, or `foldl`) present in various programming languages of
functional flavor.
This is functionally equivalent to $(LREF reduce) except the range
parameter comes first and there is no need to use $(REF_ALTTEXT
`tuple`,tuple,std,typecons) for multiple seeds.
Params:
functions = One or more functions
Returns:
The accumulated result as a single value for single function and as
a tuple of values for multiple functions
See_Also:
Similar to $(REF _fold, std,algorithm,iteration), `fold` is a wrapper around $(LREF reduce).
*/
template fold(functions...) template fold(functions...)
{ {
/** Implements the homonym function (also known as `accumulate`, `compress`,
`inject`, or `foldl`) present in various programming languages of
functional flavor.
`fold` is functionally equivalent to $(LREF reduce) except the range
parameter comes first and there is no need to use $(REF_ALTTEXT
`tuple`,tuple,std,typecons) for multiple seeds.
There may be one or more callable entities (`functions` argument) to
apply.
Params:
args = Just the range to _fold over; or the range and one seed
per function; or the range, one seed per function, and
the work unit size
Returns:
The accumulated result as a single value for single function and
as a tuple of values for multiple functions
See_Also:
Similar to $(REF _fold, std,algorithm,iteration), `fold` is a wrapper around $(LREF reduce).
Example:
---
static int adder(int a, int b)
{
return a + b;
}
static int multiplier(int a, int b)
{
return a * b;
}
// Just the range
auto x = taskPool.fold!adder([1, 2, 3, 4]);
assert(x == 10);
// The range and the seeds (0 and 1 below; also note multiple
// functions in this example)
auto y = taskPool.fold!(adder, multiplier)([1, 2, 3, 4], 0, 1);
assert(y[0] == 10);
assert(y[1] == 24);
// The range, the seed (0), and the work unit size (20)
auto z = taskPool.fold!adder([1, 2, 3, 4], 0, 20);
assert(z == 10);
---
*/
auto fold(Args...)(Args args) auto fold(Args...)(Args args)
{ {
import std.string : format; import std.string : format;
@ -2841,105 +2874,85 @@ public:
args.length == 1 + functions.length + 1, // range, seeds, and workUnitSize args.length == 1 + functions.length + 1, // range, seeds, and workUnitSize
format("Invalid number of arguments (%s): Should be an input range, %s optional seed(s)," ~ format("Invalid number of arguments (%s): Should be an input range, %s optional seed(s)," ~
" and an optional work unit size", Args.length, functions.length)); " and an optional work unit size", Args.length, functions.length));
return fold(args[0], args[1..$]);
}
/** This is the overload that uses implicit seeds. (See important notes auto range()
on implicit seeds under $(LREF reduce).)
Params:
range = The range of values to _fold over
*/
auto fold(R)(R range)
{
return reduce!functions(range);
}
///
version(StdUnittest)
unittest
{
static int adder(int a, int b) { return a + b; }
auto result = taskPool.fold!adder([1, 2, 3, 4]);
assert(result == 10);
}
/** This is the overload that uses explicit _seeds. (See important notes
on explicit _seeds under $(LREF reduce).)
Params:
range = The range of values to _fold over
seeds = One seed per function
*/
auto fold(R, Seeds...)(R range, Seeds seeds)
if (Seeds.length == functions.length)
{
static if (Seeds.length == 1)
{ {
return reduce!functions(seeds, range); return args[0];
}
static if (Args.length == 1)
{
// Just the range
return reduce!functions(range);
} }
else else
{ {
import std.typecons : tuple; static if (functions.length == 1)
return reduce!functions(tuple(seeds), range); {
auto seeds()
{
return args[1];
}
}
else
{
auto seeds()
{
import std.typecons : tuple;
return tuple(args[1 .. functions.length+1]);
}
}
static if (Args.length == 1 + functions.length)
{
// The range and the seeds
return reduce!functions(seeds, range);
}
else static if (Args.length == 1 + functions.length + 1)
{
// The range, the seeds, and the work unit size
static assert(isIntegral!(Args[$-1]), "Work unit size must be an integral type");
return reduce!functions(seeds, range, args[$-1]);
}
else
{
static assert(0);
}
} }
} }
///
version(StdUnittest)
unittest
{
static int adder (int a, int b) { return a + b; }
static int multiplier(int a, int b) { return a * b; }
auto result = taskPool.fold!(adder, multiplier)([1, 2, 3, 4], 0, 1);
assert(result[0] == 10);
assert(result[1] == 24);
}
/** This is the overload that uses explicit seeds and work unit size.
(See important notes on explicit seeds and work unit size under
$(LREF reduce).)
Params:
range = The range of values to _fold over
args = One seed per function (`args[0..$-1]`) and the work unit
size (`args[$-1]`)
*/
auto fold(R, Args...)(R range, Args args)
if (Args.length == functions.length + 1)
{
static assert(isIntegral!(Args[$-1]), "Work unit size must be an integral type");
static if (Args.length == 2)
{
return reduce!functions(args[0], range, args[1]);
}
else
{
import std.typecons : tuple;
return reduce!functions(tuple(args[0..$-1]), range, args[$-1]);
}
}
///
version(StdUnittest)
unittest
{
static int adder (int a, int b) { return a + b; }
static int multiplier(int a, int b) { return a * b; }
// Single function produces single result
auto result = taskPool.fold!adder([1, 2, 3, 4], 0, 20);
assert(result == 10);
// Multiple functions produce a tuple of results
auto results = taskPool.fold!(adder, multiplier)([1, 2, 3, 4], 0, 1, 30);
assert(results[0] == 10);
assert(results[1] == 24);
}
} }
// This test is not included in the documentation because even though these
// examples are for the inner fold() template, with their current location,
// they would appear under the outer one. (We can't move this inside the
// outer fold() template because then dmd runs out of memory possibly due to
// recursive template instantiation, which is surprisingly not caught.)
version(StdUnittest)
@system unittest
{
static int adder(int a, int b)
{
return a + b;
}
static int multiplier(int a, int b)
{
return a * b;
}
// Just the range
auto x = taskPool.fold!adder([1, 2, 3, 4]);
assert(x == 10);
// The range and the seeds (0 and 1 below; also note multiple
// functions in this example)
auto y = taskPool.fold!(adder, multiplier)([1, 2, 3, 4], 0, 1);
assert(y[0] == 10);
assert(y[1] == 24);
// The range, the seed (0), and the work unit size (20)
auto z = taskPool.fold!adder([1, 2, 3, 4], 0, 20);
assert(z == 10);
}
/** /**
Gets the index of the current thread relative to this $(D TaskPool). Any Gets the index of the current thread relative to this $(D TaskPool). Any
thread not in this pool will receive an index of 0. The worker threads in thread not in this pool will receive an index of 0. The worker threads in
@ -3473,21 +3486,24 @@ public:
} }
version(StdUnittest) version(StdUnittest)
unittest @system unittest
{ {
import std.algorithm : sum; import std.algorithm.iteration : sum;
import std.range : iota; import std.range : iota;
import std.typecons : tuple;
static int adder(int a, int b)
{
return a + b;
}
enum N = 100; enum N = 100;
auto r = iota(1, N + 1); auto r = iota(1, N + 1);
const expected = r.sum(); const expected = r.sum();
static auto adder(int a, int b) {
return a + b;
}
assert(taskPool.fold!adder(r) == expected); assert(taskPool.fold!adder(r) == expected);
assert(taskPool.fold!adder(r, 0) == expected); assert(taskPool.fold!adder(r, 0) == expected);
assert(taskPool.fold!adder(r, 0, 42) == expected); assert(taskPool.fold!(adder, adder)(r, 0, 0, 42) == tuple(expected, expected));
} }
/** /**