Clone a function in Julia

Luk*_*rns 4 julia

I want to overwrite a function in Julia using its old definition. It seems the way to do this would be to clone the function and overwrite the original using the copy — something like the following. However, it appears deepcopy(f) just returns a reference to f, so this doesn't work.

f(x) = x
f_old = deepcopy(f)
f(x) = 1 + f_old(x)
Run Code Online (Sandbox Code Playgroud)

How can I clone a function?

Background: I'm interesting in writing a macro @override that allows me to override functions pointwise (or maybe even piecewise).

fib(n::Int) = fib(n-1) + fib(n-2)
@override fib(0) = 1
@override fib(1) = 1
Run Code Online (Sandbox Code Playgroud)

This particular example would be slow and could be made more efficient using @memoize. There may be good reasons not to do this, but there may also be situations in which one does not know a function fully when it is defined and overriding is necessary.

Mas*_*son 6

我们可以使用IRTools.jl做到这一点

using IRTools

fib(n::Int) = fib(n-1) + fib(n-2)

const fib_ir  = IRTools.code_ir(fib, Tuple{Int})
const fib_old = IRTools.func(fib_ir)

fib(n::Int) = n < 2 ? 1 : fib_old(fib, n)

julia> fib(10)
89
Run Code Online (Sandbox Code Playgroud)

我们在那里所做的工作捕获了该函数的中间表示fib,然后将其重建为我们称为的新函数fib_old。然后,我们可以自由地覆盖的定义fib来讲fib_old!请注意,由于fib_old被定义为递归调用fib,而不是递归调用fib_old,因此在调用时没有堆栈溢出fib(10)

注意的另一件事是,当我们打电话时fib_old,我们写的fib_old(fib, n)不是fib_old(n)。这是由于IRTools.func工作原理。

根据Mike Innes在Slack上的说法:

在Julia IR中,所有函数都带有一个表示该函数本身的隐藏额外参数。其原因是闭包是具有字段的结构,您需要在IR中对其进行访问

这是您的@override宏的实现,语法略有不同:

function _get_type_sig(fdef)
    d = splitdef(fdef)
    types = []
    for arg in d[:args]
        if arg isa Symbol
            push!(types, :Any)
        elseif @capture(arg, x_::T_) 
            push!(types, T)
        else
            error("whoops!")
        end
    end
    if isempty(d[:whereparams])
        :(Tuple{$(types...)})
    else
        :((Tuple{$(types...)} where {$(d[:whereparams]...)}).body)
    end
end

macro override(cond, fdef)
    d = splitdef(fdef)
    shadowf = gensym()
    sig = _get_type_sig(fdef)
    f = d[:name]
    quote
        const $shadowf = IRTools.func(IRTools.code_ir($(d[:name]), $sig))
        function $f($(d[:args]...)) where {$(d[:whereparams]...)}
            if $cond
                $(d[:body])
            else
                $shadowf($f, $(d[:args]...))
            end
        end
    end |> esc
end
Run Code Online (Sandbox Code Playgroud)

现在可以输入

fib(n::Int) = fib(n-1) + fib(n-2)
@override n < 2 fib(n::Int) = 1

julia> fib(10)
89
Run Code Online (Sandbox Code Playgroud)

最好的部分是,这几乎就像我们将条件写入原始函数一样快(在运行时,而不是编译时!)!

n = 15

fib2(n::Int) = n < 2 ? 1 : fib2(n-1) + fib2(n-2)

julia> @btime fib($(Ref(15))[])
  4.239 ?s (0 allocations: 0 bytes)
89

julia> @btime fib2($(Ref(15))[])
  3.022 ?s (0 allocations: 0 bytes)
89
Run Code Online (Sandbox Code Playgroud)