diff --git a/std/algorithm/sorting.d b/std/algorithm/sorting.d index 7ab059333..2c462fd93 100644 --- a/std/algorithm/sorting.d +++ b/std/algorithm/sorting.d @@ -555,6 +555,197 @@ Range partition(alias predicate, assert(isPartitioned!`a.length < 5`(b)); } +// pivotPartition +/** + +Partitions `r` around `pivot` using comparison function `less`, algorithm akin +to $(LUCKY Hoare partition). Specifically, permutes elements of `r` and returns +an index $(D k < r.length) such that: + +$(UL + +$(LI `r[pivot]` is swapped to `r[k]`) + +$(LI All elements `e` in subrange $(D r[0 .. k]) satisfy $(D !less(r[k], e)) +(i.e. `r[k]` is greater than or equal to each element to its left according to +predicate `less`)) + +$(LI $(LI All elements `e` in subrange $(D r[0 .. k]) satisfy $(D !less(e, +r[k])) (i.e. `r[k]` is less than or equal to each element to its right +according to predicate `less`)))) + +If `r` contains equivalent elements, multiple permutations of `r` satisfy these +constraints. In such cases, `pivotPartition` attempts to distribute equivalent +elements fairly to the left and right of `k` such that `k` stays close to $(D +r.length / 2). + +Params: +less = The predicate used for comparison, modeled as a $(LUCKY strict weak +ordering) (irreflexive, antisymmetric, transitive, and implies transitive +equivalence) +r = The range being partitioned +pivot = The index of the pivot for partitioning, must be less than `r.length` or +`0` is `r.length` is `0` + +Returns: +The new position of the pivot + +See_Also: +$(HTTP jgrcs.info/index.php/jgrcs/article/view/142, Engineering of a Quicksort +Partitioning Algorithm), D. Abhyankar, Journal of Global Research in Computer +Science, February 2011. $(HTTPS youtube.com/watch?v=AxnotgLql0k, ACCU 2016 +Keynote), Andrei Alexandrescu. +*/ +size_t pivotPartition(alias less = "a < b", Range) +(Range r, size_t pivot) +if (isRandomAccessRange!Range && hasLength!Range && hasSlicing!Range) +{ + assert(pivot < r.length || r.length == 0 && pivot == 0); + if (r.length <= 1) return 0; + import std.algorithm.mutation : swapAt, move; + alias lt = binaryFun!less; + + // Pivot at the front + r.swapAt(pivot, 0); + + // Fork implemnentation depending on nothrow copy, assignment, and + // comparison. If all of these are nothrow, use the specialized + // implementation discussed at https://youtube.com/watch?v=AxnotgLql0k. + static if (is(typeof( + () nothrow { auto x = r.front; x = r.front; return lt(x, x); } + ))) + { + auto p = r[0]; + // Plant the pivot in the end as well as a sentinel + size_t lo = 0, hi = r.length - 1; + auto save = move(r[hi]); + r[hi] = p; // Vacancy is in r[$ - 1] now + // Start process + for (;;) + { + // Loop invariant + version(unittest) + { + import std.algorithm.searching; + assert(r[0 .. lo].all!(x => x <= p)); + assert(r[hi + 1 .. $].all!(x => x >= p)); + } + do ++lo; while (lt(r[lo], p)); + r[hi] = r[lo]; + // Vacancy is now in r[lo] + do --hi; while (lt(p, r[hi])); + if (lo >= hi) break; + r[lo] = r[hi]; + // Vacancy is not in r[hi] + } + // Fixup + assert(lo - hi <= 2); + assert(!lt(p, r[hi])); + if (lo == hi + 2) + { + assert(!lt(r[hi + 1], p)); + r[lo] = r[hi + 1]; + --lo; + } + r[lo] = save; + if (lt(p, save)) --lo; + assert(!lt(p, r[lo])); + } + else + { + size_t lo = 1, hi = r.length - 1; + loop: for (;; lo++, hi--) + { + for (;; ++lo) + { + if (lo > hi) break loop; + if (!lt(r[lo], r[0])) break; + } + // found the left bound: r[lo] >= r[0] + assert(lo <= hi); + for (;; --hi) + { + if (lo >= hi) break loop; + if (!lt(r[0], r[hi])) break; + } + // found the right bound: r[hi] <= r[0], swap & make progress + assert(!lt(r[lo], r[hi])); + r.swapAt(lo, hi); + } + --lo; + } + r.swapAt(lo, 0); + return lo; +} + +/// +@safe nothrow unittest +{ + int[] a = [5, 3, 2, 6, 4, 1, 3, 7]; + size_t pivot = pivotPartition(a, a.length / 2); + import std.algorithm.searching : all; + assert(a[0 .. pivot].all!(x => x <= a[pivot])); + assert(a[pivot .. $].all!(x => x >= a[pivot])); +} + +@safe unittest +{ + void test(alias less)() + { + int[] a; + size_t pivot; + + a = [-9, -4, -2, -2, 9]; + pivot = pivotPartition!less(a, a.length / 2); + import std.algorithm.searching : all; + assert(a[0 .. pivot].all!(x => x <= a[pivot])); + assert(a[pivot .. $].all!(x => x >= a[pivot])); + + a = [9, 2, 8, -5, 5, 4, -8, -4, 9]; + pivot = pivotPartition!less(a, a.length / 2); + assert(a[0 .. pivot].all!(x => x <= a[pivot])); + assert(a[pivot .. $].all!(x => x >= a[pivot])); + + a = [ 42 ]; + pivot = pivotPartition!less(a, a.length / 2); + assert(pivot == 0); + assert(a == [ 42 ]); + + a = [ 43, 42 ]; + pivot = pivotPartition!less(a, 0); + assert(pivot == 1); + assert(a == [ 42, 43 ]); + + a = [ 43, 42 ]; + pivot = pivotPartition!less(a, 1); + assert(pivot == 0); + assert(a == [ 42, 43 ]); + + a = [ 42, 42 ]; + pivot = pivotPartition!less(a, 0); + assert(pivot == 0 || pivot == 1); + assert(a == [ 42, 42 ]); + pivot = pivotPartition!less(a, 1); + assert(pivot == 0 || pivot == 1); + assert(a == [ 42, 42 ]); + + import std.random : uniform; + import std.algorithm.iteration : map; + a = iota(0, uniform(1, 1000)).map!(_ => uniform(-1000, 1000)).array; + pivot = pivotPartition!less(a, a.length / 2); + assert(a[0 .. pivot].all!(x => x <= a[pivot])); + assert(a[pivot .. $].all!(x => x >= a[pivot])); + } + test!"a < b"; + static bool myLess(int a, int b) + { + static bool bogus; + if (bogus) throw new Exception(""); // just to make it no-nothrow + return a < b; + } + test!myLess; +} + /** Params: pred = The predicate that the range should be partitioned by. @@ -2617,9 +2808,6 @@ schwartzSort(alias transform, alias less = "a < b", arr[2] = highEnt; schwartzSort!(entropy, q{a > b})(arr); - assert(arr[0] == highEnt); - assert(arr[1] == midEnt); - assert(arr[2] == lowEnt); assert(isSorted!("a > b")(map!(entropy)(arr))); } @@ -2650,9 +2838,6 @@ schwartzSort(alias transform, alias less = "a < b", arr[2] = highEnt; schwartzSort!(entropy, q{a < b})(arr); - assert(arr[0] == lowEnt); - assert(arr[1] == midEnt); - assert(arr[2] == highEnt); assert(isSorted!("a < b")(map!(entropy)(arr))); }