topN rewrite

This commit is contained in:
Andrei Alexandrescu 2016-09-23 11:49:39 -04:00
parent 336f5c47d2
commit f19c92a1bf

View file

@ -640,8 +640,8 @@ if (isRandomAccessRange!Range && hasLength!Range && hasSlicing!Range)
version(unittest)
{
import std.algorithm.searching;
assert(r[0 .. lo].all!(x => x <= p));
assert(r[hi + 1 .. $].all!(x => x >= p));
assert(r[0 .. lo].all!(x => !lt(p, x)));
assert(r[hi + 1 .. $].all!(x => !lt(x, p)));
}
do ++lo; while (lt(r[lo], p));
r[hi] = r[lo];
@ -2934,52 +2934,120 @@ auto topN(alias less = "a < b",
{
static assert(ss == SwapStrategy.unstable,
"Stable topN not yet implemented");
if (nth >= r.length) return r[0 .. r.length];
auto ret = r[0 .. nth];
if (false)
{
// Workaround for https://issues.dlang.org/show_bug.cgi?id=16528
// Safety checks: enumerate all potentially unsafe generic primitives
// then use a @trusted implementation.
binaryFun!less(r[0], r[0]);
import std.algorithm.mutation : swapAt;
r.swapAt(0, 0);
}
bool useSampling = true;
topNImpl!(binaryFun!less)(r, nth, useSampling);
return ret;
}
private @trusted
void topNImpl(alias less, R)(R r, size_t n, ref bool useSampling)
{
for (;;)
{
assert(nth < r.length);
import std.algorithm.mutation : swap;
import std.algorithm.searching : minPos;
if (nth == 0)
import std.algorithm.mutation : swapAt;
assert(n < r.length);
size_t pivot = void;
// Decide strategy for partitioning
if (n == 0)
{
// Special-case "min"
swap(r.front, r.minPos!less.front);
break;
pivot = 0;
foreach (i; 1 .. r.length)
if (less(r[i], r[pivot])) pivot = i;
r.swapAt(n, pivot);
return;
}
if (nth + 1 == r.length)
if (n + 1 == r.length)
{
// Special-case "max"
swap(r.back, r.minPos!((a, b) => binaryFun!less(b, a)).front);
break;
pivot = 0;
foreach (i; 1 .. r.length)
if (less(r[pivot], r[i])) pivot = i;
r.swapAt(n, pivot);
return;
}
auto pivot = r.getPivot!less;
assert(!binaryFun!less(r[pivot], r[pivot]));
swap(r[pivot], r.back);
auto right = r.partition!(a => binaryFun!less(a, r.back), ss);
assert(right.length >= 1);
pivot = r.length - right.length;
if (pivot > nth)
if (r.length <= 12)
{
pivot = pivotPartition!less(r, r.length / 2);
}
else if (n * 16 <= (r.length - 1) * 7)
{
pivot = topNPartitionOffMedian!(less, No.leanRight)
(r, n, useSampling);
// Quality check
if (useSampling)
{
if (pivot < n)
{
if (pivot * 4 < r.length)
{
useSampling = false;
}
}
else if ((r.length - pivot) * 8 < r.length * 3)
{
useSampling = false;
}
}
}
else if (n * 16 >= (r.length - 1) * 9)
{
pivot = topNPartitionOffMedian!(less, Yes.leanRight)
(r, n, useSampling);
// Quality check
if (useSampling)
{
if (pivot < n)
{
if (pivot * 8 < r.length * 3)
{
useSampling = false;
}
}
else if ((r.length - pivot) * 4 < r.length)
{
useSampling = false;
}
}
}
else
{
pivot = topNPartition!less(r, n, useSampling);
// Quality check
if (useSampling &&
(pivot * 9 < r.length * 2 || pivot * 9 > r.length * 7))
{
// Failed - abort sampling going forward
useSampling = false;
}
}
assert(pivot != size_t.max);
// See how the pivot fares
if (pivot == n)
{
return;
}
if (pivot > n)
{
// We don't care to swap the pivot back, won't be visited anymore
assert(pivot < r.length);
r = r[0 .. pivot];
continue;
}
// Swap the pivot to where it should be
swap(right.front, r.back);
if (pivot == nth)
else
{
// Found Waldo!
break;
n -= pivot + 1;
r = r[pivot + 1 .. $];
}
++pivot; // skip the pivot
r = r[pivot .. r.length];
nth -= pivot;
}
return ret;
}
///
@ -2993,6 +3061,223 @@ auto topN(alias less = "a < b",
assert(v[n] == 9);
}
private size_t topNPartition(alias lp, R)(R r, size_t n, bool useSampling)
{
assert(r.length >= 9 && n < r.length);
immutable ninth = r.length / 9;
auto pivot = ninth / 2;
// Position subrange r[lo .. hi] to have length equal to ninth and its upper
// median r[lo .. hi][$ / 2] in exactly the same place as the upper median
// of the entire range r[$ / 2]. This is to improve behavior for searching
// the median in already sorted ranges.
immutable lo = r.length / 2 - pivot, hi = lo + ninth;
// We have either one straggler on the left, one on the right, or none.
assert(lo - (r.length - hi) <= 1 || (r.length - hi) - lo <= 1);
assert(lo >= ninth * 4);
assert(r.length - hi >= ninth * 4);
// Partition in groups of 3, and the mid tertile again in groups of 3
if (!useSampling)
p3!lp(r, lo - ninth, hi + ninth);
p3!lp(r, lo, hi);
// Get the median of medians of medians
// Map the full interval of n to the full interval of the ninth
pivot = (n * (ninth - 1)) / (r.length - 1);
topNImpl!lp(r[lo .. hi], pivot, useSampling);
return expandPartition!lp(r, lo, pivot + lo, hi);
}
private void p3(alias less, Range)(Range r, size_t lo, immutable size_t hi)
{
assert(lo <= hi && hi < r.length);
immutable ln = hi - lo;
for (; lo < hi; ++lo)
{
assert(lo >= ln);
assert(lo + ln < r.length);
medianOf!less(r, lo - ln, lo, lo + ln);
}
}
private void p4(alias less, Flag!"leanRight" f, Range)
(Range r, size_t lo, immutable size_t hi)
{
assert(lo <= hi && hi < r.length);
immutable ln = hi - lo, _2ln = ln * 2;
for (; lo < hi; ++lo)
{
assert(lo >= ln);
assert(lo + ln < r.length);
static if (f == Yes.leanRight)
medianOf!(less, f)(r, lo - _2ln, lo - ln, lo, lo + ln);
else
medianOf!(less, f)(r, lo - ln, lo, lo + ln, lo + _2ln);
}
}
@trusted private size_t topNPartitionOffMedian(alias lp, Flag!"leanRight" f, R)
(R r, size_t n, bool useSampling)
{
assert(r.length >= 12);
assert(n < r.length);
immutable _4 = r.length / 4;
static if (f == Yes.leanRight)
immutable leftLimit = 2 * _4;
else
immutable leftLimit = _4;
// Partition in groups of 4, and the left quartile again in groups of 3
if (!useSampling)
{
p4!(lp, f)(r, leftLimit, leftLimit + _4);
}
immutable _12 = _4 / 3;
immutable lo = leftLimit + _12, hi = lo + _12;
p3!lp(r, lo, hi);
// Get the median of medians of medians
// Map the full interval of n to the full interval of the ninth
immutable pivot = (n * (_12 - 1)) / (r.length - 1);
topNImpl!lp(r[lo .. hi], pivot, useSampling);
return expandPartition!lp(r, lo, pivot + lo, hi);
}
/*
Params:
less = predicate
r = range to partition
pivot = pivot to partition around
lo = value such that r[lo .. pivot] already less than r[pivot]
hi = value such that r[pivot .. hi] already greater than r[pivot]
Returns: new position of pivot
*/
private
size_t expandPartition(alias lp, R)(R r, size_t lo, size_t pivot, size_t hi)
in
{
import std.algorithm.searching : all;
assert(lo <= pivot);
assert(pivot < hi);
assert(hi <= r.length);
assert(r[lo .. pivot + 1].all!(x => !lp(r[pivot], x)));
assert(r[pivot + 1 .. hi].all!(x => !lp(x, r[pivot])));
}
out
{
import std.algorithm.searching : all;
assert(r[0 .. pivot + 1].all!(x => !lp(r[pivot], x)));
assert(r[pivot + 1 .. $].all!(x => !lp(x, r[pivot])));
}
body
{
import std.algorithm.mutation : swapAt;
import std.algorithm.searching : all;
// We work with closed intervals!
--hi;
size_t left = 0, rite = r.length - 1;
loop: for (;; ++left, --rite)
{
for (;; ++left)
{
if (left == lo) break loop;
if (!lp(r[left], r[pivot])) break;
}
for (;; --rite)
{
if (rite == hi) break loop;
if (!lp(r[pivot], r[rite])) break;
}
r.swapAt(left, rite);
}
assert(r[lo .. pivot + 1].all!(x => !lp(r[pivot], x)));
assert(r[pivot + 1 .. hi + 1].all!(x => !lp(x, r[pivot])));
assert(r[0 .. left].all!(x => !lp(r[pivot], x)));
assert(r[rite + 1 .. $].all!(x => !lp(x, r[pivot])));
immutable oldPivot = pivot;
if (left < lo)
{
// First loop: spend r[lo .. pivot]
for (; lo < pivot; ++left)
{
if (left == lo) goto done;
if (!lp(r[oldPivot], r[left])) continue;
--pivot;
assert(!lp(r[oldPivot], r[pivot]));
r.swapAt(left, pivot);
}
// Second loop: make left and pivot meet
for (;; ++left)
{
if (left == pivot) goto done;
if (!lp(r[oldPivot], r[left])) continue;
for (;;)
{
if (left == pivot) goto done;
--pivot;
if (lp(r[pivot], r[oldPivot]))
{
r.swapAt(left, pivot);
break;
}
}
}
}
// First loop: spend r[lo .. pivot]
for (; hi != pivot; --rite)
{
if (rite == hi) goto done;
if (!lp(r[rite], r[oldPivot])) continue;
++pivot;
assert(!lp(r[pivot], r[oldPivot]));
r.swapAt(rite, pivot);
}
// Second loop: make left and pivot meet
outer: for (; rite > pivot; --rite)
{
if (!lp(r[rite], r[oldPivot])) continue;
while (rite > pivot)
{
++pivot;
if (lp(r[oldPivot], r[pivot]))
{
r.swapAt(rite, pivot);
break;
}
}
}
done:
r.swapAt(oldPivot, pivot);
return pivot;
}
unittest
{
auto a = [ 10, 5, 3, 4, 8, 11, 13, 3, 9, 4, 10 ];
assert(expandPartition!((a, b) => a < b)(a, 4, 5, 6) == 9);
a = randomArray;
if (a.length == 0) return;
expandPartition!((a, b) => a < b)(a, a.length / 2, a.length / 2,
a.length / 2 + 1);
}
version(unittest)
private T[] randomArray(Flag!"exactSize" flag = No.exactSize, T = int)(
size_t maxSize = 1000,
T minValue = 0, T maxValue = 255)
{
import std.random : unpredictableSeed, Random, uniform;
import std.algorithm.iteration : map;
auto size = flag == Yes.exactSize ? maxSize : uniform(1, maxSize);
return iota(0, size).map!(_ => uniform(minValue, maxValue)).array;
}
@safe unittest
{
import std.algorithm.comparison : max, min;