Rubyで素朴なSA-IS(suffix array induced sorting)を書いてみた

FM-indexと同様に、ほぼ最適化をしていない素朴な実装を作ってみました。

元の論文の実装は空間効率まで考慮して書かれていてすごい! ……のですが、まずはそれ以外の部分の仕組みを理解したかったので、以下のコードではそこらへんも無視しています。

コード

require "minitest/autorun"

SENTINEL = "$"

TYPE_L = "L"
TYPE_S = "S"

NAME_CHARS = [
  "a","b","c","d","e","f","g","h","i","j",
  "k","l","m","n","o","p","q","r","s","t",
  "u","v","w","x","y","z"
]

def _assert(exp)
  raise "must not happen" if exp == false
end

# from <= i <= to
def _each_up(from, to)
  ( from.upto(to) ).each {|i| yield(i) }
end

# from >= i >= to
def _each_down(from, to)
  ( from.downto(to) ).each {|i| yield(i) }
end


class Buckets

  def initialize(s_cs)
    @ary = []
    @freq_map = make_freq_map(s_cs)
    @sorted_uniq_chars = @freq_map.keys().sort()

    clear(s_cs.length())
  end

  def clear(length)
    _each_up(0, length - 1) {|i|
      @ary[i] = nil
    }
  end

  def make_freq_map(s_cs)
    uniq_cs = s_cs.uniq()

    # 初期化
    freq_map = {}
    uniq_cs.each {|c| freq_map[c] = 0 }

    s_cs.each {|c| freq_map[c] += 1 }
    return freq_map
  end

  def first_pos(c)
    i = 0
    found = false
    @sorted_uniq_chars.each {|iter_c|
      if (iter_c == c)
        found = true
        break
      end

      i += num_bucket_elements(iter_c)
    }
    return i
  end

  def num_bucket_elements(c)
    return @freq_map[c]
  end

  def add_l(c, si)
    # バケツの先頭・末尾
    bia = first_pos(c)
    biz = bia + num_bucket_elements(c) - 1

    i = nil
    # 線形走査。遅い。
    _each_up(bia, biz) {|bi|
      if @ary[bi] != nil
        next
      end
      i = bi
      break
    }
    _assert( i != nil )
    @ary[i] = si
  end

  def add_s(c, si)
    # バケツの先頭・末尾
    bia = first_pos(c)
    biz = bia + num_bucket_elements(c) - 1

    i = nil
    # 線形走査。遅い。
    _each_down(biz, bia) {|bi|
      if @ary[bi] != nil
        next
      end
      i = bi
      break
    }
    _assert( i != nil )
    @ary[i] = si
  end

  def num_chars()
    return @ary.length()
  end

  def get(i)
    return @ary[i]
  end

  def set(i, value)
    @ary[i] = value
  end

  def to_array()
    return @ary
  end
end


def make_types(s_cs)
  types = []
  types[s_cs.length() - 1] = TYPE_S # sentinel は S
  prev_type = TYPE_S
  prev_c = SENTINEL

  _each_down(s_cs.length() - 2, 0) {|si|
    c = s_cs[si]

    if (c < prev_c)
      type = TYPE_S
    elsif (c > prev_c)
      type = TYPE_L
    else
      type = prev_type
    end
    types[si] = type

    prev_c = c
    prev_type = type
  }
  return types
end

def extract_lms_positions(types)
  sis = []
  prev_type = TYPE_S # sentinel
  _each_down(types.length() - 2, 0) {|si|
    type = types[si]
    if (type == TYPE_L && prev_type == TYPE_S)
      sis.unshift(si + 1)
    end

    prev_type = type
  }
  return sis
end

def add_lms_to_bkts(bkts, s_cs, lms_sis)
  lms_sis.each {|lms_si|
    c = s_cs[lms_si]
    bkts.add_s(c, lms_si)
  }
  return bkts
end

def induce_l(bkts, s_cs, types)
  # 上から
  _each_up(0, bkts.num_chars() - 1) {|bi|
    si = bkts.get(bi)
    next if si == nil
    next if si == 0
    next if types[si - 1] != TYPE_L
    prev_c = s_cs[si - 1]
    bkts.add_l(prev_c, si - 1)
  }
  return bkts
end

def induce_s(bkts, s_cs, types)
  # 下から
  _each_down(bkts.num_chars() - 1, 0) {|bi|
    si = bkts.get(bi)
    next if si == nil
    next if si == 0
    next if types[si - 1] != TYPE_S
    prev_c = s_cs[si - 1]
    bkts.add_s(prev_c, si - 1)
  }
  return bkts
end

def is_lms(lms_sis, si)
  return lms_sis.include?(si)
end

def induced_sort(bkts, s_cs, types, lms_sis)
  bkts = induce_l(bkts, s_cs, types)

  # sentinel 以外の LMS を除去
  _each_up(1, bkts.num_chars() - 1) {|bi|
    si = bkts.get(bi)
    bkts.set(bi, nil) if is_lms(lms_sis, si)
  }
  
  bkts = induce_s(bkts, s_cs, types)

  return bkts
end

def is_same_substring(s_cs, ai, bi, lms_sis)
  i = 0
  is_same = true

  while true
    if s_cs[ai + i] != s_cs[bi + i]
      is_same = false
      break
    end

    if i >= 2
      # 元論文ではタイプ(L/S)による判別も行なっている
      a_is_lms = is_lms(lms_sis, ai + i)
      b_is_lms = is_lms(lms_sis, bi + i)
      if (a_is_lms && b_is_lms)
        break
      elsif (! a_is_lms && b_is_lms)
        is_same = false
        break
      elsif (a_is_lms && ! b_is_lms)
        is_same = false
        break
      else
        # both are not LMS
      end
    end
    
    i += 1
    _assert(i < s_cs.length())
  end

  return is_same
end

def get_name(i)
  name = NAME_CHARS[i]
  if (name == nil)
    raise "names (LMS-substring) is too many"
  end
  return name
end

def to_names(s_cs, lms_sis, sorted_lms_sis_temp)
  is_unique = true
  # name index
  ni = 0
  names = []

  # 1個目
  names.unshift(get_name(ni))
  ni += 1

  # 2個目以降
  _each_up(0, sorted_lms_sis_temp.length() - 2) {|ai|
    sia = sorted_lms_sis_temp[ai]
    sib = sorted_lms_sis_temp[ai + 1]

    if (is_same_substring(s_cs, sia, sib, lms_sis))
      is_unique = false
    else
      ni += 1
    end

    names.unshift(get_name(ni))
  }

  return [names, is_unique]
end

def sa_is(s_cs)
  bkts = Buckets.new(s_cs)
  types = make_types(s_cs)
  lms_sis = extract_lms_positions(types)

  # --------------------------------
  # induced sort 1回目
  # LMS-substring をソートするのが目的

  bkts = add_lms_to_bkts(bkts, s_cs, lms_sis)
  bkts = induced_sort(bkts, s_cs, types, lms_sis)

  # この時点で LMS-substring がソートされた状態になる
  # (ただし、重複した LMS-substring 同士の順序は未確定)

  # --------------------------------
  # LMS-substring のソート

  # LMS だけを抜き出す
  sorted_lms_sis_temp = []
  _each_up(0, bkts.num_chars() - 1) {|bi|
    si = bkts.get(bi)
    if (is_lms(lms_sis, si))
      sorted_lms_sis_temp << si
    end
  }

  names, is_unique = to_names(s_cs, lms_sis, sorted_lms_sis_temp)

  sorted_lms_sis = nil
  if (is_unique)
    sorted_lms_sis = sorted_lms_sis_temp
  else
    ret = sa_is(names)

    sorted_lms_sis = []
    ret.each {|i|
      sorted_lms_sis.unshift(lms_sis[i])
    }
  end

  # --------------------------------
  # induced sort 2回目

  # 1回目のソートは LMS-substring のソート結果を得るのが目的だったので
  # 一旦空にして良い。
  bkts.clear(s_cs.length())

  bkts = add_lms_to_bkts(bkts, s_cs, sorted_lms_sis)
  bkts = induced_sort(bkts, s_cs, types, sorted_lms_sis)

  return bkts.to_array()
end


class SaIsTest < Minitest::Test

  def test_make_lms_sis_1
    #                                         0   1
    assert_equal([1], extract_lms_positions(["L","S"]))
  end

  def test_make_lms_sis_2
    #                                         0   1   2
    assert_equal([2], extract_lms_positions(["S","L","S"]))
  end

  def test_make_lms_sis_3
    #                                         0   1   2   3
    assert_equal [2], extract_lms_positions(["S","L","S","S"])
  end

  def test_make_lms_sis_4
    #                                            0   1   2   3   4   5
    assert_equal([2, 5], extract_lms_positions(["S","L","S","L","L","S"]))
  end


  def test_is_same_substring_1
    #        0   1   2   3   4   5   6   7
    #        L   S   L   S   L   S   L   S
    t_cs = ["b","a","b","a","b","a","b","$"]
    lms_sis = [1,3,5,7]
    assert_equal true, is_same_substring(t_cs, 1, 3, lms_sis)
  end

  def test_is_same_substring_2
    #        0   1   2   3   4   5   6   7
    #        L   S   L   S   L   S   L   S
    s_cs = ["b","a","c","a","b","a","b","$"]
    lms_sis = [1,3,5,7]
    assert_equal false, is_same_substring(s_cs, 1, 3, lms_sis)
  end


  def test_sa_is_bbaaddaaddaaccaa
    #        0                                       1
    #        0   1   2   3   4   5   6   7   8   9   0   1   2   3   4   5   6
    #        L   L   S   S   L   L   S   S   L   L   S   S   L   L   L   L   S
    s_cs = ["b","b","a","a","d","d","a","a","d","d","a","a","c","c","a","a", SENTINEL]

    assert_equal(
      [16 ,15 ,14 ,10 ,6 ,2 ,11 ,7 ,3 ,1 ,0 ,13 ,12 ,9 ,5 ,8 ,4],
      sa_is(s_cs)
    )
  end

  def test_sa_is_eaefaegaefaag
    #        0                                       1
    #        0   1   2   3   4   5   6   7   8   9   0   1   2   3
    #        L   S   S   L   S   S   L   S   S   L   S   S   L   S
    s_cs = ["e","a","e","f","a","e","g","a","e","f","a","a","g", SENTINEL]

    assert_equal(
      [13,10,7,1,4,11,0,8,2,5,9,3,12,6],
      sa_is(s_cs)
    )
  end

  def test_sa_is_mississippi
    #        0                                       1
    #        0   1   2   3   4   5   6   7   8   9   0   1
    #        L   S   L   L   S   L   L   S   L   L   L   S
    s_cs = ["m","i","s","s","i","s","s","i","p","p","i", SENTINEL]

    assert_equal(
      [11,10,7,4,1,0,9,8,6,3,5,2],
      sa_is(s_cs)
    )
  end

  def test_sa_is_abracadabra
    #        0                                       1
    #        0   1   2   3   4   5   6   7   8   9   0   1
    #        S   S   L   S   L   S   L   S   S   L   L   S
    s_cs = ["a","b","r","a","c","a","d","a","b","r","a", SENTINEL]

    assert_equal(
      [11,10,7,0,3,5,8,1,4,6,9,2],
      sa_is(s_cs)
    )
  end

end

(追記 2021-02-14) Ruby 3.0.0 向けに minitest まわりを微修正しました。

参考