mirror of
https://github.com/dlang/phobos.git
synced 2025-05-07 11:37:24 +03:00
topN rewrite
This commit is contained in:
parent
336f5c47d2
commit
f19c92a1bf
1 changed files with 319 additions and 34 deletions
|
@ -640,8 +640,8 @@ if (isRandomAccessRange!Range && hasLength!Range && hasSlicing!Range)
|
|||
version(unittest)
|
||||
{
|
||||
import std.algorithm.searching;
|
||||
assert(r[0 .. lo].all!(x => x <= p));
|
||||
assert(r[hi + 1 .. $].all!(x => x >= p));
|
||||
assert(r[0 .. lo].all!(x => !lt(p, x)));
|
||||
assert(r[hi + 1 .. $].all!(x => !lt(x, p)));
|
||||
}
|
||||
do ++lo; while (lt(r[lo], p));
|
||||
r[hi] = r[lo];
|
||||
|
@ -2934,52 +2934,120 @@ auto topN(alias less = "a < b",
|
|||
{
|
||||
static assert(ss == SwapStrategy.unstable,
|
||||
"Stable topN not yet implemented");
|
||||
|
||||
if (nth >= r.length) return r[0 .. r.length];
|
||||
|
||||
auto ret = r[0 .. nth];
|
||||
if (false)
|
||||
{
|
||||
// Workaround for https://issues.dlang.org/show_bug.cgi?id=16528
|
||||
// Safety checks: enumerate all potentially unsafe generic primitives
|
||||
// then use a @trusted implementation.
|
||||
binaryFun!less(r[0], r[0]);
|
||||
import std.algorithm.mutation : swapAt;
|
||||
r.swapAt(0, 0);
|
||||
}
|
||||
bool useSampling = true;
|
||||
topNImpl!(binaryFun!less)(r, nth, useSampling);
|
||||
return ret;
|
||||
}
|
||||
|
||||
private @trusted
|
||||
void topNImpl(alias less, R)(R r, size_t n, ref bool useSampling)
|
||||
{
|
||||
for (;;)
|
||||
{
|
||||
assert(nth < r.length);
|
||||
import std.algorithm.mutation : swap;
|
||||
import std.algorithm.searching : minPos;
|
||||
if (nth == 0)
|
||||
import std.algorithm.mutation : swapAt;
|
||||
assert(n < r.length);
|
||||
size_t pivot = void;
|
||||
|
||||
// Decide strategy for partitioning
|
||||
if (n == 0)
|
||||
{
|
||||
// Special-case "min"
|
||||
swap(r.front, r.minPos!less.front);
|
||||
break;
|
||||
pivot = 0;
|
||||
foreach (i; 1 .. r.length)
|
||||
if (less(r[i], r[pivot])) pivot = i;
|
||||
r.swapAt(n, pivot);
|
||||
return;
|
||||
}
|
||||
if (nth + 1 == r.length)
|
||||
if (n + 1 == r.length)
|
||||
{
|
||||
// Special-case "max"
|
||||
swap(r.back, r.minPos!((a, b) => binaryFun!less(b, a)).front);
|
||||
break;
|
||||
pivot = 0;
|
||||
foreach (i; 1 .. r.length)
|
||||
if (less(r[pivot], r[i])) pivot = i;
|
||||
r.swapAt(n, pivot);
|
||||
return;
|
||||
}
|
||||
auto pivot = r.getPivot!less;
|
||||
assert(!binaryFun!less(r[pivot], r[pivot]));
|
||||
swap(r[pivot], r.back);
|
||||
auto right = r.partition!(a => binaryFun!less(a, r.back), ss);
|
||||
assert(right.length >= 1);
|
||||
pivot = r.length - right.length;
|
||||
if (pivot > nth)
|
||||
if (r.length <= 12)
|
||||
{
|
||||
pivot = pivotPartition!less(r, r.length / 2);
|
||||
}
|
||||
else if (n * 16 <= (r.length - 1) * 7)
|
||||
{
|
||||
pivot = topNPartitionOffMedian!(less, No.leanRight)
|
||||
(r, n, useSampling);
|
||||
// Quality check
|
||||
if (useSampling)
|
||||
{
|
||||
if (pivot < n)
|
||||
{
|
||||
if (pivot * 4 < r.length)
|
||||
{
|
||||
useSampling = false;
|
||||
}
|
||||
}
|
||||
else if ((r.length - pivot) * 8 < r.length * 3)
|
||||
{
|
||||
useSampling = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (n * 16 >= (r.length - 1) * 9)
|
||||
{
|
||||
pivot = topNPartitionOffMedian!(less, Yes.leanRight)
|
||||
(r, n, useSampling);
|
||||
// Quality check
|
||||
if (useSampling)
|
||||
{
|
||||
if (pivot < n)
|
||||
{
|
||||
if (pivot * 8 < r.length * 3)
|
||||
{
|
||||
useSampling = false;
|
||||
}
|
||||
}
|
||||
else if ((r.length - pivot) * 4 < r.length)
|
||||
{
|
||||
useSampling = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
pivot = topNPartition!less(r, n, useSampling);
|
||||
// Quality check
|
||||
if (useSampling &&
|
||||
(pivot * 9 < r.length * 2 || pivot * 9 > r.length * 7))
|
||||
{
|
||||
// Failed - abort sampling going forward
|
||||
useSampling = false;
|
||||
}
|
||||
}
|
||||
|
||||
assert(pivot != size_t.max);
|
||||
// See how the pivot fares
|
||||
if (pivot == n)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if (pivot > n)
|
||||
{
|
||||
// We don't care to swap the pivot back, won't be visited anymore
|
||||
assert(pivot < r.length);
|
||||
r = r[0 .. pivot];
|
||||
continue;
|
||||
}
|
||||
// Swap the pivot to where it should be
|
||||
swap(right.front, r.back);
|
||||
if (pivot == nth)
|
||||
else
|
||||
{
|
||||
// Found Waldo!
|
||||
break;
|
||||
n -= pivot + 1;
|
||||
r = r[pivot + 1 .. $];
|
||||
}
|
||||
++pivot; // skip the pivot
|
||||
r = r[pivot .. r.length];
|
||||
nth -= pivot;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
///
|
||||
|
@ -2993,6 +3061,223 @@ auto topN(alias less = "a < b",
|
|||
assert(v[n] == 9);
|
||||
}
|
||||
|
||||
private size_t topNPartition(alias lp, R)(R r, size_t n, bool useSampling)
|
||||
{
|
||||
assert(r.length >= 9 && n < r.length);
|
||||
immutable ninth = r.length / 9;
|
||||
auto pivot = ninth / 2;
|
||||
// Position subrange r[lo .. hi] to have length equal to ninth and its upper
|
||||
// median r[lo .. hi][$ / 2] in exactly the same place as the upper median
|
||||
// of the entire range r[$ / 2]. This is to improve behavior for searching
|
||||
// the median in already sorted ranges.
|
||||
immutable lo = r.length / 2 - pivot, hi = lo + ninth;
|
||||
// We have either one straggler on the left, one on the right, or none.
|
||||
assert(lo - (r.length - hi) <= 1 || (r.length - hi) - lo <= 1);
|
||||
assert(lo >= ninth * 4);
|
||||
assert(r.length - hi >= ninth * 4);
|
||||
|
||||
// Partition in groups of 3, and the mid tertile again in groups of 3
|
||||
if (!useSampling)
|
||||
p3!lp(r, lo - ninth, hi + ninth);
|
||||
p3!lp(r, lo, hi);
|
||||
|
||||
// Get the median of medians of medians
|
||||
// Map the full interval of n to the full interval of the ninth
|
||||
pivot = (n * (ninth - 1)) / (r.length - 1);
|
||||
topNImpl!lp(r[lo .. hi], pivot, useSampling);
|
||||
return expandPartition!lp(r, lo, pivot + lo, hi);
|
||||
}
|
||||
|
||||
private void p3(alias less, Range)(Range r, size_t lo, immutable size_t hi)
|
||||
{
|
||||
assert(lo <= hi && hi < r.length);
|
||||
immutable ln = hi - lo;
|
||||
for (; lo < hi; ++lo)
|
||||
{
|
||||
assert(lo >= ln);
|
||||
assert(lo + ln < r.length);
|
||||
medianOf!less(r, lo - ln, lo, lo + ln);
|
||||
}
|
||||
}
|
||||
|
||||
private void p4(alias less, Flag!"leanRight" f, Range)
|
||||
(Range r, size_t lo, immutable size_t hi)
|
||||
{
|
||||
assert(lo <= hi && hi < r.length);
|
||||
immutable ln = hi - lo, _2ln = ln * 2;
|
||||
for (; lo < hi; ++lo)
|
||||
{
|
||||
assert(lo >= ln);
|
||||
assert(lo + ln < r.length);
|
||||
static if (f == Yes.leanRight)
|
||||
medianOf!(less, f)(r, lo - _2ln, lo - ln, lo, lo + ln);
|
||||
else
|
||||
medianOf!(less, f)(r, lo - ln, lo, lo + ln, lo + _2ln);
|
||||
}
|
||||
}
|
||||
|
||||
@trusted private size_t topNPartitionOffMedian(alias lp, Flag!"leanRight" f, R)
|
||||
(R r, size_t n, bool useSampling)
|
||||
{
|
||||
assert(r.length >= 12);
|
||||
assert(n < r.length);
|
||||
immutable _4 = r.length / 4;
|
||||
static if (f == Yes.leanRight)
|
||||
immutable leftLimit = 2 * _4;
|
||||
else
|
||||
immutable leftLimit = _4;
|
||||
// Partition in groups of 4, and the left quartile again in groups of 3
|
||||
if (!useSampling)
|
||||
{
|
||||
p4!(lp, f)(r, leftLimit, leftLimit + _4);
|
||||
}
|
||||
immutable _12 = _4 / 3;
|
||||
immutable lo = leftLimit + _12, hi = lo + _12;
|
||||
p3!lp(r, lo, hi);
|
||||
|
||||
// Get the median of medians of medians
|
||||
// Map the full interval of n to the full interval of the ninth
|
||||
immutable pivot = (n * (_12 - 1)) / (r.length - 1);
|
||||
topNImpl!lp(r[lo .. hi], pivot, useSampling);
|
||||
return expandPartition!lp(r, lo, pivot + lo, hi);
|
||||
}
|
||||
|
||||
/*
|
||||
Params:
|
||||
less = predicate
|
||||
r = range to partition
|
||||
pivot = pivot to partition around
|
||||
lo = value such that r[lo .. pivot] already less than r[pivot]
|
||||
hi = value such that r[pivot .. hi] already greater than r[pivot]
|
||||
|
||||
Returns: new position of pivot
|
||||
*/
|
||||
private
|
||||
size_t expandPartition(alias lp, R)(R r, size_t lo, size_t pivot, size_t hi)
|
||||
in
|
||||
{
|
||||
import std.algorithm.searching : all;
|
||||
assert(lo <= pivot);
|
||||
assert(pivot < hi);
|
||||
assert(hi <= r.length);
|
||||
assert(r[lo .. pivot + 1].all!(x => !lp(r[pivot], x)));
|
||||
assert(r[pivot + 1 .. hi].all!(x => !lp(x, r[pivot])));
|
||||
}
|
||||
out
|
||||
{
|
||||
import std.algorithm.searching : all;
|
||||
assert(r[0 .. pivot + 1].all!(x => !lp(r[pivot], x)));
|
||||
assert(r[pivot + 1 .. $].all!(x => !lp(x, r[pivot])));
|
||||
}
|
||||
body
|
||||
{
|
||||
import std.algorithm.mutation : swapAt;
|
||||
import std.algorithm.searching : all;
|
||||
// We work with closed intervals!
|
||||
--hi;
|
||||
|
||||
size_t left = 0, rite = r.length - 1;
|
||||
loop: for (;; ++left, --rite)
|
||||
{
|
||||
for (;; ++left)
|
||||
{
|
||||
if (left == lo) break loop;
|
||||
if (!lp(r[left], r[pivot])) break;
|
||||
}
|
||||
for (;; --rite)
|
||||
{
|
||||
if (rite == hi) break loop;
|
||||
if (!lp(r[pivot], r[rite])) break;
|
||||
}
|
||||
r.swapAt(left, rite);
|
||||
}
|
||||
|
||||
assert(r[lo .. pivot + 1].all!(x => !lp(r[pivot], x)));
|
||||
assert(r[pivot + 1 .. hi + 1].all!(x => !lp(x, r[pivot])));
|
||||
assert(r[0 .. left].all!(x => !lp(r[pivot], x)));
|
||||
assert(r[rite + 1 .. $].all!(x => !lp(x, r[pivot])));
|
||||
|
||||
immutable oldPivot = pivot;
|
||||
|
||||
if (left < lo)
|
||||
{
|
||||
// First loop: spend r[lo .. pivot]
|
||||
for (; lo < pivot; ++left)
|
||||
{
|
||||
if (left == lo) goto done;
|
||||
if (!lp(r[oldPivot], r[left])) continue;
|
||||
--pivot;
|
||||
assert(!lp(r[oldPivot], r[pivot]));
|
||||
r.swapAt(left, pivot);
|
||||
}
|
||||
// Second loop: make left and pivot meet
|
||||
for (;; ++left)
|
||||
{
|
||||
if (left == pivot) goto done;
|
||||
if (!lp(r[oldPivot], r[left])) continue;
|
||||
for (;;)
|
||||
{
|
||||
if (left == pivot) goto done;
|
||||
--pivot;
|
||||
if (lp(r[pivot], r[oldPivot]))
|
||||
{
|
||||
r.swapAt(left, pivot);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// First loop: spend r[lo .. pivot]
|
||||
for (; hi != pivot; --rite)
|
||||
{
|
||||
if (rite == hi) goto done;
|
||||
if (!lp(r[rite], r[oldPivot])) continue;
|
||||
++pivot;
|
||||
assert(!lp(r[pivot], r[oldPivot]));
|
||||
r.swapAt(rite, pivot);
|
||||
}
|
||||
// Second loop: make left and pivot meet
|
||||
outer: for (; rite > pivot; --rite)
|
||||
{
|
||||
if (!lp(r[rite], r[oldPivot])) continue;
|
||||
while (rite > pivot)
|
||||
{
|
||||
++pivot;
|
||||
if (lp(r[oldPivot], r[pivot]))
|
||||
{
|
||||
r.swapAt(rite, pivot);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
r.swapAt(oldPivot, pivot);
|
||||
return pivot;
|
||||
}
|
||||
|
||||
unittest
|
||||
{
|
||||
auto a = [ 10, 5, 3, 4, 8, 11, 13, 3, 9, 4, 10 ];
|
||||
assert(expandPartition!((a, b) => a < b)(a, 4, 5, 6) == 9);
|
||||
a = randomArray;
|
||||
if (a.length == 0) return;
|
||||
expandPartition!((a, b) => a < b)(a, a.length / 2, a.length / 2,
|
||||
a.length / 2 + 1);
|
||||
}
|
||||
|
||||
version(unittest)
|
||||
private T[] randomArray(Flag!"exactSize" flag = No.exactSize, T = int)(
|
||||
size_t maxSize = 1000,
|
||||
T minValue = 0, T maxValue = 255)
|
||||
{
|
||||
import std.random : unpredictableSeed, Random, uniform;
|
||||
import std.algorithm.iteration : map;
|
||||
auto size = flag == Yes.exactSize ? maxSize : uniform(1, maxSize);
|
||||
return iota(0, size).map!(_ => uniform(minValue, maxValue)).array;
|
||||
}
|
||||
|
||||
@safe unittest
|
||||
{
|
||||
import std.algorithm.comparison : max, min;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue