All the side effects were never mentioned to me
I am innocent of uncontrolled abuse

  • 19 Posts
  • 2 Comments
Joined 1 year ago
cake
Cake day: June 17th, 2023

help-circle










  • Stream fusion does work:

    data Stream a = forall s. Stream !(s -> Step s a) !s
    data Step s a = Yield a !s | Skip !s | Done
    
    data Tup a b = Tup !a !b
    
    cartesianProduct :: Stream a -> Stream b -> Stream (a, b)
    cartesianProduct (Stream step1 s01) (Stream step2 s02) = Stream step' s' where
      s' = Tup s01 s02
      step' (Tup s1 s2) =
        case step1 s1 of
          Yield x s1' ->
            case step2 s2 of
              Yield y s2' -> Yield (x, y) (Tup s1 s2')
              Skip s2' -> Skip (Tup s1 s2')
              Done -> Skip (Tup s1' s02)
          Skip s1' -> Skip (Tup s1' s2)
          Done -> Done
    
    eft :: Int -> Int -> Stream Int
    eft x y = Stream step x where
      step s
        | s > y = Done
        | otherwise = Yield s (s + 1)
    
    fooS :: Stream (Int, Int)
    fooS = cartesianProduct (eft 0 10) (eft 0 10)
    
    toList :: Stream a -> [a]
    toList (Stream step s0) = go s0 where
      go !s =
        case step s of
          Yield x s' -> x : go s'
          Skip s' -> go s'
          Done -> []
    
    foo :: [(Int,Int)]
    foo = toList fooS
    
    


  • Admittedly it gets more complicated when summing two things at the same time:

    let Pair dnormMean dnormNormMean =
          fold (Pair <$> dimap (\(Pair _ dnormI) -> dnormI) (/ fromIntegral cc) sum
                     <*> dimap (\(Pair normBti dnormI) -> normBti * dnormI) (/ fromIntegral cc) sum)
            $ map (\i -> Pair (((inp ! (off + i)) - meanBt) * rstdBt)
                              ((weight ! i) * (dout ! (off + i))))
              [0 .. cc - 1]
    
    
    float dnorm_mean = 0.0f;
    float dnorm_norm_mean = 0.0f;
    for (int i = 0; i < C; i++) {
        float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
        float dnorm_i = weight[i] * dout_bt[i];
        dnorm_mean += dnorm_i;
        dnorm_norm_mean += dnorm_i * norm_bti;
    }
    dnorm_mean = dnorm_mean / C;
    dnorm_norm_mean = dnorm_norm_mean / C;