Home > Net >  Haskell data type definition depended on GADTs and function output
Haskell data type definition depended on GADTs and function output

Time:03-31

I would like to have a tensor data structure

data Nat where
    Zero :: Nat
    Succ :: Nat -> Nat

-- | A list of type a and of length n
data ListN a (dim :: Nat) where
    Nil  :: ListN a Zero
    Cons :: a -> ListN a n -> ListN a (Succ n)    

data Tensor a where
        Dense :: ListN a n -> ListN Int Nat -> Tensor a

A tensor is represented by a list of elements and a list of integers representing the dimensions of the tensor. For example [3,4,5,6] in ListN would mean you have 4 dimension were each dimension is 3, 4, 5 and 6 elements long respectively. But now I want n of the first ListN be depended of the product of all integers stored in the second ListN because that's the amount of elements I can have in the first ListN. But how should I do that?

CodePudding user response:

To do this, you'll need a type-level dimension vector for your Tensor type, not just a ListN Int Nat value, so it's probably better to define Tensor with a dims type parameter. You may also find it more convenient to have the dimensions first and the element type second, so something like:

data ListN (dim :: Nat) a where
    Nil  :: ListN Zero a
    Cons :: a -> ListN n a -> ListN (Succ n) a
infixr 5 `Cons`

data Tensor (dims :: [Nat]) a where
  Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a

The missing piece here is Product which is a type-level function to multiply the dimensions. It's a little tedious to multiple Peano naturals, but the following works:

type family Plus m n where
  Plus (Succ m) n = Plus m (Succ n)
  Plus Zero n = n

type family Times m n where
  Times (Succ m) n = Plus n (Times m n)
  Times Zero n = Zero

type family Product (dims) where
  Product '[] = Succ Zero
  Product (m : ns) = Times m (Product ns)

After that, the following type checks. Note that I've made Cons an infixr operator up above to avoid a lot of parentheses:

t1 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)

If the number of elements is wrong, the constraint fails, so the following does not type check:

t2 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)

The full example:

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

data Nat where
    Zero :: Nat
    Succ :: Nat -> Nat

data ListN (dim :: Nat) a where
    Nil  :: ListN Zero a
    Cons :: a -> ListN n a -> ListN (Succ n) a
infixr 5 `Cons`

data Tensor (dims :: [Nat]) a where
  Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a

type family Plus m n where
  Plus (Succ m) n = Plus m (Succ n)
  Plus Zero n = n

type family Times m n where
  Times (Succ m) n = Plus n (Times m n)
  Times Zero n = Zero

type family Product (dims) where
  Product '[] = Succ Zero
  Product (m : ns) = Times m (Product ns)

-- type checks
t1 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)

-- won't type check
t2 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)

As noted in the comments, there is a built in non-Peano Nat type that you may find easier to work with. Rewritten to use that, the code would look like this:

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

import GHC.TypeLits

data ListN (dim :: Nat) a where
    Nil  :: ListN 0 a
    Cons :: a -> ListN n a -> ListN (1   n) a
infixr 5 `Cons`

data Tensor (dims :: [Nat]) a where
  Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a

type family Product dims where
  Product '[] = 1
  Product (m : ns) = m * Product ns

-- type checks
t1 :: Tensor '[1,2,3] Int
t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)

-- won't type check
t2 :: Tensor '[1,2,3] Int
t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)
  • Related