如何有效地传递功能?

mat*_*ath 16 python algorithm numpy scipy python-2.7

动机

看看下面的图片.

在此输入图像描述

给出的是红色,蓝色和绿色曲线.我想在x轴上的每个点找到主导曲线.这显示为图中的黑色图形.从红色,绿色和蓝色曲线的属性(在一段时间后增加和恒定),这可以归结为在右手边找到主导曲线,然后向左侧移动,找到所有交叉点并更新主导曲线.

这个概述的问题应该解决T一次.这个问题有一个最后的转折点.下一次迭代的蓝色,绿色和红色曲线是通过前一次迭代的主导解决方案加上一些变化的参数构建的.作为上图中的示例:解决方案是黑色功能.此功能用于生成新的蓝色,绿色和红色曲线.然后问题又开始找到这些新曲线的主导者等.

问题简而言之
在每次迭代中,我都从固定的右手边开始,并评估所有三个函数,看看哪个是主导函数.这种评估在迭代中花费的时间越来越长.我的感觉是,我没有最佳地通过旧的主导功能来构建新的蓝色,绿色和红色曲线.原因:我在早期版本中遇到了最大递归深度错误.代码的其他部分,其中当前支配函数的值(其中绿色,红色或蓝色曲线必不可少)在迭代时也需要越来越长的时间.

对于5次迭代,只评估右侧一点上的函数增长:

结果是通过

test = A(5, 120000, 100000) 
Run Code Online (Sandbox Code Playgroud)

然后运行

test.find_all_intersections()

>>> test.find_all_intersections()
iteration 4
to compute function values it took
0.0102479457855
iteration 3
to compute function values it took
0.0134601593018
iteration 2
to compute function values it took
0.0294270515442
iteration 1
to compute function values it took
0.109843969345
iteration 0
to compute function values it took
0.823768854141
Run Code Online (Sandbox Code Playgroud)

我想知道为什么会这样,如果可以更有效地编程它.

详细的代码说明

我很快总结了最重要的功能.完整的代码可以在下面找到.如果对代码有任何其他问题,我非常乐意详细说明/澄清.

  1. 方法u:对于上面生成新一批绿色,红色和蓝色曲线的重复任务,我们需要旧的主导曲线. u是在第一次迭代中使用的初始化.

  2. 方法_function_template:该函数使用不同的参数生成绿色,蓝色和红色曲线的版本.它返回单个输入的函数.

  3. 方法eval:这是每次生成蓝色,绿色和红色版本的核心功能.每次迭代需要三个不同的参数:vfunction这是前一步骤的主导函数m,并且s是影响所得曲线形状的两个参数(flaots).其他参数在每次迭代中都是相同的.在代码存在用于样本值ms对每个迭代.对于更令人讨厌的人:它是近似积分的位置,m并且s是基础正态分布的预期均值和标准差.近似是通过Gauss-Hermite节点/权重完成的.

  4. 方法find_all_intersections:这是在每次迭代中找到主导方法的核心方法.它通过蓝色,绿色和红色曲线的分段连接构建了一个主导功能.这是通过该功能实现的piecewise.

这是完整的代码

import numpy as np
import pandas as pd
from scipy.optimize import brentq
import multiprocessing as mp
import pathos as pt
import timeit
import math
class A(object):
    def u(self, w):
        _w = np.asarray(w).copy()
        _w[_w >= 120000] = 120000
        _p = np.maximum(0, 100000 - _w)
        return _w - 1000*_p**2

    def __init__(self, T, upper_bound, lower_bound):
        self.T = T
        self.upper_bound = upper_bound
        self.lower_bound = lower_bound

    def _function_template(self, *args):
        def _f(x):
            return self.evalv(x, *args)
        return _f

    def evalv(self, w, c, vfunction, g, m, s, gauss_weights, gauss_nodes):
        _A = np.tile(1 + m + math.sqrt(2) * s * gauss_nodes, (np.size(w), 1))
        _W = (_A.T * w).T
        _W = gauss_weights * vfunction(np.ravel(_W)).reshape(np.size(w),
                                                             len(gauss_nodes))
        evalue = g*1/math.sqrt(math.pi)*np.sum(_W, axis=1)
        return c + evalue

    def find_all_intersections(self):

        # the hermite gauss weights and nodes for integration
        # and additional paramters used for function generation

        gauss = np.polynomial.hermite.hermgauss(10)
        gauss_nodes = gauss[0]
        gauss_weights = gauss[1]
        r = np.asarray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                        1., 1., 1., 1., 1., 1., 1., 1., 1.])
        m = [[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624]]

        s = [[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142]]

        self.solution = []

        n_cpu = mp.cpu_count()
        pool = pt.multiprocessing.ProcessPool(n_cpu)

        # this function is used for multiprocessing
        def call_f(f, x):
            return f(x)

        # this function takes differences for getting cross points
        def _diff(f_dom, f_other):
            def h(x):
                return f_dom(x) - f_other(x)
            return h

        # finds the root of two function
        def find_roots(F, u_bound, l_bound):
                try:
                    sol = brentq(F, a=l_bound,
                                 b=u_bound)
                    if np.absolute(sol - u_bound) > 1:
                        return sol
                    else:
                        return l_bound
                except ValueError:
                    return l_bound

        # piecewise function
        def piecewise(l_comp, l_f):
            def f(x):
                _ind_f = np.digitize(x, l_comp) - 1
                if np.isscalar(x):
                    return l_f[_ind_f](x)
                else:
                    return np.asarray([l_f[_ind_f[i]](x[i])
                                       for i in range(0, len(x))]).ravel()
            return f

        _u = self.u

        for t in range(self.T-1, -1, -1):
            print('iteration' + ' ' + str(t))

            l_bound, u_bound = 0.5*self.lower_bound, self.upper_bound
            l_ordered_functions = []
            l_roots = []
            l_solution = []

            # build all function variations

            l_functions = [self._function_template(0, _u, r[t], m[t][i], s[t][i],
                                                   gauss_weights, gauss_nodes)
                           for i in range(0, len(m[t]))]

            # get the best solution for the upper bound on the very
            # right hand side of wealth interval

            array_functions = np.asarray(l_functions)
            start_time = timeit.default_timer()
            functions_values = pool.map(call_f, array_functions.tolist(),
                                        len(m[t]) * [u_bound])
            elapsed = timeit.default_timer() - start_time
            print('to compute function values it took')
            print(elapsed)

            ind = np.argmax(functions_values)
            cross_points = len(m[t]) * [u_bound]
            l_roots.insert(0, u_bound)
            max_m = m[t][ind]
            l_solution.insert(0, max_m)

            # move from the upper bound twoards the lower bound
            # and find the dominating solution by exploring all cross
            # points.

            test = True

            while test:
                l_ordered_functions.insert(0, array_functions[ind])
                current_max = l_ordered_functions[0]

                l_c_max = len(m[t]) * [current_max]
                l_u_cross = len(m[t]) * [cross_points[ind]]

                # Find new cross points on the smaller interval

                diff = pool.map(_diff, l_c_max, array_functions.tolist())
                cross_points = pool.map(find_roots, diff,
                                        l_u_cross, len(m[t]) * [l_bound])

                # update the solution, cross points and current
                # dominating function.

                ind = np.argmax(cross_points)
                l_roots.insert(0, cross_points[ind])
                max_m = m[t][ind]
                l_solution.insert(0, max_m)

                if cross_points[ind] <= l_bound:
                    test = False

            l_ordered_functions.insert(0, l_functions[0])
            l_roots.insert(0, 0)
            l_roots[-1] = np.inf

            l_comp = l_roots[:]
            l_f = l_ordered_functions[:]

            # build piecewise function which is used for next
            # iteration.

            _u = piecewise(l_comp, l_f)
            _sol = pd.DataFrame(data=l_solution,
                                index=np.asarray(l_roots)[0:-1])
            self.solution.insert(0, _sol)
        return self.solution
Run Code Online (Sandbox Code Playgroud)

Kir*_*rst 3

您的代码确实太复杂了,无法解释您的问题 - 争取更简单的东西。有时您必须编写代码只是为了演示问题。

我正在尝试,纯粹基于您的描述而不是您的代码(尽管我运行了代码并进行了验证)。这是你的问题:

eval方法:这是每次生成蓝、绿、红版本的核心函数。每次迭代需要三个不同的参数:vfunction,它是上一步的主导函数,m 和 s,它们是影响结果曲线形状的两个参数(flaots)。

vfunction每次迭代时您的参数都更加复杂。您正在传递一个在先前迭代中构建的嵌套函数,这会导致递归执行。每次迭代都会增加递归调用的深度。

如何避免这种情况?没有简单或内置的方法。最简单的答案是 - 假设这些函数的输入是一致的 - 存储函数结果(即数字)而不是函数本身。只要您有有限数量的已知输入,您就可以做到这一点。

如果底层函数的输入不一致,那么就没有捷径。您需要反复评估这些底层功能。max我看到您正在对底层函数进行一些分段拼接 - 您可以测试这样做的成本是否超过简单地获取每个底层函数的成本。

我运行的测试(10 次迭代)花了几秒钟。我不认为这是一个问题。