Spark:更高效的聚合以连接来自不同行的字符串

Chr*_*ers 8 python apache-spark pyspark

我目前正在研究DNA序列数据,我遇到了一些性能障碍.

我有两个查找字典/哈希(作为RDD),其中DNA"单词"(短序列)作为键,索引位置列表作为值.一个用于较短的查询序列,另一个用于数据库序列.即使是非常非常大的序列,创建表也非常快.

对于下一步,我需要将它们配对并找到"命中"(每个常用词的索引位置对).

我首先加入查找字典,速度相当快.但是,我现在需要这些对,所以我必须进行两次flatmap,一次是从查询中扩展索引列表,第二次是从数据库中扩展索引列表.这不是理想的,但我没有看到另一种方法.至少它表现不错.

此时的输出是:(query_index, (word_length, diagonal_offset)),其中对角线偏移量是database_sequence_index减去查询序列索引.

但是,我现在需要找到具有相同对角线偏移量的索引对(db_index - query_index)并合理地靠近并加入它们(因此我增加了单词的长度),但仅作为对(即一旦我加入一个索引)与另一个,我不希望任何其他东西与它合并).

我使用一个名为Seed()的特殊对象使用aggregateByKey操作.

PARALELLISM = 16 # I have 4 cores with hyperthreading
def generateHsps(query_lookup_table_rdd, database_lookup_table_rdd):
    global broadcastSequences

    def mergeValueOp(seedlist, (query_index, seed_length)):
        return seedlist.addSeed((query_index, seed_length))

    def mergeSeedListsOp(seedlist1, seedlist2):
        return seedlist1.mergeSeedListIntoSelf(seedlist2)

    hits_rdd = (query_lookup_table_rdd.join(database_lookup_table_rdd)
                .flatMap(lambda (word, (query_indices, db_indices)): [(query_index, db_indices) for query_index in query_indices], preservesPartitioning=True)
                .flatMap(lambda (query_index, db_indices): [(db_index - query_index, (query_index, WORD_SIZE)) for db_index in db_indices], preservesPartitioning=True)
                .aggregateByKey(SeedList(), mergeValueOp, mergeSeedListsOp, PARALLELISM)
                .map(lambda (diagonal, seedlist): (diagonal, seedlist.mergedSeedList))
                .flatMap(lambda (diagonal, seedlist): [(query_index, seed_length, diagonal) for query_index, seed_length in seedlist])
                )

    return hits_rdd
Run Code Online (Sandbox Code Playgroud)

种子():

class SeedList():
    def __init__(self):
        self.unmergedSeedList = []
        self.mergedSeedList = []


    #Try to find a more efficient way to do this
    def addSeed(self, (query_index1, seed_length1)):
        for i in range(0, len(self.unmergedSeedList)):
            (query_index2, seed_length2) = self.unmergedSeedList[i]
            #print "Checking ({0}, {1})".format(query_index2, seed_length2)
            if min(abs(query_index2 + seed_length2 - query_index1), abs(query_index1 + seed_length1 - query_index2)) <= WINDOW_SIZE:
                self.mergedSeedList.append((min(query_index1, query_index2), max(query_index1+seed_length1, query_index2+seed_length2)-min(query_index1, query_index2)))
                self.unmergedSeedList.pop(i)
                return self
        self.unmergedSeedList.append((query_index1, seed_length1))
        return self

    def mergeSeedListIntoSelf(self, seedlist2):
        print "merging seed"
        for (query_index2, seed_length2) in seedlist2.unmergedSeedList:
            wasmerged = False
            for i in range(0, len(self.unmergedSeedList)):
                (query_index1, seed_length1) = self.unmergedSeedList[i]
                if min(abs(query_index2 + seed_length2 - query_index1), abs(query_index1 + seed_length1 - query_index2)) <= WINDOW_SIZE:
                    self.mergedSeedList.append((min(query_index1, query_index2), max(query_index1+seed_length1, query_index2+seed_length2)-min(query_index1, query_index2)))
                    self.unmergedSeedList.pop(i)
                    wasmerged = True
                    break
            if not wasmerged:
                self.unmergedSeedList.append((query_index2, seed_length2))
        return self
Run Code Online (Sandbox Code Playgroud)

对于即使是中等长度的序列,这也是性能真正崩溃的地方.

有没有更好的方法来进行这种聚合?我的直觉是肯定的,但我无法想出来.

我知道这是一个非常漫长的技术问题,即使没有简单的解决方案,我也非常感谢任何见解.

编辑:这是我如何制作查找表:

def createLookupTable(sequence_rdd, sequence_name, word_length):
    global broadcastSequences
    blank_list = []

    def addItemToList(lst, val):
        lst.append(val)
        return lst

    def mergeLists(lst1, lst2):
        #print "Merging"
        return lst1+lst2

    return (sequence_rdd
            .flatMap(lambda seq_len: range(0, seq_len - word_length + 1))
            .repartition(PARALLELISM)
            #.partitionBy(PARALLELISM)
            .map(lambda index: (str(broadcastSequences.value[sequence_name][index:index + word_length]), index), preservesPartitioning=True)
            .aggregateByKey(blank_list, addItemToList, mergeLists, PARALLELISM))
            #.map(lambda (word, indices): (word, sorted(indices))))
Run Code Online (Sandbox Code Playgroud)

这是运行整个操作的函数:

def run(query_seq, database_sequence, translate_query=False):
    global broadcastSequences
    scoring_matrix = 'nucleotide' if isinstance(query_seq.alphabet, DNAAlphabet) else 'blosum62'
    sequences = {'query': query_seq,
                 'database': database_sequence}

    broadcastSequences = sc.broadcast(sequences)
    query_rdd = sc.parallelize([len(query_seq)])
    query_rdd.cache()
    database_rdd = sc.parallelize([len(database_sequence)])
    database_rdd.cache()
    query_lookup_table_rdd = createLookupTable(query_rdd, 'query', WORD_SIZE)
    query_lookup_table_rdd.cache()
    database_lookup_table_rdd = createLookupTable(database_rdd, 'database', WORD_SIZE)
    seeds_rdd = generateHsps(query_lookup_table_rdd, database_lookup_table_rdd)
    return seeds_rdd
Run Code Online (Sandbox Code Playgroud)

编辑2:通过更换以下内容,我稍微调整了一些内容并略微提高了性能:

                .flatMap(lambda (word, (query_indices, db_indices)): [(query_index, db_indices) for query_index in query_indices], preservesPartitioning=True)
                .flatMap(lambda (query_index, db_indices): [(db_index - query_index, (query_index, WORD_SIZE)) for db_index in db_indices], preservesPartitioning=True)
Run Code Online (Sandbox Code Playgroud)

在hits_rdd中:

.flatMap(lambda (word, (query_indices, db_indices)): itertools.product(query_indices, db_indices))
                .map(lambda (query_index, db_index): (db_index - query_index, (query_index, WORD_SIZE) ))
Run Code Online (Sandbox Code Playgroud)

至少现在我没有用中间数据结构烧掉存储空间.

max*_*moo 1

让我们忘记您正在做的事情的技术细节,并“从功能上”思考所涉及的步骤,忘记实现的细节。像这样的函数式思维是并行数据分析的重要组成部分;理想情况下,如果我们能像这样分解问题,我们就可以更清楚地推理所涉及的步骤,并最终得到更清晰、往往更简洁的结论。考虑表格数据模型,我认为您的问题包含以下步骤:

  1. 在序列列上连接两个数据集。
  2. delta创建一个包含索引之间差异的新列。
  3. 按(任一)索引排序以确保子序列的顺序正确。
  4. delta对序列列中的字符串进行分组和连接,以获得数据集之间的完全匹配。

对于前 3 个步骤,我认为使用 DataFrame 是有意义的,因为这个数据模型在我的脑海中对我们正在进行的处理是有意义的。(实际上,我也可能在步骤 4 中使用 DataFrame,除了 pyspark 目前不支持 DataFrame 的自定义聚合函数,尽管 Scala 支持)。

对于第四步(如果我正确理解你在问题中真正要问的内容),考虑如何在功能上做到这一点有点棘手,但是我认为一个优雅而有效的解决方案是使用reduce (也称为右折叠);这种模式可以应用于任何可以用迭代应用关联二元函数来表述的问题,即任何 3 个参数的“分组”并不重要的函数(尽管顺序肯定可能很重要),象征性地,这是一个函数,x,y -> f(x,y)其中f(x, f(y, z)) = f(f(x, y), z). 字符串(或更普遍的列表)连接就是这样一个函数。

下面是一个示例,展示了它的外观pyspark;希望您可以根据您的数据详细信息进行调整:

#setup some sample data
query = [['abcd', 30] ,['adab', 34] ,['dbab',38]]
reference = [['dbab', 20], ['ccdd', 24], ['abcd', 50], ['adab',54], ['dbab',58], ['dbab', 62]]

#create data frames
query_df = sqlContext.createDataFrame(query, schema = ['sequence1', 'index1'])
reference_df = sqlContext.createDataFrame(reference, schema = ['sequence2', 'index2'])

#step 1: join
matches = query_df.join(reference_df, query_df.sequence1 == reference_df.sequence2)

#step 2: calculate delta column
matches_delta = matches.withColumn('delta', matches.index2 - matches.index1)

#step 3: sort by index
matches_sorted = matches_delta.sort('delta').sort('index2')

#step 4: convert back to rdd and reduce
#note that + is just string concatenation for strings
r = matches_sorted['delta', 'sequence1'].rdd
r.reduceByKey(lambda x, y : x + y).collect()

#expected output:
#[(24, u'dbab'), (-18, u'dbab'), (20, u'abcdadabdbab')]
Run Code Online (Sandbox Code Playgroud)