如何在accelerate-haskell中定义矩阵乘积

eps*_*lbe 5 haskell matrix accelerate-haskell

我试图在加速之上定义一个类型安全的矩阵计算库,部分是出于教育目的,部分是为了看看这是否是一种实用的方法。

\n\n

但当涉及到正确定义矩阵的乘积时,我完全陷入困境 - 即以 GHC 接受/编译我的代码的方式。

\n\n

我进行了一些尝试,这些尝试是以下的变体:

\n\n

Linear.hs

\n\n
{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE KindSignatures #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n\nimport qualified Data.Array.Accelerate as A\n\nimport GHC.TypeLits\nimport Data.Array.Accelerate ( (:.)(..), Array\n                             , Exp, Shape, FullShape, Slice\n                             , DIM0, DIM1, DIM2, Z(Z)\n                             , IsFloating, IsNum, Elt, Acc\n                             , Any(Any), All(All))\nimport           Data.Proxy\n\nnewtype Matrix (rows :: Nat) (cols :: Nat) a = AccMatrix {unMatrix :: Acc (Array DIM2 a)}\n(#*#) :: forall k m n a. (KnownNat k, KnownNat m, KnownNat n, IsNum a, Elt a) =>\n    Matrix k m a -> Matrix m n a -> Matrix k n a\n v #*# w = let v\' = unMatrix v\n               w\' = unMatrix w\n           in AccMatrix $ A.generate (A.index2 k\' n\') undefined\n          where k\' = fromInteger $ natVal (Proxy :: Proxy k)\n                n\' = fromInteger $ natVal (Proxy :: Proxy n)\n                aux :: Acc (Array (FullShape (Z :. Int) :. Int) e) -> Acc (Array (FullShape (Z :. All) :. Int) e) -> Exp ((Z :. Int) :. Int) -> Exp e\n                aux v w sh = let (Z:.i:.j) = A.unlift sh\n                                 v\' = A.slice v (A.lift $ Z:.i:.All)\n                                 w\' = A.slice w (A.lift $ Z:.All:.j)\n                              in A.the $ A.sum $ A.zipWith (*) v\' w\'\n
Run Code Online (Sandbox Code Playgroud)\n\n

错误stack build给我的是

\n\n
.../src/Linear.hs:196:55:\n    Couldn\'t match type \xe2\x80\x98A.Plain ((Z :. head0) :. head1)\xe2\x80\x99\n                   with \xe2\x80\x98(Z :. Int) :. Int\xe2\x80\x99\n    The type variables \xe2\x80\x98head0\xe2\x80\x99, \xe2\x80\x98head1\xe2\x80\x99 are ambiguous\n    Expected type: Exp (A.Plain ((Z :. head0) :. head1))\n      Actual type: Exp ((Z :. Int) :. Int)\n    Relevant bindings include\n      i :: head0 (bound at src/Linear.hs:196:38)\n      j :: head1 (bound at src/Linear.hs:196:41)\n    In the first argument of \xe2\x80\x98A.unlift\xe2\x80\x99, namely \xe2\x80\x98sh\xe2\x80\x99\n    In the expression: A.unlift sh\n\n.../src/Linear.hs:197:47:\n    Couldn\'t match type \xe2\x80\x98FullShape (A.Plain (Z :. head0))\xe2\x80\x99\n                   with \xe2\x80\x98Z :. Int\xe2\x80\x99\n    The type variable \xe2\x80\x98head0\xe2\x80\x99 is ambiguous\n    Expected type: Acc\n                     (Array (FullShape (A.Plain (Z :. head0) :. All)) e)\n      Actual type: Acc (Array (FullShape (Z :. Int) :. Int) e)\n    Relevant bindings include\n      v\' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)\n        (bound at src/Linear.hs:197:34)\n      i :: head0 (bound at src/Linear.hs:196:38)\n    In the first argument of \xe2\x80\x98A.slice\xe2\x80\x99, namely \xe2\x80\x98v\xe2\x80\x99\n    In the expression: A.slice v (A.lift $ Z :. i :. All)\n\n.../src/Linear.hs:198:39:\n    Couldn\'t match type \xe2\x80\x98A.SliceShape (A.Plain ((Z :. All) :. head1))\xe2\x80\x99\n                   with \xe2\x80\x98A.SliceShape (A.Plain (Z :. head0)) :. Int\xe2\x80\x99\n    The type variables \xe2\x80\x98head0\xe2\x80\x99, \xe2\x80\x98head1\xe2\x80\x99 are ambiguous\n    Expected type: Acc\n                     (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)\n      Actual type: Acc\n                     (Array (A.SliceShape (A.Plain ((Z :. All) :. head1))) e)\n    Relevant bindings include\n      w\' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)\n        (bound at src/Linear.hs:198:34)\n      v\' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)\n        (bound at src/Linear.hs:197:34)\n      i :: head0 (bound at src/Linear.hs:196:38)\n      j :: head1 (bound at src/Linear.hs:196:41)\n    In the expression: A.slice w (A.lift $ Z :. All :. j)\n    In an equation for \xe2\x80\x98w\'\xe2\x80\x99: w\' = A.slice w (A.lift $ Z :. All :. j)\n\n.../src/Linear.hs:198:47:\n    Couldn\'t match type \xe2\x80\x98FullShape (A.Plain ((Z :. All) :. head1))\xe2\x80\x99\n                   with \xe2\x80\x98(Z :. Int) :. Int\xe2\x80\x99\n    The type variable \xe2\x80\x98head1\xe2\x80\x99 is ambiguous\n    Expected type: Acc\n                     (Array (FullShape (A.Plain ((Z :. All) :. head1))) e)\n      Actual type: Acc (Array (FullShape (Z :. All) :. Int) e)\n    Relevant bindings include\n      j :: head1 (bound at src/Linear.hs:196:41)\n    In the first argument of \xe2\x80\x98A.slice\xe2\x80\x99, namely \xe2\x80\x98w\xe2\x80\x99\n    In the expression: A.slice w (A.lift $ Z :. All :. j)\n
Run Code Online (Sandbox Code Playgroud)\n\n

我查阅了Accelerate的文档,并且还在阅读Accelerate-arithmetic,它具有类似的目标,但不用于TypeLits断言数组/向量维度。

\n\n

我还尝试制作一个普通版本(即没有我自己的矩阵类型),以防我的类型错误,我相信这对slice. 我只是为了完整性而添加此内容,我可以添加错误消息,但我选择忽略它们,因为我相信它们与上述问题无关。

\n\n
(#*#) :: forall a. (IsNum a, Elt a) =>\n    Acc (Array DIM2 a) -> Acc (Array DIM2 a) -> Maybe (Acc (Array DIM2 a))   \nv #*# w = let Z:.k :.m = A.unlift $ A.arrayShape $ I.run v\n              Z:.m\':.n = A.unlift $ A.arrayShape $ I.run w\n           in if m /= m\'\n                 then Nothing\n                 else Just $ AccMatrix $ A.generate (A.index2 k n) (aux v w)\n          where aux :: Acc (Array DIM2 a) -> Acc (Array DIM2 a) -> Exp DIM2 -> Exp a\n                aux v w sh = let (Z:.i:.j) = A.unlift sh\n                                 v\' = A.slice v (A.lift $ Z:.i:.All)\n                                 w\' = A.slice w (A.lift $ Z:.All:.j)\n                              in A.the $ A.sum $ A.zipWith (*) v\' w\'\n
Run Code Online (Sandbox Code Playgroud)\n

use*_*038 3

你的代码实际上是正确的。不幸的是,类型检查器不够聪明,无法弄清楚,所以你必须帮助它:

let (Z:.i:.j) = A.unlift sh
Run Code Online (Sandbox Code Playgroud)

变成

let (Z:.i:.j) = A.unlift sh :: (Z :. Exp Int) :. Exp Int
Run Code Online (Sandbox Code Playgroud)

这里的关键是A.unlift :: A.Unlift c e => c (A.Plain e) -> ebutA.Plain是一个关联的类型族(因此是非单射的),因此e如果没有类型签名就无法确定该类型,并且e需要选择一个实例来用于Unlift c e. 这就是“模糊类型”错误的来源——它确实是e模糊的。


您还有一个不相关的错误。aux应该有类型

aux :: (IsNum e, Elt e) => ...
Run Code Online (Sandbox Code Playgroud)

或者

aux :: (e ~ a) => ... 
Run Code Online (Sandbox Code Playgroud)

在后一种情况下,它a是 的类型签名,(#*#)因此它已经具有约束IsNum, Elt