转换表达式模板树

All*_*lan 6 c++ metaprogramming expression-templates c++11

给定一个表达式模板树,我想在处理它之前创建一个新的优化树.请考虑以下乘法运算示例:

a * b * c * d,
Run Code Online (Sandbox Code Playgroud)

由于operator*表达式树的从左到右的相关性,它产生:

(((a * b) * c) * d).
Run Code Online (Sandbox Code Playgroud)

我想生成一个转换的表达式树,其中乘法从右到左发生:

(a * (b * (c * d))).
Run Code Online (Sandbox Code Playgroud)

考虑二进制表达式类型:

template<typename Left, typename Right>
struct BinaryTimesExpr
{
    BinaryTimesExpr() = default;
    BinaryTimesExpr(const BinaryTimesExpr&) = default;
    BinaryTimesExpr(BinaryTimesExpr&&) = default;
    BinaryTimesExpr(Left&& l, Right&& r) : left(forward<Left>(l)), right(forward<Right>(r)) {}

    BinaryTimesExpr& operator=(const BinaryTimesExpr&) = default;
    BinaryTimesExpr& operator=(BinaryTimesExpr&&) = default;

    Left left;
    Right right;
};
Run Code Online (Sandbox Code Playgroud)

定义乘法运算符operator*:

template<typename Left, typename Right>
BinaryTimesExpr<Constify<Left>, Constify<Right>> operator*(Left&& l, Right&& r)
{
    return {forward<Left>(l), forward<Right>(r)};
}
Run Code Online (Sandbox Code Playgroud)

其中Constify定义如下:

template<typename T> struct HelperConstifyRef     { using type = T;  };
template<typename T> struct HelperConstifyRef<T&> { using type = const T&; };
template<typename T>
using ConstifyRef = typename HelperConstifyRef<T>::type;
Run Code Online (Sandbox Code Playgroud)

并且用于确保从lvalues构造时具有const lvalue-references的子表达式,以及从rvalues构造时rvalues的副本(通过复制/移动).

定义转换函数,该函数使用以前的条件创建新的表达式模板树:

template<typename Expr>
auto Transform(const Expr& expr) -> Expr
{
    return expr;
}

template<typename Left, typename Right>
auto Transform(const BinaryTimesExpr<Left, Right>& expr) -> type(???)
{
    return {(Transform(expr.left), Transform(expr.right))};
}

template<typename Left, typename Right>
auto Transform(const BinaryTimesExpr<BinaryTimesExpr<LeftLeft, LeftRight>, Right>& expr) -> type(???)
{
    return Transform({Transform(expr.left.left), {Transform(expr.left.right), Transform(expr.right)}}); // this sintax is invalid...how can I write this?
}
Run Code Online (Sandbox Code Playgroud)

我的问题是:

1)如何确定Transform函数的返回类型?我尝试过使用类型特征:

template<typename Expr>
struct HelperTransformedExpr
{
    using type = Expr;
};

template<typename Left, typename Right>
struct HelperTransformedExpr<BinaryTimesExpr<Left, Right>>
{
    using type = BinaryTimesExpr<typename HelperTransformedExpr<Left>::type, typename HelperTransformedExpr<Right>::type>;
};

template<typename LeftLeft, typename LeftRight, typename Right>
struct HelperTransformedExpr<BinaryTimesExpr<BinaryTimesExpr<LeftLeft, LeftRight>, Right>>
{
    using type = BinaryTimesExpr<typename HelperTransformedExpr<LeftLeft>::type,
        BinaryTimesExpr<typename HelperTransformedExpr<LeftRight>::type, typename HelperTransformedExpr<Right>::type>>;
};

template<typename Expr>
using TransformedExpr = typename HelperTransformedExpr<Expr>::type;
Run Code Online (Sandbox Code Playgroud)

但不知道如何应用这个来解决我的问题(2).

2)如何编写递归行:

return Transform({Transform(expr.left.left), {Transform(expr.left.right), Transform(expr.right)}});
Run Code Online (Sandbox Code Playgroud)

3)这个问题有更清洁的解决方案吗?


编辑: DyP提出了上述问题的部分解决方案.以下是基于他答案的完整解决方案:

template<typename Expr>
auto Transform(const Expr& expr) -> Expr
{
    return expr;
}

template<typename Left, typename Right>
auto Transform(BinaryTimesExpr<Left, Right> const& expr)
-> decltype(BinaryTimesExpr<decltype(Transform(expr.left)), decltype(Transform(expr.right))>{Transform(expr.left), Transform(expr.right)})
{
    return BinaryTimesExpr<decltype(Transform(expr.left)), decltype(Transform(expr.right))>{Transform(expr.left), Transform(expr.right)};
}

template<typename LeftLeft, typename LeftRight, typename Right>
auto Transform(BinaryTimesExpr<BinaryTimesExpr<LeftLeft, LeftRight>, Right> const& expr)
-> decltype(Transform(BinaryTimesExpr<decltype(Transform(expr.left.left)), BinaryTimesExpr<decltype(Transform(expr.left.right)), decltype(Transform(expr.right))>>{Transform(expr.left.left), {Transform(expr.left.right), Transform(expr.right)}}))
{
    return Transform(BinaryTimesExpr<decltype(Transform(expr.left.left)), BinaryTimesExpr<decltype(Transform(expr.left.right)), decltype(Transform(expr.right))>>{Transform(expr.left.left), {Transform(expr.left.right), Transform(expr.right)}});
}

int main()
{
    BinaryTimesExpr<int, int> beg{1,2};
    auto res = beg*3*4*5*beg;
    std::cout << res << std::endl;
    std::cout << Transform(res) << std::endl;
}
Run Code Online (Sandbox Code Playgroud)

输出:

(((((1*2)*3)*4)*5)*(1*2))
(1*(2*(3*(4*(5*(1*2))))))
Run Code Online (Sandbox Code Playgroud)

请注意,Transform除了最外部Transform调用之外,还必须在每个子表达式上应用该函数(请参阅上一次Transform重载).

完整的源代码可以在这里找到.

dyp*_*dyp 0

不包含完美转发:

#include <iostream>

// simplified by making it an aggregate
template<typename Left, typename Right>
struct BinaryTimesExpr
{
    Left left;
    Right right;
};

// "debug" / demo output
template<typename Left, typename Right>
std::ostream& operator<<(std::ostream& o, BinaryTimesExpr<Left, Right> const& p)
{
    o << "(" << p.left << "*" << p.right << ")";
    return o;
}

// NOTE: removed reference as universal-ref yields a reference type for lvalues!
template<typename Left, typename Right>
BinaryTimesExpr < typename std::remove_reference<Left>::type,
                  typename std::remove_reference<Right>::type >
operator*(Left&& l, Right&& r)
{
    return {std::forward<Left>(l), std::forward<Right>(r)};
}


// overload to end recursion (no-op)
template<typename Expr>
auto Transform(const Expr& expr) -> Expr
{
    return expr;
}

template<typename LeftLeft, typename LeftRight, typename Right>
auto Transform(BinaryTimesExpr < BinaryTimesExpr<LeftLeft, LeftRight>,
                                 Right > const& expr)
-> decltype(Transform(
     BinaryTimesExpr < LeftLeft,
                       BinaryTimesExpr<LeftRight, Right>
                     > {expr.left.left, {expr.left.right, expr.right}}
   ))
{
    return Transform(
        BinaryTimesExpr < LeftLeft,
                          BinaryTimesExpr<LeftRight, Right>
                        > {expr.left.left, {expr.left.right, expr.right}}
    );
}


int main()
{
    BinaryTimesExpr<int, int> beg{1,2};
    auto res = beg*3*4*5*6;
    std::cout << res << std::endl;
    std::cout << Transform(res) << std::endl;
}
Run Code Online (Sandbox Code Playgroud)

输出:

(((((1*2)*3)*4)*5)*6)

(1*(2*(3*(4*(5*6)))))