Learning String Sorting Algorithm
文章目录
场景
假设有一个手机号数组,要求对这个数组进行排序,怎么样排序最快呢?
一般我们会用快排来实现,现场写一个快排:
from collections import deque
def swap(strs, i, j):
temp = strs[i]
strs[i] = strs[j]
strs[j] = temp
def quick_sort(strs):
queue = deque([0, len(strs) - 1])
while len(queue) > 0:
start = queue.popleft()
end = queue.popleft()
if start >= end:
continue
pivot_idx = (start + end) // 2
swap(strs, pivot_idx, end)
# 左边最后一个元素位置
last_left = start - 1
for i in range(start, end):
if strs[i] <= strs[end]:
last_left += 1
swap(strs, last_left, i)
last_left += 1
swap(strs, last_left, end)
if start < last_left - 1:
queue.append(start)
queue.append(last_left - 1)
if last_left + 1 < end:
queue.append(last_left + 1)
queue.append(end)
phones = []
for i in range(0, 10):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, 11)])
phones.append(phone)
print(phones)
quick_sort(phones)
print(phones)
['25900803422', '82533056198', '36977030221', '82656350178', '50197010257', '55052712489', '80915203421', '78663941583', '72021821280', '23338950918']
['23338950918', '25900803422', '36977030221', '50197010257', '55052712489', '72021821280', '78663941583', '80915203421', '82533056198', '82656350178']
快排的平均时间复杂度为$O(NlogN)$, 最坏情况下为$O(N^2)$。
如果用python内置的sort函数,来实现,代码如下:
import random
phones = []
for i in range(0, 10):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, 11)])
phones.append(phone)
print(phones)
phones.sort()
print(phones)
['49993400585', '90540095748', '93341457109', '42794844210', '72054270540', '13814446283', '79563846731', '26819845915', '14973623007', '31301649154']
['13814446283', '14973623007', '26819845915', '31301649154', '42794844210', '49993400585', '72054270540', '79563846731', '90540095748', '93341457109']
python的sort的内部实现是TimSort,一种改良的归并排序,平均时间复杂度为$O(NlogN)$, 这个算法非常快,据说内部实现是1000多行的C语言代码。
针对我们这个应用场景,《算法》一书里面讲到另外3种字符串排序算法,一种叫LSD字符串排序算法。它的时间复杂度可以达到$O(N * W)$,其中W是字符串的长度。一种叫MSD字符串排序算法,时间复杂度在$O(N)$至$O(N * W)$之间。还有一种叫三向快速排序,时间复杂度在$O(N)$至$O(N * logN)$之间。
要理解这3个算法,首先要学会桶排序,或者叫基数排序。
桶排序
桶排序是一种针对特定场景的排序,比如我已知一个数组每个元素的取值范围是0~9,它的思想就是用一个辅助数组来记录每个元素出现的次数,最后回写原数组。看代码:
nums = [random.randint(0, 9) for i in range(10)]
def bucket_sort(arr):
aux = [0] * 10
for i in arr:
aux[i] += 1
i, j = 0, 0
for i in range(0, len(aux)):
while aux[i] > 0:
arr[j] = i
aux[i] -= 1
j += 1
print(nums)
bucket_sort(nums)
print(nums)
[0, 1, 0, 9, 3, 2, 9, 2, 4, 5]
[0, 0, 1, 2, 2, 3, 4, 5, 9, 9]
可以看到,桶排序的时间复杂度为$O(N + M)$,其中M为取值范围的大小,另外需要增加$O(M)$的空间复杂度。
由于特定字符串的取值范围是固定的,比如手机号就是11个0~9的数字,如果使用桶排序,是不是可以大大提高速度呢?
LSD字符串排序
LSD(Least-Significant-Digit)低位优先字符串排序,它的思想是这样的,从每个字符串的低位开始,对最低位做一次桶排序,完成之后,再往高位做桶排序,一共做W次桶排序,其中W是字符串的长度。注意,在我们这个应用场景里面,每个字符串的长度是一样的。代码实现如下:
phones = []
for i in range(0, 10):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, 11)])
phones.append(phone)
def lsd_sort(strs):
# 从低位往高位
for i in range(len(strs[0]) - 1, -1, -1):
# 记录各个字符对应的字符串
aux = [[] for _ in range(10)]
for j in range(0, len(strs)):
idx = ord(strs[j][i]) - ord('0')
aux[idx].append(strs[j])
k = 0
for j in range(0, 10):
for item in aux[j]:
strs[k] = item
k += 1
print(phones)
lsd_sort(phones)
print(phones)
['04145432564', '86748121667', '26710905823', '51612671904', '86609844213', '30603651543', '51116348132', '48212721677', '70985842642', '46470938975']
['04145432564', '26710905823', '30603651543', '46470938975', '48212721677', '51116348132', '51612671904', '70985842642', '86609844213', '86748121667']
这个算法,每次M个循环里面,都做了1次N循环和1次10循环,所以总的时间复杂度$O(N * M)$。
MSD字符串排序算法
MSD(Most Significant Digital)高位优先字符串排序,它的思想是这样的,从每个字符串的高位开始,做一次桶排序,高位字符可能有很多是相同的,对它每组相同的字符,对它们的子字符串递归地做一次MSD排序,为了提高性能,当每组元素个数少于某个阀值时,做插入排序。这样子,大部分情况下,只需要看前面几个字符就可以完成整个排序。代码实现如下:
phones = []
for i in range(0, 10):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, 11)])
phones.append(phone)
def insert_sort(strs, low, high, index):
for i in range(low + 1, high + 1):
temp = strs[i]
j = i - 1
while j >= low:
if temp[index:] < strs[j][index:]:
strs[j + 1] = strs[j]
j -= 1
strs[j + 1] = temp
def msd_sort(strs):
# 从高位开始
msd_sort_r(strs, 0, len(strs) - 1, 0)
def msd_sort_r(strs, low, high, index):
if high <= low or index >= len(strs[0]):
return
if high - low <= 7:
insert_sort(strs, low, high, index)
return
aux = [[] for _ in range(10)]
for i in range(low, high + 1):
idx = ord(strs[i][index]) - ord('0')
aux[idx].append(strs[i])
k = low
for i in range(0, len(aux)):
for item in aux[i]:
strs[k] = item
k += 1
for i in range(0, len(aux)):
msd_sort_r(strs, low, low + len(aux[i]) - 1, index + 1)
low += len(aux[i])
print(phones)
msd_sort(phones)
print(phones)
['08914475530', '36382806030', '17292598106', '79423140725', '61884265921', '96644424367', '48257134818', '92088637322', '12274980975', '99064046401']
['08914475530', '12274980975', '17292598106', '36382806030', '48257134818', '61884265921', '79423140725', '99064046401', '96644424367', '99064046401']
三向字符串排序
快速排序的原理我们都知道,就是取一个参考元素,把比它小的元素放到左边,比它大的元素放到右边,然后递归下去。
三向字符串排序的原理是这样的,取一个参考元素,从高位字符开始,高位字符比参考元素的小的在左边,比它大的在右边,中间的是高位元素相同的,这个时候,除了递归左右两边的,还要把高位元素相同的一组元素,按下一位继续排序。代码实现如下:
phones = []
for i in range(0, 10):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, 11)])
phones.append(phone)
def quick3_sort(strs):
# 从高位开始
quick3_sort_r(strs, 0, len(strs) - 1, 0)
def quick3_sort_r(strs, low, high, index):
if high <= low or index >= len(strs[0]):
return
if high - low <= 7:
insert_sort(strs, low, high, index)
return
pivot_idx = low
i = low + 1
last_left = low - 1
first_right = high + 1
while i < first_right:
if strs[i][index] < strs[pivot_idx][index]:
last_left += 1
temp = strs[last_left]
strs[last_left] = strs[i]
strs[i] = temp
pivot_idx = last_left + 1
elif strs[i][index] > strs[pivot_idx][index]:
first_right -= 1
temp = strs[first_right]
strs[first_right] = strs[i]
strs[i] = temp
else:
i += 1
quick3_sort_r(strs, low, last_left, index)
quick3_sort_r(strs, last_left + 1, first_right - 1, index + 1)
quick3_sort_r(strs, first_right, high, index)
print(phones)
quick3_sort(phones)
print(phones)
['08104132906', '31709678996', '08273091795', '79987194126', '83552801655', '59365186752', '97600677349', '22396353108', '81547158280', '73500845523']
['08273091795', '08273091795', '31709678996', '73500845523', '81547158280', '97600677349', '97600677349', '83552801655', '83552801655', '97600677349']
来写个性能测试来比较一下这几种排序算法哪个更快。
import timeit
phone_len = 11
N = 10000
def lsd_sort_test():
phones = []
for i in range(0, N):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, phone_len)])
phones.append(phone)
lsd_sort(phones)
def msd_sort_test():
phones = []
for i in range(0, N):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, phone_len)])
phones.append(phone)
msd_sort(phones)
def quick3_sort_test():
phones = []
for i in range(0, N):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, phone_len)])
phones.append(phone)
quick3_sort(phones)
def timsort_test():
phones = []
for i in range(0, N):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, phone_len)])
phones.append(phone)
phones.sort()
def quick_sort_test():
phones = []
for i in range(0, N):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, phone_len)])
phones.append(phone)
quick_sort(phones)
print(timeit.timeit('lsd_sort_test()', number=10, setup="from __main__ import lsd_sort_test"))
print(timeit.timeit('msd_sort_test()', number=10, setup="from __main__ import msd_sort_test"))
print(timeit.timeit('quick3_sort_test()', number=10, setup="from __main__ import quick3_sort_test"))
print(timeit.timeit('timsort_test()', number=10, setup="from __main__ import timsort_test"))
print(timeit.timeit('quick_sort_test()', number=10, setup="from __main__ import quick_sort_test"))
2.858616407000227
2.4598556860000826
2.801925394014688
2.146542837988818
2.5571339470043313
测试发现,对于10000个随机11位字符串,内置的timsort算法是最快的,然后是MSD算法,然后是快排,然后是三向快速排序,然后是LSD排序算法。这下子啪啪啪打脸了,本来以为这个LSD算法$O(N * W)$的时间复杂度是比$O(NlogN)$要快,但是结果是比快排要慢。大概的原因是,由于字符串长度是11,共10000个用例,log2(10000)是13.28,只比11大一点点,由于LSD排序还要开辟新空间并回写到原数组,所以会慢一些,如果把字符串长度调小或者用例数量调大,LSD排序是要比快排快的。三向快速排序比普通快速排序慢的原因是,三向快速排序适合有公共前缀或者很多重复元素的情况下。 假设字符串长度为3:
phone_len = 3
N = 10000
def lsd_sort_test():
phones = []
for i in range(0, N):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, phone_len)])
phones.append(phone)
lsd_sort(phones)
def msd_sort_test():
phones = []
for i in range(0, N):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, phone_len)])
phones.append(phone)
msd_sort(phones)
def quick3_sort_test():
phones = []
for i in range(0, N):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, phone_len)])
phones.append(phone)
quick3_sort(phones)
def timsort_test():
phones = []
for i in range(0, N):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, phone_len)])
phones.append(phone)
phones.sort()
def quick_sort_test():
phones = []
for i in range(0, N):
phone = ''.join([chr(random.randint(0, 9) + ord('0')) for i in range(0, phone_len)])
phones.append(phone)
quick_sort(phones)
print(timeit.timeit('lsd_sort_test()', number=10, setup="from __main__ import lsd_sort_test"))
print(timeit.timeit('msd_sort_test()', number=10, setup="from __main__ import msd_sort_test"))
print(timeit.timeit('quick3_sort_test()', number=10, setup="from __main__ import quick3_sort_test"))
print(timeit.timeit('timsort_test()', number=10, setup="from __main__ import timsort_test"))
print(timeit.timeit('quick_sort_test()', number=10, setup="from __main__ import quick_sort_test"))
0.8569560649921186
0.8670862840081099
1.2120806089951657
0.7727651970053557
1.326724906975869
TimSort真的是快,这算法真不是吹牛的,不愧是python和java的内置算法。
参考资料
- 《算法》第4版