Sal*_*ali 9 python algorithm performance
我试图在python中实现Burrows-Wheeler变换.(这是在线课程的任务之一,但我希望我做了一些工作才有资格寻求帮助).
该算法的工作原理如下.取一个以特殊字符结尾的字符串(在我的情况下是$)并从该字符串创建所有循环字符串.按字母顺序对所有这些字符串进行排序,使特殊字符始终小于任何其他字符.在此之后获取每个字符串的最后一个元素.
这给了我一个oneliner:
''.join([i[-1] for i in sorted([text[i:] + text[0:i] for i in xrange(len(text))])]
Run Code Online (Sandbox Code Playgroud)
对于相当大的字符串来说哪个是正确且合理的快(这足以解决问题):
60 000 chars - 16 secs
40 000 chars - 07 secs
25 000 chars - 02 secs
Run Code Online (Sandbox Code Playgroud)
但是当我试图用几百万个字符处理一个非常庞大的字符串时,我失败了(处理需要太多时间).
我认为问题在于在内存中存储太多字符串.
有没有办法克服这个问题?
PS只想指出这也可能看起来像是一个家庭作业问题,我的解决方案已经通过了分级机,我只是想找到一种方法来加快速度.此外,我并没有破坏其他人的乐趣,因为如果他们想找到解决方案,维基文章就有一个类似于我的解决方案.我还检查了这个听起来相似的问题,但回答了一个更难的问题,如何解码用这种算法编码的字符串.
Kar*_*tel 10
使用长字符串制作所有字符串切片需要很长时间.它至少是 O(N ^ 2)(因为你创建N个N长度的字符串,并且每个字符串都必须从原始数据库中复制到内存中),这会破坏整体性能并使排序无关紧要.更不用说内存要求了!
而不是实际切片串,在下以为是订购i用于创建循环串,在如何生成的字符串命令值会比较-而不实际创建它.事实证明这有点棘手.(删除/编辑了一些错误的内容;请参阅@TimPeters的回答.)
我在这里采用的方法是绕过标准库 - 这使得"按需"比较这些字符串变得困难(尽管不是不可能) - 并且我自己进行排序.这里算法的自然选择是基数排序,因为我们无论如何都需要一次考虑一个字符串.
我们先设置好.我正在编写版本3.2的代码,所以季节尝试.(特别是在3.3及以上,我们可以利用yield from.)我使用以下导入:
from random import choice
from timeit import timeit
from functools import partial
Run Code Online (Sandbox Code Playgroud)
我写了一个通用的基数排序函数,如下所示:
def radix_sort(values, key, step=0):
if len(values) < 2:
for value in values:
yield value
return
bins = {}
for value in values:
bins.setdefault(key(value, step), []).append(value)
for k in sorted(bins.keys()):
for r in radix_sort(bins[k], key, step + 1):
yield r
Run Code Online (Sandbox Code Playgroud)
当然,我们不需要是通用的(我们的'bins'只能用单个字符标记,并且可能你真的意味着将算法应用于一个字节序列;)),但它不会伤害.还有可重复使用的东西,对吧?无论如何,这个想法很简单:我们处理一个基本情况,然后我们根据key函数的结果将每个元素放入一个"bin",然后我们按照排序的bin顺序从bin中提取值,递归地排序每个bin的内容.
界面要求key(value, n)给我们的n"基数" value.因此对于简单的情况,比如直接比较字符串,这可能很简单lambda v, n: return v[n].但是,这里的想法是根据该点的字符串中的数据(周期性地考虑)将索引与字符串进行比较.所以让我们定义一个关键:
def bw_key(text, value, step):
return text[(value + step) % len(text)]
Run Code Online (Sandbox Code Playgroud)
现在获得正确结果的技巧是记住我们在概念上加入了我们实际上没有创建的字符串的最后一个字符.如果我们考虑使用索引创建的虚拟字符串n,它的最后一个字符是索引n - 1,因为我们如何环绕 - 并且片刻的想法会向你确认这仍然适用于n == 0;).[但是,当我们向前包装时,我们仍然需要保持字符串索引入界 - 因此在键函数中进行模运算.]
这是一个通用的键函数,需要text在转换values进行比较时引用它.这就是functools.partial进入的地方- 你也可能只是乱七八糟lambda,但这可以说是更清洁,而且我发现它通常也更快.
无论如何,现在我们可以使用键轻松编写实际的转换:
def burroughs_wheeler_custom(text):
return ''.join(text[i - 1] for i in radix_sort(range(len(text)), partial(bw_key, text)))
# Notice I've dropped the square brackets; this means I'm passing a generator
# expression to `join` instead of a list comprehension. In general, this is
# a little slower, but uses less memory. And the underlying code uses lazy
# evaluation heavily, so :)
Run Code Online (Sandbox Code Playgroud)
很漂亮.让我们看看它是怎么做的,不是吗?我们需要一个标准来比较它:
def burroughs_wheeler_standard(text):
return ''.join([i[-1] for i in sorted([text[i:] + text[:i] for i in range(len(text))])])
Run Code Online (Sandbox Code Playgroud)
和计时例程:
def test(n):
data = ''.join(choice('abcdefghijklmnopqrstuvwxyz') for i in range(n)) + '$'
custom = partial(burroughs_wheeler_custom, data)
standard = partial(burroughs_wheeler_standard, data)
assert custom() == standard()
trials = 1000000 // n
custom_time = timeit(custom, number=trials)
standard_time = timeit(standard, number=trials)
print("custom: {} standard: {}".format(custom_time, standard_time))
Run Code Online (Sandbox Code Playgroud)
注意我已做过的数学决定了一些数字trials,与test字符串的长度成反比.这应该将用于测试的总时间保持在合理的范围内 - 对吧?;)(当然,错了,因为我们确定standard算法至少是O(N ^ 2).)
让我们看看它是如何做的(*drumroll*):
>>> imp.reload(burroughs_wheeler)
<module 'burroughs_wheeler' from 'burroughs_wheeler.py'>
>>> burroughs_wheeler.test(100)
custom: 4.7095093091438684 standard: 0.9819262643716229
>>> burroughs_wheeler.test(1000)
custom: 5.532266880287807 standard: 2.1733253807396977
>>> burroughs_wheeler.test(10000)
custom: 5.954826800612864 standard: 42.50686064849015
Run Code Online (Sandbox Code Playgroud)
哇,这有点可怕的跳跃.无论如何,正如您所看到的,新方法在短字符串上增加了大量开销,但却使实际排序成为瓶颈而不是字符串切片.:)
只是添加一点@KarlKnechtel的现场响应.
首先,加速循环置换提取的"标准方法"就是将两个副本粘贴在一起并直接索引到其中.后:
N = len(text)
text2 = text * 2
Run Code Online (Sandbox Code Playgroud)
那么从索引i处开始的循环置换就是正确的text2[i: i+N],并且j该置换中的字符就是正确的text2[i+j].无需将两个切片粘贴在一起,也不需要进行模数(%)操作.
其次,内置sort()可以用于此,但是:
作为概念验证,这里是Karl代码部分的替代品(虽然这很适合Python 2):
def burroughs_wheeler_custom(text):
N = len(text)
text2 = text * 2
class K:
def __init__(self, i):
self.i = i
def __lt__(a, b):
i, j = a.i, b.i
for k in xrange(N): # use `range()` in Python 3
if text2[i+k] < text2[j+k]:
return True
elif text2[i+k] > text2[j+k]:
return False
return False # they're equal
inorder = sorted(range(N), key=K)
return "".join(text2[i+N-1] for i in inorder)
Run Code Online (Sandbox Code Playgroud)
需要注意的是内置sort()的执行计算的关键正好一次在其输入的每个元素,并不会保存这些结果排序的持续时间.在这种情况下,结果是K只记得起始索引的惰性小实例,并且其__lt__方法一次比较一个字符对,直到"小于!".或"大于!" 已经解决了.