Quicksort Revisited

転職活動のコーディングインタビューに備えてクイックソートを勉強し直していました。 それで、クイックソートなんて実装することはまずないだろうと思って今まで適当にやっていたので、今さら気が付いたのですが、クイックソートは pivot を適当に選ぶと停止しなかったり、セグフォしたりするのでアルゴリズム全体より pivot の選び方の方が重要だなと。 というわけで pivot の選び方の注意点やクイックソートのアルゴリズムについて触れたいと思います。

まず、よく知られているクイックソートのアルゴリズムの復習です。

#include <string>
#include <cstdio>
#include <cstdint>

template <typename T>
T med3(T x, T y, T z) {
  if (x < y) {
    if (y < z) return y; else if (z < x) return x; else return z;
  } else {
    if (z < y) return y; else if (x < z) return x; else return z;
  }
}

template <typename T>
void quick_sort_med3(T* xs, std::size_t N, int l, int r)
{
  if (l >= r) return;

  auto i = l, j = r;
  auto pi = med3(xs[i], xs[i + (j - i) / 2], xs[j]);

  while (true) {
    while (xs[i] < pi) ++i;
    while (xs[j] > pi) --j;

    if (i >= j) break;
    
    std::swap(xs[i], xs[j]);
    ++i;
    --j;
  }
  
  quick_sort_med3(xs, N, l, i - 1);
  quick_sort_med3(xs, N, j + 1, r);
}

まぁだいたいこのような実装になります。quick_sort_med3 は pivot として 3 つの要素の中央値を選んでいます。こうするのは、なるべく均等に 2 分割になるようにして、クイックソートの最良のオーダー(O(N logN))に近づけるためです。 もちろん他の pivot の選択方法でも構わなくて、例えば、常にデータの最小の添字(あるいは最大の添字)を pivot として選ぶなど手抜きをする方法があります。

ただ、ここで注意しなければならないのは、pivot としてソート対象のデータの最小値以下(または最大値以上)の値を選択しないことです。なぜなら、上記クイックソートは再帰的な実装になっていますが、ソート対象のデータを 2 分割した際に片方が空になってしまうと無限に再帰し続けて停止しなくなるからです。

実際に停止しなくなることを確認するには pivot を以下のように選ぶように変更してプログラムを実行してみてください。

auto pivot = std::min_element(xs, xs + N);

std::min_element を使うと計算量が O(N logN) におさまらなくなりますけど、あくまで実験ということで。

さて、ついでにクイックソートの他の実装も見ておきましょう。実装にいくつかバリエーションがあって、例えば Functional Programming in Scala という本の Chapter 14 で ST Monad でクイックソートを実装する章があるわけですが、その章で使われているアルゴリズムを C++ にすると以下のようになります。

#include <string>
#include <cstdio>
#include <cstdint>

template <typename T>
int partition(T* xs, int l, int r, int pivot)
{
  auto v = xs[pivot];
  auto pi = l;
  
  std::swap(xs[pivot], xs[r]);

  for (auto i = l; i < r; ++i) {
    if (xs[i] >= v) continue;

    std::swap(xs[i], xs[pi]);
    pi += 1;
  }

  std::swap(xs[pi], xs[r]);
  
  return pi;
}

template <typename T>
void quick_sort(T* xs, std::size_t N, int l, int r)
{
  if (l >= r) return;
  auto pi = partition(xs, l, r, l + (r - l) / 2);
  quick_sort(xs, N, l, pi - 1);
  quick_sort(xs, N, pi + 1, r);
}

この partition 関数は初見の方が多いのではないでしょうか。やっていることは本質的には同じです。調べてみたら英語版の Wikipedia では言及されていましたが、日本語版の Wikipedia には記載されていませんでした。

余談ですが、ST Monad を使ったクイックソートは Scala だと以下のようになります。実装というか変数名が若干適当ですが面倒だったということで…。

sealed trait ST[S, A] {
  self =>

  protected def run(s: S): (A, S)

  def map[B](f: A => B): ST[S, B] =
    new ST[S, B] {
      override protected def run(s: S): (B, S) = {
        val (a, s1) = self.run(s)
        (f(a), s1)
      }
    }

  def flatMap[B](f: A => ST[S, B]): ST[S, B] =
    new ST[S, B] {
      override protected def run(s: S): (B, S) = {
        val (a, s1) = self.run(s)
        f(a).run(s1)
      }
    }
}

object ST {
  def apply[S, A](a: => A): ST[S, A] = {
    lazy val memo = a
    new ST[S, A] {
      override protected def run(s: S): (A, S) = {
        (memo, s)
      }
    }
  }

  def runST[A](st: RunnableST[A]): A =
    st.apply[Unit].run(())._1
}

trait RunnableST[A] {
  def apply[S]: ST[S, A]
}

sealed trait STRef[S, A] {
  protected var cell: A

  def read: ST[S, A] = ST(cell)

  def write(x: A): ST[S, Unit] =
    new ST[S, Unit] {
      override protected def run(s: S): (Unit, S) = {
        cell = x
        ((), s)
      }
    }
}

object STRef {
  def apply[S, A](x: => A): ST[S, STRef[S, A]] =
    ST(
      new STRef[S, A] {
        override protected var cell: A = x
      }
    )
}

sealed abstract class STArray[S, A](implicit m: Manifest[A]) {
  protected def value: Array[A]

  def size: ST[S, Int] = ST(value.size)

  def write(i: Int, x: A): ST[S, Unit] =
    new ST[S, Unit] {
      override protected def run(s: S): (Unit, S) = {
        value(i) = x
        ((), s)
      }
    }

  def read(i: Int): ST[S, A] = ST(value(i))

  def swap(i: Int, j: Int): ST[S, Unit] =
    for {
      x <- read(i)
      y <- read(j)
      _ <- write(i, y)
      _ <- write(j, x)
    } yield ()

  def freeze: ST[S, List[A]] = ST(value.toList)
}

object STArray {
  def apply[S, A: Manifest](sz :Int, x: A): ST[S, STArray[S, A]] =
    ST(new STArray[S, A] {
      override protected val value: Array[A] = Array.fill(sz)(x)
    })

  def fromList[S, A: Manifest](xs: List[A]): ST[S, STArray[S, A]] =
    ST(new STArray[S, A] {
      override protected val value: Array[A] = xs.toArray
    })
}

object QuickSort {
  def partition[S](arr: STArray[S, Int], n: Int, r: Int, pivot: Int): ST[S, Int] =
    for {
      v <- arr.read(pivot)

      _ <- arr.swap(pivot, r)

      j <- STRef(n)

      _ <- (n until r).foldLeft(ST[S, Unit](())) { (s, i) =>
        for {
          _  <- s
          p  <- arr.read(i)
          _  <- if (p >= v) ST[S, Unit](()) else
            for {
              j0 <- j.read
              _  <- arr.swap(i, j0)
              _  <- j.write(j0 + 1)
            } yield ()
        } yield ()
      }

      j0 <- j.read

      _ <- arr.swap(j0, r)
    } yield j0

  def sort[S](arr: STArray[S, Int], n: Int, r: Int): ST[S, Unit] =
    if (n >= r) ST[S, Unit](()) else for {
      pi <- partition(arr, n, r, n + (r - n) / 2)
      _  <- sort(arr, n, pi - 1)
      _  <- sort(arr, pi + 1, r)
    } yield ()

  def sort(xs: List[Int]): List[Int] = {
    if (xs.isEmpty) xs else
      ST.runST(new RunnableST[List[Int]] {
        override def apply[S]: ST[S, List[Int]] = {
          for {
            arr <- STArray.fromList(xs)
            _   <- sort(arr, 0, xs.size - 1)
            ys  <- arr.freeze
          } yield ys
        }
      })
  }
}

Comments

comments powered by Disqus